第一种神经网络的写法

假设这里有一个二分类问题:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import seaborn as snsn_data = torch.ones(100,2)
x1 = torch.normal(n_data*2,1)
y1 = torch.ones(100).unsqueeze(1)x2 = torch.normal(-2*n_data,1)
y2 = torch.zeros(100).unsqueeze(1)print(x1.shape,y1.shape)  # torch.Size([100, 2]) torch.Size([100, 1])x = torch.cat((x1,x2),0)
y = torch.cat((y1,y2),0)
x, y = Variable(x),Variable(y)

我们画一下图:

# 我们大致画一下图,画图的时候需要将Variable类型转为numpy
import pandas as pd
sns.set()
x_plot = x.data.numpy()[:,0]
y_plot = x.data.numpy()[:,1]
special = y.data.numpy()[:,0]pdata = {"x_plot":x_plot, "y_plot":y_plot,"special":special}
df = pd.DataFrame(pdata)
sns.relplot(x="x_plot", y="y_plot", hue="special",data=df)


下面是定义神经网络:

class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.hidden = nn.Linear(2,10)self.predict = nn.Linear(10,2)def forward(self,x):x = self.hidden(x)x = F.relu(x)x = self.predict(x)return x
net = Net()
print(net)
-------------------------result-----------------------------
Net((hidden): Linear(in_features=2, out_features=10, bias=True)(predict): Linear(in_features=10, out_features=2, bias=True)
)

下面是训练过程:

optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)
loss_func = nn.CrossEntropyLoss()  # 专门训练多分类问题的
for i in range(500):  # 训练500次out = net(x)loss = loss_func(out,y.long().squeeze()) # 这里的y必须转成squeeze,因为200*1 和 200 是不同的optimizer.zero_grad()loss.backward()optimizer.step()if i % 5 == 0:y_pre = torch.max(F.softmax(out),1)[1].squeeze()  # F.softmax,转成(0.1,0.9)这种样子的分类,而max返回最大值,[1] 返回最大值的下标,也就是预测出[0.3,0.7],返回0.7的下标1,这样就转化成了一个0,1分类问题sums = sum(y_pre == y.data.squeeze())accu = sums.numpy()/y.shape[0]  # 这里亲测,不转成numpy相除,会有问题print("准确率:",accu)

第二种神经网络的写法

上面的神经网络是通过定义一个类这种形式,下面换一种新的形式来写神经网络:

net2 = torch.nn.Sequential(torch.nn.Linear(2,10),torch.nn.ReLU(),torch.nn.Linear(10,2)
)
print(net2)
------------result-------------
Sequential((0): Linear(in_features=2, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=2, bias=True)
)  输入为2个chanel,隐藏层为10,输出层为2
-------------------------------
除了定义不一样之外,其他的用法和上面的无异

神经网络的保存

我们就拿上面训练好的神经网络net1,为例子:

torch.save(net1,"net1.pkl")   # 它保存的是整个神经网络
torch.save(net1.state_dict(),"net1_paras.pkl")  # 保存的是整个神经网络的参数,这个比上面那种保存方式快那么一点点

分批次训练

下面实现一下简单的分批:

import torch
import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True,
)
for epoch in range(3):for step, (batch_x,batch_y) in enumerate(loader):print(epoch,step,batch_x.numpy(),batch_y.numpy())
------------------------result--------------------------------
0 0 [ 5.  3.  4.  6. 10.] [6. 8. 7. 5. 1.]
0 1 [8. 1. 9. 2. 7.] [ 3. 10.  2.  9.  4.]
1 0 [1. 8. 3. 2. 6.] [10.  3.  8.  9.  5.]
1 1 [ 4.  9.  7.  5. 10.] [7. 2. 4. 6. 1.]
2 0 [4. 2. 7. 5. 3.] [7. 9. 4. 6. 8.]
2 1 [ 8.  6.  9.  1. 10.] [ 3.  5.  2. 10.  1.]

Pytorch入门-2相关推荐

  1. PyTorch | (2)PyTorch 入门-张量

    PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch 是一个基于 Python 的科学计算包,主要定位两类人群: NumPy 的替代品,可 ...

  2. PyTorch入门v2.pptx

    给本科生入门深度学习的PyTorch使用总结,图比较多. 若需要原版ppt,关注公众号,回复"pytorch入门"可以获取下载地址. 若需要原版ppt,关注公众号,回复" ...

  3. PyTorch 入门实战

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_36556893/article/ ...

  4. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

  5. [pytorch] Pytorch入门

    Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...

  6. pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)

    点击关注我哦 autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现 ...

  7. 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门

    1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...

  8. pytorch 入门学习多分类问题-9

    pytorch 入门学习多分类问题 运行结果 [1, 300] loss: 2.287[1, 600] loss: 2.137[1, 900] loss: 1.192 Accuracy on test ...

  9. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  10. pytorch 入门学习处理多维特征输入-7

    pytorch 入门学习处理多维特征输入 处理多维特征输入 import torch import numpy as np import torchvision import numpy as np ...

最新文章

  1. Nginx缓存设置教程
  2. PHP如何将表单提交给自己
  3. 维护一套同时兼容 iOS 6 和 iOS 7,并且能够自动适应两个系统的 UI 风格的代码...
  4. ASP.NET运行环境配置
  5. ThinkCMF 和 OneThink内容管理系统对比
  6. 避免关注底层硬件,Nvidia将机器学习与GPU绑定
  7. Mac在命令行中打开Finder
  8. 什么是php渲染,php数据渲染输出
  9. asp.net 网页做一个浮动层_网页的布局与定位看完这篇就够了
  10. java读取gpx文件格式,Gpx文件基本格式及常见错误——乱码
  11. 零基础自学软件测试-项目经验-电商项目实战-测试用例设计-促销中心
  12. Docker-基本命令和漏洞分享
  13. 足球足球裁判昏招大全裁判昏招大全
  14. 算法及其复杂性分析整理
  15. TX-LCN和Seata
  16. 利用计算机对多媒体进行综合处理,多媒体技术复习题及答案
  17. avr模拟串口通讯c语言,AVR的模拟串口的问题
  18. ip地址、网络地址、网关、域名
  19. 五大常用算法一(回溯,随机化,动态规划)
  20. 技术创业者必读:从验证想法到技术产品商业化的全方位解析

热门文章

  1. 20145234黄斐《java程序设计》第十三周代码检查
  2. Linux命令之mv
  3. Openfire性能优化与压力测试小结
  4. 每天看了哪些技术点,都记录在该文章下面,时常回过头来看看。
  5. 解决无法使用locate命令的方法
  6. QT与opencv(二)开启摄像头
  7. python运维开发笔记5
  8. 程序员养生(01) -- 心态
  9. 互斥对象与关键代码段的比较
  10. Docker持续交付部署类型