ViT 微调实战
小白学视觉
共 9694字,需浏览 20分钟
·
2024-08-01 10:05
重磅干货,第一时间送达
重磅干货,第一时间送达
探索 CIFAR-10 图像分类
介绍
你一定听说过“Attention is all your need”?Transformers 最初从文本开始,现在已无处不在,甚至在图像中使用了一种称为视觉变换器 (ViT) 的东西,这种变换器最早是在论文《一张图片胜过 16x16 个单词:用于大规模图像识别的 Transformers》中引入的。这不仅仅是另一个浮华的趋势;事实证明,它们是强有力的竞争者,可以与卷积神经网络 (CNN) 等传统模型相媲美。
-
将图像分成多个块,将这些块传递到全连接(FC)网络或 FC+CNN 以获取输入嵌入向量。 -
添加位置信息。 -
将其传递到传统的 Transformer 编码器中,并在末端附加一个 FC 层。
问题描述
设置环境
!pip install torch torchvision
!pip install transformers datasets
!pip install transformers[torch]
# PyTorch
import torch
import torchvision
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
# 用于显示图像
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
# 加载数据集
from datasets import load_dataset
# Transformers
从transformers import ViTImageProcessor, ViTForImageClassification
从transformers import TrainingArguments, Trainer
# 矩阵运算
import numpy as np
# 评估
from sklearn.metrics import accuracy_score
from sklearn.metrics import confused_matrix, ConfusionMatrixDisplay
数据预处理
trainds, testds = load_dataset("cifar10", split=["train[:5000]","test[:1000]"])
splits = trainds.train_test_split(test_size=0.1)
trainds = splits['train']
valds = splits['test']
trainds, valds, testds
# Output
(Dataset({
features: ['img', 'label'],
num_rows: 4500
}),
Dataset({
features: ['img', 'label'],
num_rows: 500
}),
Dataset({
features: ['img', 'label'],
num_rows: 1000
}))
trainds.features,trainds.num_rows,trainds[ 0 ]
# Output
({'img': Image(decode=True, id=None),
'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)},
4500,
{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
'label': 0})
itos = dict((k,v) for k,v in enumerate(trainds.features['label'].names))
stoi = dict((v,k) for k,v in enumerate(trainds.features['label'].names))
itos
# Output
{0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}
index = 0
img, lab = trainds[index]['img'], itos[trainds[index]['label']]
print(lab)
img
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
mu, sigma = processor.image_mean, processor.image_std #get default mu,sigma
size = processor.size
norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1]
# resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize
_transf = Compose([
Resize(size['height']),
ToTensor(),
norm
])
# apply transforms to PIL Image and store it to 'pixels' key
def transf(arg):
arg['pixels'] = [_transf(image.convert('RGB')) for image in arg['img']]
return arg
trainds.set_transform(transf)
valds.set_transform(transf)
testds.set_transform(transf)
idx = 0
ex = trainds[idx]['pixels']
ex = (ex+1)/2 #imshow requires image pixels to be in the range [0,1]
exi = ToPILImage()(ex)
plt.imshow(exi)
plt.show()
微调模型
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
print(model.classifier)
# Output
Linear(in_features=768, out_features=1000, bias=True)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)
print(model.classifier)
# Output
Linear(in_features=768, out_features=10, bias=True)
拥抱脸部训练师
args = TrainingArguments(
f"test-cifar-10",
save_strategy="epoch",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=10,
per_device_eval_batch_size=4,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir='logs',
remove_unused_columns=False,
)
def collate_fn(examples):
pixels = torch.stack([example["pixels"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixels, "labels": labels}
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return dict(accuracy=accuracy_score(predictions, labels))
trainer = Trainer(
model,
train_dataset=trainds,
eval_dataset=valds,
data_collator=collate_fn,
compute_metrics=compute_metrics,
tokenizer=processor,
)
训练模型
trainer.train()
# Output
TrainOutput(global_step=675, training_loss=0.22329048227380824, metrics={'train_runtime': 1357.9833, 'train_samples_per_second': 9.941, 'train_steps_per_second': 0.497, 'total_flos': 1.046216869705728e+18, 'train_loss': 0.22329048227380824, 'epoch': 3.0})
评估
outputs = trainer.predict(testds)
print(outputs.metrics)
# Output
{'test_loss': 0.07223748415708542, 'test_accuracy': 0.973, 'test_runtime': 28.5169, 'test_samples_per_second': 35.067, 'test_steps_per_second': 4.383}
itos[np.argmax(outputs.predictions[0])], itos[outputs.label_ids[0]]
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)
labels = trainds.features['label'].names
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
评论