下面是用MNIST手写字符数据从数据loader到全连接网络设计、模型训练、模型测试、模型存储的全过程完整代码,仔细品味可供学习使用。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(),  download=True)test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outmodel = NeuralNet(input_size, hidden_size, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):  # Move tensors to the configured deviceimages = images.reshape(-1, 28*28).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

Pytorch实战从入门到精通第一部分——手写字符识别全流程相关推荐

  1. pytorch实战从入门到精通第三部分——数据处理

    计算数据集的均值和标准差 import os import cv2 import numpy as np from torch.utils.data import Dataset from PIL i ...

  2. pytorch实战从入门到精通第二部分——卷积神经网络

    两层卷积网络的示例 # convolutional neural network (2 convolutional layers) class ConvNet(nn.Module):def __ini ...

  3. Pytorch实战 | 第P1周:实现mnist手写数字识别

  4. 【Python】Python实战从入门到精通之四 -- 教你使用Python中字典

    本文是Python实战–从入门到精通系列的第四篇文章: Python实战从入门到精通第一讲–Python中的变量和数据类型 Python实战从入门到精通第二讲–Python中列表操作详解 Python ...

  5. 【Python】Python实战从入门到精通之五 -- 教你使用文件写入

    本文是<Python实战从入门到精通>系列之第5篇 [Python]Python实战从入门到精通之一 -- 教你深入理解Python中的变量和数据类型 [Python]Python实战从入 ...

  6. 【Python】Python实战从入门到精通之一 -- 教你深入理解Python中的变量和数据类型

    本文是Python实战–从入门到精通系列的第一篇文章: Python实战从入门到精通之一 – 教你深入理解Python中的变量和数据类型 文章目录 1.变量 1.1 变量命名规则 1.2 变量名称错误 ...

  7. 黑客零基础入门教程:「黑客攻防实战从入门到精通(第二版)」堪称黑客入门天花板

    前言 您知道在每天上网时,有多少黑客正在浏览您计算机中的重要数据吗﹖黑客工具的肆意传播,使得即使是稍有点计算机基础的人,就可以使用简单的工具对网络中一些疏于防范的主机进行攻击,在入侵成功之后,对其中的 ...

  8. unity应用开发实战案例_Unity3D游戏引擎开发实战从入门到精通

    Unity3D游戏引擎开发实战从入门到精通(坦克大战项目实战.NGUI开发.GameObject) 一.Unity3D游戏引擎开发实战从入门到精通是怎么样的一门课程(介绍) 1.1.Unity3D游戏 ...

  9. 【Python】Python实战从入门到精通之七 -- 教你深入理解异常处理

    本文是<Python实战从入门到精通>系列之第7篇 [Python]Python实战从入门到精通之一 -- 教你深入理解Python中的变量和数据类型 [Python]Python实战从入 ...

最新文章

  1. 用于半监督语义分割的基于掩码的数据增强
  2. 编译php时错误make ***[libphp5.la] Error 1
  3. 设置/修改centos上的swap交换分区的方法
  4. JSP完全自学手册图文教程
  5. 解决内存瓶颈和计算负载问题,韩松团队提出 MCUNetV2
  6. word 插入代码_突破Word页码困境,这招简单又实用的自动更新法,90%的人还不会!...
  7. 《LeetBook》leetcode题解(5):Longest Palindromic [M]——回文串判断
  8. 析构函数为虚函数的必要性
  9. 算法----迷宫问题
  10. redhat7 防火墙设置
  11. 使用apt更新和升级系统软件
  12. arcgis字段计算器赋值_ArcGIS中62个常用应用技巧汇总【必须收藏】
  13. HTML页面转PDF 思路
  14. 星界边境文本自动翻译机(高级版)使用说明
  15. vSphere Client连接主机提示远程服务器响应时间过长
  16. 中兴echat_公网对讲机都有哪些平台?
  17. 小程序----个人中心页面
  18. 张家口北方学院计算机是专科,河北北方学院有哪些专科专业
  19. 【雕爷学编程】Arduino动手做(5)---热敏温度传感器模块
  20. Linux下安装CMake的方法

热门文章

  1. encapsulation
  2. 双马尾机器人(???)
  3. CF988 D. Points and Powers of Two【hash/数学推理】
  4. vue基础知识之vue-resource/axios
  5. jQuery链式操作[转]
  6. PLSA隐变量主题模型的公式推导解惑
  7. 关于提高网站性能的几点建议(二)
  8. dubbo升级spring4与cxf
  9. ios学习8_KVC和字典转模型
  10. iis 支持html执行php输出