轻松学pytorch – 使用多标签损失函数训练卷积网络
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。
这个数据是来自Kaggle上的一个验证码识别例子,作者采用的是迁移学习,基于ResNet18做到的训练。
https://www.kaggle.com/anjalichoudhary12/captcha-with-pytorch
这个数据集总计有1070张验证码图像,我把其中的1040张用作训练,30张作为测试,使用pytorch自定义了一个数据集类,代码如下:
1import torch
2import numpy as np
3from torch.utils.data import Dataset, DataLoader
4from torchvision import transforms
5import os
6import cv2 as cv
7
8NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
9ALPHABET = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
10ALL_CHAR_SET = NUMBER + ALPHABET
11ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
12MAX_CAPTCHA = 5
13
14
15def output_nums():
16 return MAX_CAPTCHA * ALL_CHAR_SET_LEN
17
18
19def encode(a):
20 onehot = [0]*ALL_CHAR_SET_LEN
21 idx = ALL_CHAR_SET.index(a)
22 onehot[idx] += 1
23 return onehot
24
25
26class CapchaDataset(Dataset):
27 def __init__(self, root_dir):
28 self.transform = transforms.Compose([transforms.ToTensor()])
29 img_files = os.listdir(root_dir)
30 self.txt_labels = []
31 self.encodes = []
32 self.images = []
33 for file_name in img_files:
34 label = file_name[:-4]
35 label_oh = []
36 for i in label:
37 label_oh += encode(i)
38 self.images.append(os.path.join(root_dir, file_name))
39 self.encodes.append(np.array(label_oh))
40 self.txt_labels.append(label)
41
42 def __len__(self):
43 return len(self.images)
44
45 def num_of_samples(self):
46 return len(self.images)
47
48 def __getitem__(self, idx):
49 if torch.is_tensor(idx):
50 idx = idx.tolist()
51 image_path = self.images[idx]
52 else:
53 image_path = self.images[idx]
54 img = cv.imread(image_path) # BGR order
55 h, w, c = img.shape
56 # rescale
57 img = cv.resize(img, (128, 32))
58 img = (np.float32(img) /255.0 - 0.5) / 0.5
59 # H, W C to C, H, W
60 img = img.transpose((2, 0, 1))
61 sample = {'image': torch.from_numpy(img), 'encode': self.encodes[idx], 'label': self.txt_labels[idx]}
62 return sample
基于ResNet的block结构,我实现了一个比较简单的残差网络,最后加一个全连接层输出多个标签。验证码是有5个字符的,每个字符的是小写26个字母加上0~9十个数字,总计36个类别,所以5个字符就有5x36=180个输出,其中每个字符是独热编码,这个可以从数据集类的实现看到。模型的输入与输出格式:
输入:NCHW=Nx3x32x128
卷积层最终输出:NCHW=Nx256x1x4
全连接层:Nx(256x4)
最终输出层:Nx180
代码实现如下:
1class CapchaResNet(torch.nn.Module):
2 def __init__(self):
3 super(CapchaResNet, self).__init__()
4 self.cnn_layers = torch.nn.Sequential(
5 # 卷积层 (128x32x3)
6 ResidualBlock(3, 32, 1),
7 ResidualBlock(32, 64, 2),
8 ResidualBlock(64, 64, 2),
9 ResidualBlock(64, 128, 2),
10 ResidualBlock(128, 256, 2),
11 ResidualBlock(256, 256, 2),
12 )
13
14 self.fc_layers = torch.nn.Sequential(
15 torch.nn.Linear(256 * 4, output_nums()),
16 )
17
18 def forward(self, x):
19 # stack convolution layers
20 x = self.cnn_layers(x)
21 out = x.view(-1, 4 * 256)
22 out = self.fc_layers(out)
23 return out
使用多标签损失函数,Adam优化器,代码实现如下:
1model = CapchaResNet()
2print(model)
3
4# 使用GPU
5if train_on_gpu:
6 model.cuda()
7
8ds = CapchaDataset("D:/python/pytorch_tutorial/capcha/samples")
9num_train_samples = ds.num_of_samples()
10bs = 16
11dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
12
13# 训练模型的次数
14num_epochs = 25
15# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
16optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
17model.train()
18
19# 损失函数
20mul_loss = torch.nn.MultiLabelSoftMarginLoss()
21index = 0
22for epoch in range(num_epochs):
23 train_loss = 0.0
24 for i_batch, sample_batched in enumerate(dataloader):
25 images_batch, oh_labels = \
26 sample_batched['image'], sample_batched['encode']
27 if train_on_gpu:
28 images_batch, oh_labels= images_batch.cuda(), oh_labels.cuda()
29 optimizer.zero_grad()
30
31 # forward pass: compute predicted outputs by passing inputs to the model
32 m_label_out_ = model(images_batch)
33 oh_labels = torch.autograd.Variable(oh_labels.float())
34
35 # calculate the batch loss
36 loss = mul_loss(m_label_out_, oh_labels)
37
38 # backward pass: compute gradient of the loss with respect to model parameters
39 loss.backward()
40
41 # perform a single optimization step (parameter update)
42 optimizer.step()
43
44 # update training loss
45 train_loss += loss.item()
46 if index % 100 == 0:
47 print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
48 index += 1
49
50 # 计算平均损失
51 train_loss = train_loss / num_train_samples
52
53 # 显示训练集与验证集的损失函数
54 print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
55
56# save model
57model.eval()
58torch.save(model, 'capcha_recognize_model.pt')
调用保存之后的模型,对图像测试代码如下:
1cnn_model = torch.load("./capcha_recognize_model.pt")
2root_dir = "D:/python/pytorch_tutorial/capcha/testdata"
3files = os.listdir(root_dir)
4one_hot_len = ALL_CHAR_SET_LEN
5for file in files:
6 if os.path.isfile(os.path.join(root_dir, file)):
7 image = cv.imread(os.path.join(root_dir, file))
8 h, w, c = image.shape
9 img = cv.resize(image, (128, 32))
10 img = (np.float32(img) /255.0 - 0.5) / 0.5
11 img = img.transpose((2, 0, 1))
12 x_input = torch.from_numpy(img).view(1, 3, 32, 128)
13 probs = cnn_model(x_input.cuda())
14 mul_pred_labels = probs.squeeze().cpu().tolist()
15 c0 = ALL_CHAR_SET[np.argmax(mul_pred_labels[0:one_hot_len])]
16 c1 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len:one_hot_len*2])]
17 c2 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*2:one_hot_len*3])]
18 c3 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*3:one_hot_len*4])]
19 c4 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*4:one_hot_len*5])]
20 pred_txt = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
21 cv.putText(image, pred_txt, (10, 20), cv.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 255), 2)
22 print("current code : %s, predict code : %s "%(file[:-4], pred_txt))
23 cv.imshow("capcha predict", image)
24 cv.waitKey(0)
其中对输入结果,要根据每个字符的独热编码,截取成五个独立的字符分类标签,然后使用argmax获取index根据index查找类别标签,得到最终的验证码预测字符串,代码运行结果如下:
好消息!
小白学视觉知识星球
开始面向外开放啦👇👇👇
下载1:OpenCV-Contrib扩展模块中文版教程 在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。 下载2:Python视觉实战项目52讲 在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。 下载3:OpenCV实战项目20讲 在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。 交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~