文章目录

  • 1. MNIST 手写数字识别
  • 2. 聚焦数据集扩充后的模型训练
  • 3. pytorch 手写数字识别基本实现
    • 3.1完整代码及 MNIST 测试集测试结果
      • 3.1.1代码
      • 3.1.2 MNIST 测试集测试结果
    • 3.2 使用自己的图片进行测试
      • 3.2.1 测试图片预处理代码
      • 3.2.2 测试图片结果
  • 4. 数据增强
    • 4.1 手动读取 MNIST 数据集
    • 4.2 数据增强
      • 4.2.1 像素反转
      • 4.2.2 图像旋转
        • 4.2.2.1 图像类别统计
        • 4.2.2.2 根据类别进行等量均类划分
      • 4.2.3 像素反转 + 图像旋转
      • 4.2.4 选择加载不同的处理后的数据集
    • 4.3 完整代码
  • 5. 模型再训练
    • 5.1 怎么加载 split 后的数据?
      • 5.1.1 创建自己的 dataset 类
      • 5.1.2 load 分割好的数据
    • 5.2 加载完成后怎么和原始数据合并,然后送入模型进行训练?
    • 5.3 完整代码
    • 5.4 训练结果
      • 5.4.1 只进行像素反转
        • 5.4.1.1 测试结果
        • 5.4.1.2 在自己的数据上测试
          • 测试代码
          • 测试结果
      • 5.4.2 只进行图像旋转
        • 5.4.2.1 测试结果
        • 5.4.2.2 在自己的数据上测试
          • 测试代码
          • 测试结果
      • 5.4.3 二者同时进行
        • 5.4.3.1 测试结果
        • 5.4.3.2 在自己的数据上测试
          • 测试代码
          • 测试结果
    • 5.5 结果整合
  • 结语

1. MNIST 手写数字识别

MNIST 数据集分为两部分,分别是训练集和测试集,其中训练集含有 60000 张图片,测试集中含有 10000 张图片。从官网下载的数据集主要包括有 4 个文件:

文件名称 文件用途
train-images-idx3-ubyte.gz 训练集图像
train-labels-idx1-ubyte.gz 训练集 label
t10k-images-idx3-ubyte.gz 测试集图像
t10k-labels-idx1-ubyte.gz 测试集 label

参考:
MNIST 数据集介绍 1
MNIST 数据集介绍 2

2. 聚焦数据集扩充后的模型训练

Internet 中有很多关于 pytorch 实现手写数字识别的博客了,所以本文不再对这一方面作过多的叙述。更多地,本文对 MNIST 数据集进行了扩充,利用 3 中不同的数据集构成对模型进行训练,每类数据集构成都包含了 12000 张图片。这 3 种不同的数据集构成如下:

  • 原始数据集(60000 张)+ 像素反转后的图片(60000 张)
  • 原始数据集(60000 张)+ 对图像进行 90°, 180°, 270° 等量均类旋转后的图片(60000 张)(注意:此处的等量均类是指对每个角度都旋转了 20000 张图片,同时,这 20000 张图片中包含了数字 0-9 这十个类别的图片各 2000 张)
  • 原始数据集(60000 张)+ 像素反转后的图片(30000 张)+ 等量均类旋转的图片(30000 张)

建议自己尝试进行数据分割,也可以利用分割好了的数据 click->已分割好了的数据

3. pytorch 手写数字识别基本实现

3.1完整代码及 MNIST 测试集测试结果

3.1.1代码

完整代码如下:

import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Imageclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, img):output = self.conv1(img)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputdef get_device():if torch.cuda.is_available():train_device = torch.device('cuda')else:train_device = torch.device('cpu')return train_devicedef get_data_loader(dat_path, bat_size, trans, to_train=False):dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)if to_train is True:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)else:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)return dat_set, dat_loaderdef show_part_of_image(dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)fig = plt.figure(num=1)for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}'.format(exam_label[i]))plt.xticks([])plt.yticks([])plt.show()def train(network, dat_loader, device, epos, loss_function, optimizer):for epoch in range(1, epos + 1):network.train(mode=True)for idx, (train_img, train_label) in enumerate(dat_loader):train_img = train_img.to(device)train_label = train_label.to(device)outputs = network(train_img)optimizer.zero_grad()loss = loss_function(outputs, train_label)loss.backward()optimizer.step()if idx % 100 == 0:cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,idx * len(train_img),len(dat_loader.dataset),(100 * cnt) / (len(dat_loader.dataset) * epos),loss.item()))print('------------------------------------------------')print('Training ended.')return networkdef test(network, dat_loader, device, loss_function):test_loss_avg, correct, total = 0, 0, 0test_loss = []network.train(mode=False)with torch.no_grad():for idx, (test_img, test_label) in enumerate(dat_loader):test_img = test_img.to(device)test_label = test_label.to(device)total += test_label.size(0)outputs = network(test_img)loss = loss_function(outputs, test_label)test_loss.append(loss.item())predictions = torch.argmax(outputs, dim=1)correct += torch.sum(predictions == test_label)test_loss_avg = np.average(test_loss)print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,correct / total * 100,test_loss_avg))def show_part_of_test_result(network, dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)with torch.no_grad():outputs = network(exam_img)fig = plt.figure()for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}, Prediction: {}'.format(exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])plt.show()batch_size, epochs = 64, 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
my_device = get_device()path = './data'
_, train_data_loader = get_data_loader(path, batch_size, transform, True)
print('Training data loaded.')show_part_of_image(train_data_loader, 3, 3)_, test_data_loader = get_data_loader(path, batch_size, transform)
print('Testing data loaded.')cnn = CNN()
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cnn.parameters(), lr=0.01)cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)
test(cnn, test_data_loader, my_device, loss_func)show_part_of_test_result(cnn, test_data_loader, 5, 2)torch.save(cnn, './cnn.pth')

3.1.2 MNIST 测试集测试结果

模型测试结果:

其中一些超参数如下:

  • batch_size: 64
  • epochs: 10

同时,采用交叉熵 CrossEntropyLoss 来计算 loss,Adam 来进行优化:

模型在测试集上的准确率达到了 97.32%,从右侧的测试集采样结果来看,正确率也相对较高;

3.2 使用自己的图片进行测试

另外,还在画图中做了 0-9 这 10 个数字代入模型进行识别。注意:在画图中做的图片必须要是 28 * 28 的大小(当然也可以用 python 进行裁剪,这里就偷个懒~)
还需要注意的是,MNIST 数据集中的图片是黑底白字的,而通过画图做出的图片是白底黑字的,因此若想得到准确结果的话,必须要对需要测试的图片进行像素反转的预处理操作;

3.2.1 测试图片预处理代码

注意:由于将模型保存进了 cnn.pth 文件,测试时直接 torch.load('./cnn.pth') 即可(当然也可以用官方推荐的只保存参数的方法);需要注意的是:记得把网络结构的定义复制过来,否则会报错;

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)# 像素反转for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]# print(img_array)img = Image.fromarray(img_array)# img.show()img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()

3.2.2 测试图片结果


(虽然结果正确率挺高,但是那些图片看起来怎么是灰底呢!?)

4. 数据增强

由于我们需要对数据进行处理,因此需要单独将数据读取出来,再进行相应的处理后保存;

4.1 手动读取 MNIST 数据集

关于如何从 .gz 文件中读取图片和图片的 label,参考了这篇文章 手动读取 MNIST 数据集;
主要代码:

def load_mnist(folder, img_file_name, label_file_name):with gzip.open(os.path.join(folder, label_file_name), 'rb') as lbpath:y_set = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(folder, img_file_name), 'rb') as imgpath:x_set = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_set), 28, 28)return x_set, y_set

注意,offset 的0000-0003是 magic number,offset的0004-0007是items数目,所以跳过不读,因此将 offset 设置为 8 开始读取;同理:

将 offset 设置为 16,开始读取图片数据;

4.2 数据增强

4.2.1 像素反转

主要操作就是用 255 - 原像素,代码如下:

def all_divert(x, save_path):# 使 numpy 矩阵可以读写x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(len(x)):for pixel in np.nditer(x[i], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[i])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(i))

4.2.2 图像旋转

4.2.2.1 图像类别统计

在对图像进行旋转的时候,需要做到等量均类,这两个条件缺一不可(因为你不可能让一个人看到一个陌生的动物却能准确说出这个动物是什么),因此首先对图片数据根据它们的 label 进行一个统计。

def classify_img(y):cnt = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}for i in range(len(y)):label = y[i]cnt[label].append(i)return cnt

这里返回的字典 cnt 中每个字典项保存有属于该 key(label) 的图像的编号;

4.2.2.2 根据类别进行等量均类划分

有了对每个 label 的统计,从中进行划分即可。此处是对全部图像进行 90°,180°,270° 这三类旋转,因此对于每个 label 都将其编号集合进行三等分

def all_rotate(x, cnt_seq, save_path):'''x: 图像数据集cnt_seq: 统计后的 cnt 字典save_path: 图像保存路径'''x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 3 份data_len = int(len(cnt_seq[i]) / 3)for split in range(3):left = split * data_lenif split == 2:# 最后一份包含剩下的所有图像right = len(cnt_seq[i])else:right = (split + 1) * data_lenfor j in range(left, right):# split + 1 表示旋转 90° 的 (split + 1) 倍x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split + 1)save_img = Image.fromarray(x[cnt_seq[i][j]])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))

4.2.3 像素反转 + 图像旋转

就是上面两种操作的综合,只不过将原始数据集划分为 4 等分:

def divert_and_rotate(x, cnt_seq, save_path):'''x: 图像数据集cnt_seq: 统计后的 cnt 字典save_path: 图像保存路径'''x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 4 份data_len = int(len(cnt_seq[i]) / 4)for split in range(4):left = split * data_lenif split == 3:right = len(cnt_seq[i])else:right = (split + 1) * data_lenif split == 0:# 第一等份进行像素反转for j in range(left, right):for pixel in np.nditer(x[cnt_seq[i][j]], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))else:# 后面的进行图像旋转for j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split)save_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))

4.2.4 选择加载不同的处理后的数据集

上面的三个函数可以实现将不同的处理方式处理后的数据集进行保存,需要注意的是:测试集进行了划分,训练集也要进行划分! 因此下面的 split_and_save() 函数用来选择不同的处理模式;

def split_and_save(x, y, save_img_path, to_divert=False, to_rotate=False):'''x: 图像数据; y: label 数据save_img_path: 图像保存路径to_divert: 是否进行像素反转to_rotate: 是否进行图像旋转'''count_seq = classify_img(y)if to_divert is True and to_rotate is False:all_divert(x, save_img_path)elif to_divert is False and to_rotate is True:all_rotate(x, count_seq, save_img_path)elif to_divert is True and to_rotate is True:divert_and_rotate(x, count_seq, save_img_path)else:return

接下来就是图像处理了:

root_path = './data/MNIST/raw'
# 加载训练集
img_file_path = 'train-images-idx3-ubyte.gz'
label_file_path = 'train-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)save_root_path = './testImgs'
split_and_save(raw_x, raw_y, save_root_path + '/divert/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate/', True, True)# 将训练集 label 保存在 label_train.txt 中
with open(save_root_path + '/label_train.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()# 加载测试集
img_file_path = 't10k-images-idx3-ubyte.gz'
label_file_path = 't10k-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)
split_and_save(raw_x, raw_y, save_root_path + '/divert_test/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate_test/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate_test/', True, True)# 将测试集 label 保存在 label_test.txt 中
with open(save_root_path + '/label_test.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()

4.3 完整代码

import gzip
import os
import numpy as np
from PIL import Imagedef load_mnist(folder, img_file_name, label_file_name):with gzip.open(os.path.join(folder, label_file_name), 'rb') as lbpath:y_set = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(folder, img_file_name), 'rb') as imgpath:x_set = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_set), 28, 28)return x_set, y_setdef all_divert(x, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(len(x)):for pixel in np.nditer(x[i], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[i])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(i))def classify_img(y):cnt = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}for i in range(len(y)):label = y[i]cnt[label].append(i)return cntdef all_rotate(x, cnt_seq, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 3 份data_len = int(len(cnt_seq[i]) / 3)for split in range(3):left = split * data_lenif split == 2:right = len(cnt_seq[i])else:right = (split + 1) * data_lenfor j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split + 1)save_img = Image.fromarray(x[cnt_seq[i][j]])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))def divert_and_rotate(x, cnt_seq, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 4 份data_len = int(len(cnt_seq[i]) / 4)for split in range(4):left = split * data_lenif split == 3:right = len(cnt_seq[i])else:right = (split + 1) * data_lenif split == 0:for j in range(left, right):for pixel in np.nditer(x[cnt_seq[i][j]], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))else:for j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split)save_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))def split_and_save(x, y, save_img_path, to_divert=False, to_rotate=False):count_seq = classify_img(y)if to_divert is True and to_rotate is False:all_divert(x, save_img_path)elif to_divert is False and to_rotate is True:all_rotate(x, count_seq, save_img_path)elif to_divert is True and to_rotate is True:divert_and_rotate(x, count_seq, save_img_path)else:returnroot_path = './data/MNIST/raw'
img_file_path = 'train-images-idx3-ubyte.gz'
label_file_path = 'train-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)save_root_path = './testImgs'
split_and_save(raw_x, raw_y, save_root_path + '/divert/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate/', True, True)with open(save_root_path + '/label_train.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()img_file_path = 't10k-images-idx3-ubyte.gz'
label_file_path = 't10k-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)
split_and_save(raw_x, raw_y, save_root_path + '/divert_test/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate_test/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate_test/', True, True)with open(save_root_path + '/label_test.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()

5. 模型再训练

模型再训练需要解决 2 个问题:

  • 怎么加载 split 后的数据?
  • 加载完成后怎么和原始数据合并,然后送入模型进行训练?

5.1 怎么加载 split 后的数据?

5.1.1 创建自己的 dataset 类

为了使自己的数据集和原始数据集进行合并,可以继承 torch.utils.data.Dataset 类开发自己的 my_dataset 类:

class my_dataset(torch.utils.data.Dataset):def __init__(self, img, label, transform=None):super(my_dataset, self).__init__()self.dataset = imgself.label = labelself.transform = transformdef __getitem__(self, item):data = self.dataset[item]lb = self.label[item]if self.transform is not None:data = self.transform(data)return data, lbdef __len__(self):return len(self.dataset)

这里,__init(self)__, __getitem(self, item)__, __len(self)__ 是必须实现的,当把之前分割好的 img, label 数据加载进来后,放入 my_dataset 类即可;

5.1.2 load 分割好的数据

加载分割好的数据,返回 my_dataset 对象;

def load(trans, to_divert=False, to_rotate=False, train=False):'''trans: torchvision.transforms 对象to_divert: 是否进行像素反转to_rotate: 是否进行图像旋转train: 是否是用于训练的数据'''x, y = [], []root_path = './testImgs/'# 加载训练数据if train is True:num = 6e4label_path = root_path + 'label_train.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate/'else:returnelse:num = 1e4label_path = root_path + 'label_test.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert_test/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate_test/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate_test/'else:returnfor i in range(int(num)):path = load_path + '{}.jpg'.format(i)img = Image.open(path).convert('L')x.append(img)# 加载 labelwith open(label_path, 'r') as f:for i in range(int(num)):label = f.readline()label = label.strip('\n')y.append(int(label))f.close()dataset = my_dataset(x, y, trans)return dataset

5.2 加载完成后怎么和原始数据合并,然后送入模型进行训练?

以原始训练集和数据增强后的训练集合并为例:

 path = './data'# get_data_loader() 就是基本实现中定义好的函数train_data_set, _ = get_data_loader(path, batch_size, transform, True)# 增强后的数据集enhanced_train_data_set = load(transform, True, True, True)# 采用 ConcatDataset() 进行连接train_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_train_data_set, train_data_set])print('Training data loaded.')# 将 dataset 放入 loader 中train_data_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)

5.3 完整代码

完整代码和基本实现中的差不多,主要区别在于:

  • 增加了 my_dataset 类;
  • 加载自己的数据集的 load() 函数;
  • 主函数里面对数据进行了加载和合并;
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Imageclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, img):output = self.conv1(img)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputclass my_dataset(torch.utils.data.Dataset):def __init__(self, img, label, transform=None):super(my_dataset, self).__init__()self.dataset = imgself.label = labelself.transform = transformdef __getitem__(self, item):data = self.dataset[item]lb = self.label[item]if self.transform is not None:data = self.transform(data)return data, lbdef __len__(self):return len(self.dataset)def get_device():if torch.cuda.is_available():train_device = torch.device('cuda')else:train_device = torch.device('cpu')return train_devicedef get_data_loader(dat_path, bat_size, trans, to_train=False):dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)if to_train is True:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)else:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)return dat_set, dat_loaderdef show_part_of_image(dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)fig = plt.figure(num=1)for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}'.format(exam_label[i]))plt.xticks([])plt.yticks([])plt.show()def train(network, dat_loader, device, epos, loss_function, optimizer):for epoch in range(1, epos + 1):network.train(mode=True)for idx, (train_img, train_label) in enumerate(dat_loader):train_img = train_img.to(device)train_label = train_label.to(device)outputs = network(train_img)optimizer.zero_grad()loss = loss_function(outputs, train_label)loss.backward()optimizer.step()if idx % 100 == 0:cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,idx * len(train_img),len(dat_loader.dataset),(100 * cnt) / (len(dat_loader.dataset) * epos),loss.item()))print('------------------------------------------------')print('Training ended.')return networkdef test(network, dat_loader, device, loss_function):test_loss_avg, correct, total = 0, 0, 0test_loss = []network.train(mode=False)with torch.no_grad():for idx, (test_img, test_label) in enumerate(dat_loader):test_img = test_img.to(device)test_label = test_label.to(device)total += test_label.size(0)outputs = network(test_img)loss = loss_function(outputs, test_label)test_loss.append(loss.item())predictions = torch.argmax(outputs, dim=1)correct += torch.sum(predictions == test_label)test_loss_avg = np.average(test_loss)print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,correct / total * 100,test_loss_avg))def show_part_of_test_result(network, dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)with torch.no_grad():outputs = network(exam_img)fig = plt.figure()for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}, Prediction: {}'.format(exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])plt.show()def load(trans, to_divert=False, to_rotate=False, train=False):x, y = [], []root_path = './testImgs/'if train is True:num = 6e4label_path = root_path + 'label_train.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate/'else:returnelse:num = 1e4label_path = root_path + 'label_test.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert_test/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate_test/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate_test/'else:returnfor i in range(int(num)):path = load_path + '{}.jpg'.format(i)img = Image.open(path).convert('L')x.append(img)with open(label_path, 'r') as f:for i in range(int(num)):label = f.readline()label = label.strip('\n')y.append(int(label))f.close()dataset = my_dataset(x, y, trans)return datasetif __name__ == '__main__':batch_size, epochs = 128, 10transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])my_device = get_device()path = './data'train_data_set, _ = get_data_loader(path, batch_size, transform, True)enhanced_train_data_set = load(transform, True, True, True)train_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_train_data_set, train_data_set])print('Training data loaded.')train_data_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)show_part_of_image(train_data_loader, 3, 3)test_data_set, _ = get_data_loader(path, batch_size, transform)enhanced_test_data_set = load(transform, True, True, False)test_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_test_data_set, test_data_set])print('Testing data loaded.')test_data_loader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, shuffle=True)cnn = CNN()loss_func = nn.CrossEntropyLoss()optim = torch.optim.Adam(cnn.parameters(), lr=0.01)cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)test(cnn, test_data_loader, my_device, loss_func)show_part_of_test_result(cnn, test_data_loader, 5, 2)torch.save(cnn, './cnn2.pth')

5.4 训练结果

5.4.1 只进行像素反转

5.4.1.1 测试结果


其中一些超参数如下:

  • batch_size: 128
  • epochs: 10

模型在测试集上的准确率达到了 97.76%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.1.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)img = Image.fromarray(img_array)# img.show()img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果


可以从右侧的结果中看出,准确率比较高(除了看起来比较讨厌的灰底?);

5.4.2 只进行图像旋转

5.4.2.1 测试结果


用到的 batch_size, epochs 和上面的一样。
模型在测试集上的准确率达到了 93.54%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.2.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)# 注意进行需要是黑底白字的图片for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]img = Image.fromarray(img_array)img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果


比较有趣的就是 9 旋转 180° 就变成 6 了hhh

5.4.3 二者同时进行

5.4.3.1 测试结果


用到的 batch_size, epochs 同样是 128 和 10;
模型在测试集上的准确率达到了 95.38%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.3.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()# 前 3 张图片是像素反转后的图片,后 3 张是未作处理的图片,最后 4 张二者同时进行的图片(注意像素反转是指将黑底白字转换为白底黑字)
for k in range(10):if k < 3:infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)elif 3 <= k < 6:infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]else:infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)img = Image.fromarray(img_array)img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果

5.5 结果整合

原始数据 只进行像素反转 只进行图像旋转 二者同时进行
batch_size, epochs 64, 10 128, 10 128, 10 128, 10
accuracy 97.32% 97.76% 93.54% 95.38%

结语

这样的想法来源于在对 MNIST 手写数字识别进行基本实现并利用自己做的图进行进行测试的时候,开始由于没有认识到黑底白字和白底黑字的问题,因此模型测试结果很差;然后就是写的数字必须比较端正,否则测试结果也很差;
因此在学长的启发下对数据集进行了拓展,使之能够应用于更广的场景中;
另外,在这里我只进行了 90°, 180°, 270° 这三种旋转,如果有兴趣的话可以尝试更多不同角度的旋转;

pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练相关推荐

  1. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

  2. Caffe MNIST 手写数字识别(全面流程)

    目录 1.下载MNIST数据集 2.生成MNIST图片训练.验证.测试数据集 3.制作LMDB数据库文件 4.准备LeNet-5网络结构定义模型.prototxt文件 5.准备模型求解配置文件_sol ...

  3. MNIST手写数字识别一

    手写分类识别一 MNIST手写数字识别:分类问题 一.数据准备 MNIST数据集来自美国国家标准与技术研究所,National Institute of Standards and Technolog ...

  4. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  5. 使用PYTORCH复现ALEXNET实现MNIST手写数字识别

    网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...

  6. Pytorch实现mnist手写数字识别

    2020/6/29 Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别.可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能. Mnist数据集如图 ...

  7. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  8. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  9. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

最新文章

  1. 数据洪流时代的芯片之变
  2. log4j的日志级别(ssm中log4j的配置)
  3. android 代码 日历 重复事件设置,日历重复/重复事件 - 最佳存储方法
  4. Node.js CLI 工具最佳实践
  5. java程序 开发工具_15款Java程序员必备的开发工具
  6. mysql主从切换机制torch_MySQL Replication设置(Master/Slave)实现主从复制
  7. link-cut-tree 简单介绍
  8. qn模块java脚本_Qn271 对于网络编程 反射 IO 线程的一些一本入门程序 多多联系会加快 速度 WinSock-NDIS 269万源代码下载- www.pudn.com...
  9. SPI、I2C、UART三种串行总线协议的区别和SPI接口介绍(转)
  10. 面向对象编程启思录——读《OOD启思录》有感
  11. wget命令下载文件并另存为不同的文件名
  12. Python中断多重循环的几种思路exit_flag
  13. JS中的slice和splice
  14. mysql与oracle存储过程_5分钟学会oracle与mysql存储过程insertinto
  15. 纯前端导出PDF分页截取问题处理
  16. 队列仿真一个银行叫号系统
  17. python就业方向那么广,为什么还是很多人找不到工作?
  18. sql中别名as,不写,以及使用双引号总结
  19. Java *1.11(人口估算)美国人口调查局基于以下假设进行人口估算:
  20. 【新手村专属】服创比赛总结

热门文章

  1. sqldeveloper fedora 20 [error occurred during error reporting (printing problematic frame), id 0xb]
  2. 周末在家学习echarts
  3. Python 工程师薪资再飙升,不止25K了!
  4. 如何让直播延迟低体验更好?做好这4点就可以了
  5. 警惕职场的竞业协议坑
  6. 浅谈目标检测中的特征冲突与不对齐问题
  7. 抖音3d相册html代码,抖音上很火的3D立体动态相册.html
  8. 【VSCode】Visual Studio Code软件使用入门
  9. 终于从树堆里爬出来了——堆排序(基于二叉树)基本思想、步骤、复杂度及python代码,欢迎交流
  10. 回顾社交游戏公司Zynga创业史