模型的保存

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 载入训练集transform=transforms.ToTensor(), # 把数据变成tensor类型download = True # 下载)
# 测试集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break
# 定义网络结构
class LSTM(nn.Module):def __init__(self):super(LSTM,self).__init__()# 初始化self.lstm = torch.nn.LSTM(input_size = 28, # 表示输入特征的大小hidden_size = 64, # 表示lstm模块的数量num_layers = 1, # 表示lstm隐藏层的层数batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature))self.out = torch.nn.Linear(in_features=64,out_features=10)self.softmax = torch.nn.Softmax(dim=1)def forward(self,x):# (batch,seq_len,feature)x = x.view(-1,28,28)# output:(batch,seq_len,hidden_size)包含每个序列的输出结果# 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers# h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果# c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果output,(h_n,c_n) = self.lstm(x)output_in_last_timestep = h_n[-1,:,:]x = self.out(output_in_last_timestep)x = self.softmax(x)return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():# 模型的训练状态model.train()for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 交叉熵代价函数out(batch,C:类别的数量),labels(batch)loss = mse_loss(out,labels)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():# 模型的测试状态model.eval()correct = 0 # 测试集准确率for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))correct = 0for i,data in enumerate(train_loader): # 训练集准确率# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(5):print("epoch:",epoch)train()test()

torch.save(model.state_dict(),"./my_model.pth") # 模型的保存

模型加载

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 载入训练集transform=transforms.ToTensor(), # 把数据变成tensor类型download = True # 下载)
# 测试集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break
# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 全连接层self.fc1 = nn.Linear(784,10)self.softmax = nn.Softmax(dim=1)def forward(self,x):x = x.view(x.size()[0],-1)x = self.fc1(x)x = self.softmax(x)return x
# 定义模型
model = Net()
# 模型载入
model.load_state_dict(torch.load("./my_model.pth")) # 模型的加载
# 定义代价函数
mse_loss = nn.MSELoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
def train():for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# to onehot 把数据标签变成独热编码labels = labels.reshape(-1,1) # 先把1维变成2维(64)-(64,1)# tensor.scatter(dim,index,src)# dim:对那个维度进行独热编码# index:要将src中对应的值放到tensor那个位置# src:插入index的数值one_hot = torch.zeros(inputs.shape[0],10).scatter(1,labels,1)# 计算loss   mse_loss的两个数据的shape要一致loss = mse_loss(out,one_hot)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))correct = 0for i,data in enumerate(train_loader): # 训练集准确率# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Train acc:{0}".format(correct.item()/len(train_data)))
# # 训练
# for epoch in range(2):
#     print("epoch:",epoch)
#     train()
test()

PyTorch基础-模型的保存和加载-09相关推荐

  1. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  2. PyTorch | 模型的保存和加载

    PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...

  3. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  4. numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...

    一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...

  5. 线性回归之模型的保存和加载

    线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib   [目前这行代码报错,直接写import joblib ...

  6. paddlepaddle模型的保存和加载

    导读 深度学习中模型的计算图可以被分为两种,静态图和动态图,这两种模型的计算图各有优劣. 静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当 ...

  7. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  8. 调gensim库,word2vec模型的保存和加载

    一.模型的保存 模型保存可以有很多种格式,根据格式的不同可以分为2种,一种是保存为.model的文件,一种是非.model文件的保存.我常用的保存格式是.model和.vector直接上代码和结果: ...

  9. pytorch 不同设备下保存和加载模型,需要指定设备

最新文章

  1. JavaScript中使用console调试程序的坑
  2. rust图形编程_国产编程语言“木兰”,你以为是个王者,结果是个玩笑
  3. 简单的客户端,服务端通信
  4. IELE:区块链的一个新虚拟机
  5. 【jenkins】jenkins build项目的三种方式
  6. 拉格朗日插值_拉格朗日插值定理的理论基础
  7. 微服务SpringCloud系列
  8. 调用百度地图 API 移动地图时 maker 始终在地图中间 并根据maker 经纬度 返回地址...
  9. 教师国培计算机计划,教师国培计划大全
  10. jboss forge整合 hibersap
  11. Tera Term (串口工具)永久保存设置的字体和框体大小
  12. php mysql 博客_基于PHP+MySQL的个人博客系统
  13. ccf---导弹防御系统
  14. iis 支持apk json ipa下载
  15. 携程和12306解绑
  16. MATLAB使用基本操作
  17. CAP 和 Zookeeper
  18. 生物统计学(biostatistics)学习笔记(四)统计推断(已知样本推总体)
  19. CentOS8 解决SSH Secure Shell 报错 Algorithm negotiation failes
  20. 成功解决Solving environment: failed with initial frozen solve. Retrying with flexible solve.

热门文章

  1. g11 android 4.4,HTC G11 Incredible S 稳定流畅Android4.0.4华丽体验Sense4.1 省电耐用
  2. Curl http_code 状态码
  3. PHP使用fpdf生成pdf文件(含中文类)
  4. php 生成ai文件,php_Generator php 生成器
  5. python中显示第三行数据_Python从零开始第三章数据处理与分析①python中的dplyr(1)...
  6. 计算机心得300,计算机实训总结计算机实训心得300
  7. optional判断是否为空_乐字节Java8核心特性之Optional
  8. php网页轮播图,JavaScript_JavaScript实现图片轮播的方法,本文实例讲述了JavaScript实现图 - phpStudy...
  9. hutool读取和导出excel_Java编程第44讲——非常好用的hutool工具介绍
  10. c语言复制后无法运行,刚学C语言,在Linux下写的代码能正常编译,复制到VC下就无法运行...