由于新型冠状肺炎疫情一直没能开学,在家自己学习了一下pytorch,本来说按着官网的60分钟教程过一遍的,但是CIFAR-10数据库的下载速度太慢了……

这台电脑里也没有现成的数据库,想起之前画了一些粒子的动量分量分布图,干脆拿来用了,也没期待它能表现得多好,主要图一个练手。(事实证明它表现相当差,不过这也在意料之中)

那么开始。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

接下来定义读取和处理图片的函数,图片尺寸是432x288,把它切成中间的288x288,再缩小成32x32。这样的处理单纯是为了让模型训练得快一点,毕竟这次练手本身的目的不是训练一个高精度的模型,而是训练一个模型。(而且话说回来这电脑也莫得英伟达高性能图形处理器,(笑))

这里其实两个函数写成一个就行,但是我懒得改了。

PATH = '/Users/huangyige/Downloads/fig/'def load_img(imgname):#here, only consider pion at 7.7 and 14.5 GeV, px.img = Image.open(imgname).convert('RGB')return imgdef process_img(img):img = img.crop((72,0,360,288))img = img.resize((32,32))return img

然后随意拉张图进来看看。

img = load_img(PATH+'Pion-7.7GeV-7-P1.png')
img = process_img(img)
plt.imshow(img)

只能隐约能看出来有两个峰2333333,这样的数据集能训练出来个鬼咯~

然后生成两个数据集的文件名列表(附带标签)的文档。

def generate_file(name,num_range):with open('./'+name+'.txt','w') as f:for energy in ['7.7','14.5']:for _ in num_range:imgname = PATH + 'Pion-' + energy + 'GeV-' + str(_+1) + '-P1.png'f.write(imgname+' '+energy+'n')return
generate_file('train',range(0,70))
generate_file('test',range(70,90))

就别问我为什么训练集就70张图,测试集就20张图了,只有这么点数据……可以打开文档看看效果。

差不多就这样,没什么问题。接下来定义自定义Dataset类。

class sets(torch.utils.data.Dataset):def __init__(self,datatxt,transform=None):super(sets,self).__init__()imgs = []with open(datatxt,'r') as f:for line in f:line = line.rstrip('n')words = line.split(' ')imgs.append((words[0],words[1]))self.imgs = imgsself.transform = transformreturndef __getitem__(self,index):imgname,label_o = self.imgs[index]img = load_img(imgname)img = process_img(img)if label_o == '7.7':label = 0else:label = 1if self.transform is not None:img = self.transform(img)return img,labeldef __len__(self):return len(self.imgs)

以及DataLoader。

train_set = sets('./train.txt',transforms.ToTensor())
test_set = sets('./test.txt',transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=1,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=1)

batch_size选1是不是很扯,哈哈哈我也这么觉得。如果需要做数据增强,在初始化sets时,transform参数用transforms.Compose[transforms.ToTensor(),...]这样多填几个就行了。

然后是模型的结构。

class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(3,6,5)self.pool = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(6,16,5)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)#self.fc3 = nn.Linear(84,10)self.fc3 = nn.Linear(84,2)returndef forward(self,x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1,16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

直接把pytorch官网的tutorial里CIFAR-10的模型拉出来用了,正好我已经把数据变成了32x32,参数都不用改。(修改:最后一个全链接层的神经元数应该是2而不是10,还是得改一下的)

选损失函数和优化器。

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

pytorch没有现成的算accuracy的函数所以自己写一个。

def accuracy(net,test_loder):correct = 0total = 0with torch.no_grad():for data in test_loader:inputs,labels = dataoutputs = net(inputs)_,pred = torch.max(outputs.data,1)total += labels.size(0)correct += (pred==labels).sum().item()acc = 100.0*correct/totalreturn acc

然后就可以开始训练了,本来也只是个玩具模型,所以2代就够了。不得不插一嘴,keras用起来确实要方便一点。

for epoch in range(2):running_loss = 0.0for i,data in enumerate(train_loader,0):inputs,labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs,labels)loss.backward()optimizer.step()running_loss += loss.item()acc = accuracy(net,test_loader)print('r[epoch %d >%3.d<] loss:%.3f,acc:%.1f%%'%(epoch+1,i+1,loss,acc),end='')print('')
print('Done!')

看看效果:

精度精准地锁定在50%,也就是说这模型在纯猜~whatever,这次练手主要是熟悉pytorch怎么用,模型本身的质量不重要。

最后保存一下模型:

torch.save(net.state_dict(),'./model.pth')

想开学啊啊啊啊啊……

pytorch 训练过程acc_pytorch入门练手:一个简单的CNN模型相关推荐

  1. pytorch 训练过程acc_Pytorch之Softmax多分类任务

    在上一篇文章中,笔者介绍了什么是Softmax回归及其原理.因此在接下来的这篇文章中,我们就来开始动手实现一下Softmax回归,并且最后要完成利用Softmax模型对Fashion MINIST进行 ...

  2. 一个适合于Python 初学者的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  3. python新手项目-推荐:一个适合于Python新手的入门练手项目

    原标题:推荐:一个适合于Python新手的入门练手项目 随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人 ...

  4. python新手小项目-推荐:一个适合于Python新手的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  5. python新手程序_推荐:一个适合于Python新手的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  6. python新手入门项目推荐_推荐:一个适合于Python新手的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  7. 适合新手入门的8个python项目_推荐:一个适合于Python新手的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  8. 一个适合于Python初学者的入门练手项目

    随着人工智能的兴起,国内掀起了一股Python学习热潮,入门级编程语言,大多选择Python,有经验的程序员,也开始学习Python,正所谓是人生苦短,我用Python 有个Python入门练手项目, ...

  9. Android Studio 插件开发详解一:入门练手

    转载请标明出处:http://blog.csdn.net/zhaoyanjun6/article/details/78112003 本文出自[赵彦军的博客] 系列目录 Android Gradle使用 ...

最新文章

  1. CUSTOMDRAW msdn网站
  2. 八大编程知名编程语言或系统的发展简史
  3. Java程序中AB类可调用_19年【石油大学】《Java语言程序设计》二次在线作业(100分)...
  4. [置顶] 状态压缩DP 简单入门题 11题
  5. linux命令行聊天,Linux 下使用talk 进行聊天
  6. python3+requests:get、post请求(python get、post)
  7. 使用Qemu模拟Cortex-A9运行U-boot和Linux
  8. Java 反射(Class class相关)
  9. 现代计算机内补码是多少进制,二进制:关于10000000如何表示-128的问题
  10. “AI超人”李开复慢下来的投资节奏
  11. Laravel Eloquent ORM字段处理
  12. 简单的MediaPlayer+SurfaceView实现视频横竖屏播放
  13. Jenkins流水线部署java项目
  14. 原神—薄樱初绽时(html+css+js仿原神2.5首页,前端课设)
  15. OpenCV第五章练习p163_5~8
  16. ESP8266+Flash基本操作
  17. Java多线程——线程同步
  18. 【音频处理和分析工具】上海道宁与NUGEN Audio助力您更轻松地提供高质量、合规的音频
  19. Leetcode 1231:分享巧克力(超详细的解法!!!)
  20. word怎么在下一页添加表头_简单!用2个小妙招,让Word跨页表格自动加表头!

热门文章

  1. golang 简单web服务
  2. 【职场攻略】是什么决定了我们的工资
  3. 在winform程序中启动wpf窗体
  4. windowSoftInputMode属性详解
  5. 【后缀自动机】SPOJ 1812-LCSII
  6. iOS学习笔记---oc语言第八天
  7. 13个风格独特的关于页面(About Pages)设计
  8. Python中的matplotlib xticks
  9. Leetcode_No.66 Plus One
  10. 用华为eNSP模拟器配置Hybrid、Trunk和Access三种链路类型端口