手撕 CNN 之 AlexNet(PyTorch 实战篇)
大家好,我是红色石头!
在上一篇文章:
手撕 CNN 经典网络之 AlexNet(理论篇)
详细介绍了 AlexNet 的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题。
正文开始!
1. 数据集制作
在论文中AlexNet作者使用的是ILSVRC 2012比赛数据集,该数据集非常大(有138G),下载、训练都很消耗时间,我们在复现的时候就不用这个数据集了。由于MNIST、CIFAR10、CIFAR100这些数据集图片尺寸都较小,不符合AlexNet网络输入尺寸227x227的要求,因此我们改用kaggle比赛经典的“猫狗大战”数据集了。
该数据集包含的训练集总共25000张图片,猫狗各12500张,带标签;测试集总共12500张,不带标签。我们仅使用带标签的25000张图片,分别拿出2500张猫和狗的图片作为模型的验证集。我们按照以下目录层级结构,将数据集图片放好。
为了方便大家训练,我们将该数据集放在百度云盘,下载链接:
链接:https://pan.baidu.com/s/1UEOzxWWMLCUoLTxdWUkB4A
提取码:cdue
1.1 制作图片数据的索引
准备好数据集之后,我们需要用PyTorch来读取并制作可以用来训练和测试的数据集。对于训练集和测试集,首先要分别制作对应的图片数据索引,即train.txt和test.txt两个文件,每个txt中包含每个图片的目录和对应类别class(cat对应的label=0,dog对应的label=1)。示意图如下:
制作图片数据索引train.txt和test.txt两个文件的python脚本程序如下:
import ostrain_txt_path = os.path.join("data", "catVSdog", "train.txt")
train_dir = os.path.join("data", "catVSdog", "train_data")
valid_txt_path = os.path.join("data", "catVSdog", "test.txt")
valid_dir = os.path.join("data", "catVSdog", "test_data")def gen_txt(txt_path, img_dir):f = open(txt_path, 'w')for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称for sub_dir in s_dirs:i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径for i in range(len(img_list)):if not img_list[i].endswith('jpg'): # 若不是png文件,跳过continue#label = (img_list[i].split('.')[0] == 'cat')? 0 : 1 label = img_list[i].split('.')[0]# 将字符类别转为整型类型表示if label == 'cat':label = '0'else:label = '1'img_path = os.path.join(i_dir, img_list[i])line = img_path + ' ' + label + '\n'f.write(line)f.close()if __name__ == '__main__':gen_txt(train_txt_path, train_dir)gen_txt(valid_txt_path, valid_dir)
运行脚本之后就在./data/catVSdog/目录下生成train.txt和test.txt两个索引文件。
1.2 构建Dataset子类
PyTorch 加载自己的数据集,需要写一个继承自torch.utils.data中Dataset类,并修改其中的__init__方法、__getitem__方法、__len__方法。默认加载的都是图片,__init__的目的是得到一个包含数据和标签的list,每个元素能找到图片位置和其对应标签。然后用__getitem__方法得到每个元素的图像像素矩阵和标签,返回img和label。
from PIL import Image
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, txt_path, transform = None, target_transform = None):fh = open(txt_path, 'r')imgs = []for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1]))) # 类别转为整型intself.imgs = imgs self.transform = transformself.target_transform = target_transformdef __getitem__(self, index):fn, label = self.imgs[index]img = Image.open(fn).convert('RGB') #img = Image.open(fn)if self.transform is not None:img = self.transform(img) return img, labeldef __len__(self):return len(self.imgs)
getitem是核心函数。self.imgs是一个list,self.imgs[index]是一个str,包含图片路径,图片标签,这些信息是从上面生成的txt文件中读取;利用Image.open对图片进行读取,注意这里的img是单通道还是三通道的;self.transform(img)对图片进行处理,这个transform里边可以实现减均值、除标准差、随机裁剪、旋转、翻转、放射变换等操作。
1.3 加载数据集和数据预处理
当Mydataset构建好,剩下的操作就交给DataLoder来加载数据集。在DataLoder中,会触发Mydataset中的getiterm函数读取一张图片的数据和标签,并拼接成一个batch返回,作为模型真正的输入。
pipline_train = transforms.Compose([#随机旋转图片transforms.RandomHorizontalFlip(),#将图片尺寸resize到227x227transforms.Resize((227,227)),#将图片转化为Tensor格式transforms.ToTensor(),#正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
pipline_test = transforms.Compose([#将图片尺寸resize到227x227transforms.Resize((227,227)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
train_data = MyDataset('./data/catVSdog/train.txt', transform=pipline_train)
test_data = MyDataset('./data/catVSdog/test.txt', transform=pipline_test)#train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=32, shuffle=False)
# 类别信息也是需要我们给定的
classes = ('cat', 'dog') # 对应label=0,label=1
在数据预处理中,我们将图片尺寸调整到227x227,符合AlexNet网络的输入要求。均值mean = [0.5, 0.5, 0.5],方差std = [0.5, 0.5, 0.5],然后使用transforms.Normalize进行归一化操作。
我们来看一下最终制作的数据集图片和它们对应的标签:
examples = enumerate(trainloader)
batch_idx, (example_data, example_label) = next(examples)
# 批量展示图片
for i in range(4):plt.subplot(1, 4, i + 1)plt.tight_layout() #自动调整子图参数,使之填充整个图像区域img = example_data[i]img = img.numpy() # FloatTensor转为ndarrayimg = np.transpose(img, (1,2,0)) # 把channel那一维放到最后img = img * [0.5, 0.5, 0.5] + [0.5, 0.5, 0.5]plt.imshow(img)plt.title("label:{}".format(example_label[i]))plt.xticks([])plt.yticks([])
plt.show()
2. 搭建AlexNet神经网络结构,并定义前向传播的过程
class AlexNet(nn.Module):"""Neural network model consisting of layers propsed by AlexNet paper."""def __init__(self, num_classes=2):"""Define and allocate layers for this neural net.Args:num_classes (int): number of classes to predict with this model"""super().__init__()# input size should be : (b x 3 x 227 x 227)# The image in the original paper states that width and height are 224 pixels, but# the dimensions after first convolution layer do not lead to 55 x 55.self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4), # (b x 96 x 55 x 55)nn.ReLU(),nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2), # section 3.3nn.MaxPool2d(kernel_size=3, stride=2), # (b x 96 x 27 x 27)nn.Conv2d(96, 256, 5, padding=2), # (b x 256 x 27 x 27)nn.ReLU(),nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),nn.MaxPool2d(kernel_size=3, stride=2), # (b x 256 x 13 x 13)nn.Conv2d(256, 384, 3, padding=1), # (b x 384 x 13 x 13)nn.ReLU(),nn.Conv2d(384, 384, 3, padding=1), # (b x 384 x 13 x 13)nn.ReLU(),nn.Conv2d(384, 256, 3, padding=1), # (b x 256 x 13 x 13)nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2), # (b x 256 x 6 x 6))# classifier is just a name for linear layersself.classifier = nn.Sequential(nn.Dropout(p=0.5, inplace=True),nn.Linear(in_features=(256 * 6 * 6), out_features=500),nn.ReLU(),nn.Dropout(p=0.5, inplace=True),nn.Linear(in_features=500, out_features=20),nn.ReLU(),nn.Linear(in_features=20, out_features=num_classes),)def forward(self, x):"""Pass the input through the net.Args:x (Tensor): input tensorReturns:output (Tensor): output tensor"""x = self.net(x)x = x.view(-1, 256 * 6 * 6) # reduce the dimensions for linear layer inputreturn self.classifier(x)
在构建AlexNet网络里,参数num_classes指的是类别的数量,由于论文中AlexNet的输出是1000个类别,我们这里的数据集只有猫和狗两个类别,因此这里的全连接层的神经元个数做了微调。num_classes=2,输出层也是两个神经元,不是原来的1000个神经元。FC6由原来的4096个神经元改为500个神经元,FC7由原来的4096个神经元改为20个神经元。
这里的改动大家注意一下,根据实际数据集的类别数量进行调整。整个网络的其它结构跟论文中的完全一样。
3. 将定义好的网络结构搭载到GPU/CPU,并定义优化器
#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet().to(device)
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
4. 定义训练过程
def train_runner(model, device, trainloader, optimizer, epoch):#训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct =0.0#enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存训练结果outputs = model(inputs)#计算损失和#多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#获取最大概率的预测结果#dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向传播loss.backward()#更新参数optimizer.step()if i % 100 == 0:#loss.item()表示当前loss的数值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total
5. 定义测试过程
def test_runner(model, device, testloader):#模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值#因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为Falsemodel.eval()#统计模型正确率, 设置初始值correct = 0.0test_loss = 0.0total = 0#torch.no_grad将不会计算梯度, 也不会进行反向传播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label).item()predict = output.argmax(dim=1)#计算正确数量total += label.size(0)correct += (predict == label).sum().item()#计算损失值print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))
6. 运行
#调用
epoch = 20
Loss = []
Accuracy = []
for epoch in range(1, epoch+1):print("start_time",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))loss, acc = train_runner(model, device, trainloader, optimizer, epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model, device, testloader)print("end_time: ",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'\n')print('Finished Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()
经历 20 次 epoch 的 loss 和 accuracy 曲线如下:
经过20个epoch的训练之后,accuracy达到了87.94%。
7. 保存模型
print(model)
torch.save(model, './models/alexnet-catvsdog.pth') #保存模型
AlexNet 的模型会打印出来,并将模型模型命令为 alexnet-catvsdog.pth 保存在固定目录下。
8. 模型测试
下面使用一张猫狗大战测试集的图片进行模型的测试。
from PIL import Image
import numpy as npif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('./models/alexnet-catvsdog.pth') #加载模型model = model.to(device)model.eval() #把模型转为test模式#读取要预测的图片# 读取要预测的图片img = Image.open("./images/test_cat.jpg") # 读取图像#img.show()plt.imshow(img) # 显示图片plt.axis('off') # 不显示坐标轴plt.show()# 导入图片,图片扩展后为[1,1,32,32]trans = transforms.Compose([transforms.Resize((227,227)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = trans(img)img = img.to(device)img = img.unsqueeze(0) #图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 预测 # 预测 classes = ('cat', 'dog')output = model(img)prob = F.softmax(output,dim=1) #prob是2个分类的概率print("概率:",prob)value, predicted = torch.max(output.data, 1)predict = output.argmax(dim=1)pred_class = classes[predicted.item()]print("预测类别:",pred_class)
输出:
概率: tensor([[1.0000e+00, 5.8714e-13]], grad_fn=<SoftmaxBackward>)
预测类别: cat
模型预测结果正确!
好了,以上就是使用 PyTorch 复现 AlexNet 网络的核心代码。建议大家根据文章内容完整码一下代码,可以根据实际情况使用自己的数据集,并调整FC6、FC7、Output Layer的神经元个数。
完整代码我已经放在了 GitHub 上,地址:
https://github.com/RedstoneWill/CNN_PyTorch_Beginner/tree/main/AlexNet
手撕 CNN 系列:
手撕 CNN 经典网络之 LeNet-5(理论篇)
手撕 CNN 经典网络之 LeNet-5(MNIST 实战篇)
手撕 CNN 经典网络之 LeNet-5(CIFAR10 实战篇)
手撕 CNN 经典网络之 LeNet-5(自定义实战篇)
手撕 CNN 经典网络之 AlexNet(理论篇)
如果觉得这篇文章有用的话,麻烦点个在看或转发朋友圈!
推荐阅读
(点击标题可跳转阅读)
干货 | 公众号历史文章精选
我的深度学习入门路线
我的机器学习入门路线图
重磅!
AI有道年度技术文章电子版PDF来啦!
扫描下方二维码,添加 AI有道小助手微信,可申请入群,并获得2020完整技术文章合集PDF(一定要备注:入群 + 地点 + 学校/公司。例如:入群+上海+复旦。
长按扫码,申请入群
(添加人数较多,请耐心等待)
感谢你的分享,点赞,在看三连
手撕 CNN 之 AlexNet(PyTorch 实战篇)相关推荐
- 【深度学习】手撕 CNN 之 AlexNet(PyTorch 实战篇)
今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题. 正文开始! 1. 数据集制作 在论文中AlexNet作者使用的是ILSV ...
- 手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
大家好,我是红色石头! 在上一篇文章: 手撕 CNN 经典网络之 VGGNet(理论篇) 详细介绍了 VGGNet 的网络结构,今天我们将使用 PyTorch 来复现VGGNet网络,并用VGGNet ...
- 手撕 CNN 经典网络之 AlexNet(理论篇)
大家好,我是红色石头! 第一个典型的CNN是LeNet5网络,而第一个大放异彩的CNN却是AlexNet.2012年在全球知名的图像识别竞赛 ILSVRC 中,AlexNet 横空出世,直接将错误率降 ...
- 手撕 CNN 经典网络之 VGGNet(理论篇)
2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司一起研发了新的卷积神经网络,并命名为VGGNet.VGGNet是比AlexNet更深的 ...
- svm手写数字识别_KNN 算法实战篇如何识别手写数字
上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算 ...
- string list 查找_手撕java集合源码——List篇
阅读list集合观察它们底层是如何实现的,以及集合面试中提出的问题进行实践. list集合中常用的类为Arraylist.LinkedLIst. 两者的区别 区别 Arraylist LinkedLi ...
- 手撕CNN神经网络(二)反向传播(back propagation)
反向传播简单来说就是通过求偏导数从而更新神经网络中权重w的过程 首先先简单复习一下前向传播(forward propagation) 篇幅有限,bias(b)在图中就没画出来,大家知道有这么个东西就好 ...
- Python3《机器学习实战》学习笔记(八):支持向量机原理篇之手撕线性SVM
原 Python3<机器学习实战>学习笔记(八):支持向量机原理篇之手撕线性SVM 置顶 2017年09月23日 17:50:18 阅读数:12644 转载请注明作者和出处: https: ...
- 关键点检测之直接回归(逐行手撕Pytorch)
直接回归的整个流程如下: 用卷积提特征 用全连接层进行回归得到关键点 与标签做损失并优化损失 当图像很大的时候,神经网络是在整张图像中去寻找一个像素点,所以直接回归这一解决办法让神经网络的工作是非常难 ...
最新文章
- Oracle的sys_connect_by_path函数
- 员外带你读论文:SeqGAN论文分享
- C# 系统应用之获取IE浏览记录和IE地址栏输入网址
- 11.30 如何取得当事人的银行账号?
- 附加SQL2005数据库失败解决方法
- 21天学通JAVA——学习笔记
- word中有软回车(每行后面有向下的箭头)的解决方法
- oracle创建一个永久性表空间,Oracle表空间简单管理永久表空间
- django安装mysqlclient报错mand errored out with exit status 1: python setup.py egg_info Check the logs f
- pgpool-Ⅱ一主两从同步流读写分离高可用方案测试
- vue组件传参(父传子)
- 2020年春分,新的开始
- IDEA查看maven的依赖树
- jmeter中控制器的使用
- Material Design设计之【色彩】实践
- iOS银联支付(最新)
- Sofaboot学习整理
- Swagger体验版
- 易语言词达人爬取题库
- 用友U872安装无法继续,iewebcontrol安装不成功解决办法