Pytorch入门-2
第一种神经网络的写法
假设这里有一个二分类问题:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import seaborn as snsn_data = torch.ones(100,2)
x1 = torch.normal(n_data*2,1)
y1 = torch.ones(100).unsqueeze(1)x2 = torch.normal(-2*n_data,1)
y2 = torch.zeros(100).unsqueeze(1)print(x1.shape,y1.shape) # torch.Size([100, 2]) torch.Size([100, 1])x = torch.cat((x1,x2),0)
y = torch.cat((y1,y2),0)
x, y = Variable(x),Variable(y)
我们画一下图:
# 我们大致画一下图,画图的时候需要将Variable类型转为numpy
import pandas as pd
sns.set()
x_plot = x.data.numpy()[:,0]
y_plot = x.data.numpy()[:,1]
special = y.data.numpy()[:,0]pdata = {"x_plot":x_plot, "y_plot":y_plot,"special":special}
df = pd.DataFrame(pdata)
sns.relplot(x="x_plot", y="y_plot", hue="special",data=df)
下面是定义神经网络:
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.hidden = nn.Linear(2,10)self.predict = nn.Linear(10,2)def forward(self,x):x = self.hidden(x)x = F.relu(x)x = self.predict(x)return x
net = Net()
print(net)
-------------------------result-----------------------------
Net((hidden): Linear(in_features=2, out_features=10, bias=True)(predict): Linear(in_features=10, out_features=2, bias=True)
)
下面是训练过程:
optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)
loss_func = nn.CrossEntropyLoss() # 专门训练多分类问题的
for i in range(500): # 训练500次out = net(x)loss = loss_func(out,y.long().squeeze()) # 这里的y必须转成squeeze,因为200*1 和 200 是不同的optimizer.zero_grad()loss.backward()optimizer.step()if i % 5 == 0:y_pre = torch.max(F.softmax(out),1)[1].squeeze() # F.softmax,转成(0.1,0.9)这种样子的分类,而max返回最大值,[1] 返回最大值的下标,也就是预测出[0.3,0.7],返回0.7的下标1,这样就转化成了一个0,1分类问题sums = sum(y_pre == y.data.squeeze())accu = sums.numpy()/y.shape[0] # 这里亲测,不转成numpy相除,会有问题print("准确率:",accu)
第二种神经网络的写法
上面的神经网络是通过定义一个类这种形式,下面换一种新的形式来写神经网络:
net2 = torch.nn.Sequential(torch.nn.Linear(2,10),torch.nn.ReLU(),torch.nn.Linear(10,2)
)
print(net2)
------------result-------------
Sequential((0): Linear(in_features=2, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=2, bias=True)
) 输入为2个chanel,隐藏层为10,输出层为2
-------------------------------
除了定义不一样之外,其他的用法和上面的无异
神经网络的保存
我们就拿上面训练好的神经网络net1,为例子:
torch.save(net1,"net1.pkl") # 它保存的是整个神经网络
torch.save(net1.state_dict(),"net1_paras.pkl") # 保存的是整个神经网络的参数,这个比上面那种保存方式快那么一点点
分批次训练
下面实现一下简单的分批:
import torch
import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True,
)
for epoch in range(3):for step, (batch_x,batch_y) in enumerate(loader):print(epoch,step,batch_x.numpy(),batch_y.numpy())
------------------------result--------------------------------
0 0 [ 5. 3. 4. 6. 10.] [6. 8. 7. 5. 1.]
0 1 [8. 1. 9. 2. 7.] [ 3. 10. 2. 9. 4.]
1 0 [1. 8. 3. 2. 6.] [10. 3. 8. 9. 5.]
1 1 [ 4. 9. 7. 5. 10.] [7. 2. 4. 6. 1.]
2 0 [4. 2. 7. 5. 3.] [7. 9. 4. 6. 8.]
2 1 [ 8. 6. 9. 1. 10.] [ 3. 5. 2. 10. 1.]
Pytorch入门-2相关推荐
- PyTorch | (2)PyTorch 入门-张量
PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch 是一个基于 Python 的科学计算包,主要定位两类人群: NumPy 的替代品,可 ...
- PyTorch入门v2.pptx
给本科生入门深度学习的PyTorch使用总结,图比较多. 若需要原版ppt,关注公众号,回复"pytorch入门"可以获取下载地址. 若需要原版ppt,关注公众号,回复" ...
- PyTorch 入门实战
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_36556893/article/ ...
- numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践
<<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗 读完<<深度学习框架PyTorc ...
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)
点击关注我哦 autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现 ...
- 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门
1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...
- pytorch 入门学习多分类问题-9
pytorch 入门学习多分类问题 运行结果 [1, 300] loss: 2.287[1, 600] loss: 2.137[1, 900] loss: 1.192 Accuracy on test ...
- pytorch 入门学习加载数据集-8
pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...
- pytorch 入门学习处理多维特征输入-7
pytorch 入门学习处理多维特征输入 处理多维特征输入 import torch import numpy as np import torchvision import numpy as np ...
最新文章
- Nginx缓存设置教程
- PHP如何将表单提交给自己
- 维护一套同时兼容 iOS 6 和 iOS 7,并且能够自动适应两个系统的 UI 风格的代码...
- ASP.NET运行环境配置
- ThinkCMF 和 OneThink内容管理系统对比
- 避免关注底层硬件,Nvidia将机器学习与GPU绑定
- Mac在命令行中打开Finder
- 什么是php渲染,php数据渲染输出
- asp.net 网页做一个浮动层_网页的布局与定位看完这篇就够了
- java读取gpx文件格式,Gpx文件基本格式及常见错误——乱码
- 零基础自学软件测试-项目经验-电商项目实战-测试用例设计-促销中心
- Docker-基本命令和漏洞分享
- 足球足球裁判昏招大全裁判昏招大全
- 算法及其复杂性分析整理
- TX-LCN和Seata
- 利用计算机对多媒体进行综合处理,多媒体技术复习题及答案
- avr模拟串口通讯c语言,AVR的模拟串口的问题
- ip地址、网络地址、网关、域名
- 五大常用算法一(回溯,随机化,动态规划)
- 技术创业者必读:从验证想法到技术产品商业化的全方位解析