pytorch 入门学习多分类问题

运行结果

 [1,  300] loss: 2.287[1,  600] loss: 2.137[1,  900] loss: 1.192
Accuracy on test set: 78 % [2,  300] loss: 0.560[2,  600] loss: 0.422[2,  900] loss: 0.361
Accuracy on test set: 90 % [3,  300] loss: 0.307[3,  600] loss: 0.292[3,  900] loss: 0.258
Accuracy on test set: 93 % [4,  300] loss: 0.228[4,  600] loss: 0.221[4,  900] loss: 0.201
Accuracy on test set: 94 % [5,  300] loss: 0.178[5,  600] loss: 0.178[5,  900] loss: 0.158
Accuracy on test set: 95 % [6,  300] loss: 0.141[6,  600] loss: 0.139[6,  900] loss: 0.144
Accuracy on test set: 96 % [7,  300] loss: 0.129[7,  600] loss: 0.116[7,  900] loss: 0.114
Accuracy on test set: 96 % [8,  300] loss: 0.107[8,  600] loss: 0.100[8,  900] loss: 0.106
Accuracy on test set: 96 % [9,  300] loss: 0.091[9,  600] loss: 0.088[9,  900] loss: 0.089
Accuracy on test set: 96 % [10,  300] loss: 0.079[10,  600] loss: 0.074[10,  900] loss: 0.080
Accuracy on test set: 96 % Process finished with exit code 0
import  torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim#step1 准备数据集batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), #对图像数据进行处理,归一化到0到1transforms.Normalize((0.137,),(0.3081,)) #平均值和标准差,mnist数据集提前算出来的
])train_dataset = datasets.MNIST(root='../dataset/mnist',train=True,download=True,transform=transform)train_loder = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist',train=False,download=True,transform=transform)test_loder = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)#step2 搭建网络
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784,512)self.l2 = torch.nn.Linear(512,256)self.l3 = torch.nn.Linear(256,128)self.l4 = torch.nn.Linear(128,64)self.l5 = torch.nn.Linear(64,10)def forward(self,x):x = x.view(-1,784)x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)model = Net()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)#step3 训练
def train(epoch):running_loss = 0.0for batch_idx,data in enumerate(train_loder,0):inputs,target = dataoptimizer.zero_grad()    #梯度清零#forward + backward + updateoutputs = model(inputs)loss = criterion(outputs,target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print(' [%d,%5d] loss: %.3f' % (epoch + 1,batch_idx + 1, running_loss / 300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():      #不计算梯度for data in test_loder:images,labels = dataoutputs = model(images)_,predicted = torch.max(outputs.data,dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %% '%(100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

pytorch 入门学习多分类问题-9相关推荐

  1. pytorch 入门学习使用逻辑斯蒂做二分类-6

    pytorch 入门学习使用逻辑斯蒂做二分类 使用pytorch实现逻辑斯蒂做二分类 import torch import torchvision import numpy as np import ...

  2. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  3. pytorch 入门学习处理多维特征输入-7

    pytorch 入门学习处理多维特征输入 处理多维特征输入 import torch import numpy as np import torchvision import numpy as np ...

  4. pytorch 入门学习 实现线性回归-5

    pytorch 入门学习实现线性回归 使用pytorch实现线性回归 import numpy as np import matplotlib.pyplot as plt import torch#p ...

  5. pytorch 入门学习反向传播-4

    pytorch 入门学习反向传播 反向传播 import numpy as np import matplotlib.pyplot as plt import torchdef forward(x): ...

  6. 程序媛养成第0天--pytorch入门学习

    本篇基于<深度学习框架-pytorch入门与实践>陈云 有一起监督学习打卡的小伙伴请私信 2.2 pytorch入门第一步 2.2.1 Tensor # 分配矩阵空间但不初始化 #使用 [ ...

  7. pytorch 入门学习 MSE

    <PyTorch深度学习实践>完结合集-线性模型 import numpy as np import matplotlib.pyplot as pltx_data = [1.0,2.0,3 ...

  8. PyTorch入门学习-4.自然语言分类任务

    一. 情感分析 1. 准备数据 TorchText中的一个重要概念是Field.Field决定了你的数据会被怎样处理.在我们的情感分类任务中,我们所需要接触到的数据有文本字符串和两种情感," ...

  9. PyTorch入门-简单图片分类

    一. CNN图像分类 PyTorch Version: 1.0.0 import torch import torch.nn as nn import torch.nn.functional as F ...

最新文章

  1. 跳表SkipList
  2. vins中imu融合_VINS-Mono代码分析与总结(最终版)
  3. Android高版本开机广播,android3.1以上,假如程序没有启动过,怎么获取开机广播呢?...
  4. 跟我一起学.NetCore之配置变更监听
  5. HTTP之Last-Modified、Etage、If-Modified-Since理论与实践(C++ Qt实现)
  6. 遍历列表python_python中列表的遍历
  7. java ajax传递到action_ajax传值到action,后台取不到值。
  8. SPOJ371 Boxes(最小费用最大流)
  9. 慢慢看Spring源码
  10. vim文件时自动添加作者、时间、版权等信息
  11. 未处理的异常: 0xC0000005: 读取位置 0x00000000 时发生访问冲突
  12. Java课程报告实验总结,java实验报告总结 [Java课程设计实验报告]
  13. 人月到底有多少神话色彩
  14. 图片像素大小怎么调整,批量调整图片像素
  15. 无线传感网路由协议(一)
  16. 【校招VIP】前端专业课考点之tcp与udp
  17. 微信小程序制作简单的商品列表页,实现价格求和
  18. 今年电商圈618活动很安静!
  19. 20200607:根据中证800指数最近十年历史P/b分位数确认基金目标仓位
  20. 【python--爬虫】爬取淘女郎照片

热门文章

  1. python爬虫中文乱码解决方法
  2. 在提交消息中链接到GitHub上的问题编号
  3. (配置消息转换器) sso单点登入之jsonp改进版
  4. 用java写猜拳游戏,Java写人机猜拳游戏(可扩展其他游戏或其他参与者)
  5. 读取usb口数据_Mixly 第12课 模拟值读取实验串口使用
  6. python离线安装flask_离线环境下安装flask
  7. 简单实现x的n次方pta_数学学霸的解题思路1“降低次方和次元”
  8. MTK:内存管理机制简单分析
  9. fun php,fun.php
  10. 计算机文档我的文档丢失,“我的文档”不见了如何找回?几种解决“我的文档不见了”的办法...