以下代码均来自bilibili:[适用于初学者的Pytorch编程教学]

以下为完整代码,复制即可运行。

import torch
import time
import json
import torchvision
import torchvision.transforms as transforms # 将图像数据转化为torch.tensor张量import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F # 用到了F.relu 和 F.cross_entropy
import torch.nn as nn
import torch.optim as optim # optim.Adam 优化器from IPython.display import clear_output #
from torch.utils.data import DataLoader
from itertools import product
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix # 绘制混淆矩阵
from collections import OrderedDict
from collections import namedtupletorch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True) # 打开梯度(grad)追踪
<torch.autograd.grad_mode.set_grad_enabled at 0x19e1fbc2f48>
# 这里建立了一个神经网络的模型(其实就是一个函数,用来预测输入张量的类)
class Network(nn.Module):def __init__(self):super(Network,self).__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,stride=1) # 输入通道的 1 代表输入的图像是单通道(灰度)的 stride表示的是滤波器 # 每次移动的步长self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)self.fc1 = nn.Linear(in_features=12*4*4,out_features=120,bias = True) # bias 表示 偏差self.fc2 = nn.Linear(in_features=120,out_features=60)self.out = nn.Linear(in_features=60,out_features=10) # 这里的 10 的意思是,分成10个类def __repr__(self):return "this is the string of the lizard"def forward(self,t):# (1) input layert = t#(2) hidden conv layert = F.max_pool2d(F.relu(self.conv1(t)),kernel_size=2,stride=2) # 在2X2的子矩阵中选出最大值#(3) hidden conv layert = F.max_pool2d(F.relu(self.conv2(t)),kernel_size=2,stride=2)#(4) hidden Liner layert = t.reshape(-1,12*4*4)t = F.relu(self.fc1(t))#(5) hidden Liner layert = F.relu(self.fc2(t))#(6) output layert = self.out(t)return t
# 加载数据集,如果相应文件下没有数据集合化就下载 download = True
train_set = torchvision.datasets.FashionMNIST( # 下载数据集root = './data/FashionMNIST', # 数据集存放的位置train=True, # True表示下载的是训练集download=True, # 如果上述位置中不存在数据集的话就下载transform = transforms.Compose( # 将数据集转化为我们需要的张量的类型[transforms.ToTensor()]))
class RunBuilder():@staticmethoddef get_runs(params):Run = namedtuple('Run',params.keys())runs = []for v in product(*params.values()):runs.append(Run(*v))return runs
class RunManager():def __init__(self):self.epoch_count = 0self.epoch_loss = 0self.epoch_num_correct = 0self.epoch_start_time = 0self.run_params = Noneself.run_count = 0self.run_data = []self.run_start_time = Noneself.network = Noneself.loader = Noneself.tb = Nonedef begin_run(self, run, network, loader):self.run_start_time = time.time()self.run_params = runself.run_count += 1self.network = networkself.loader = loaderself.tb = SummaryWriter(comment = f'-{run}')images,labels = next(iter(self.loader))grid = torchvision.utils.make_grid(images)self.tb.add_image('images',grid)self.tb.add_graph(network,images)def end_run(self):self.tb.close()self.epoch_count = 0def begin_epoch(self):self.epoch_start_time = time.time()self.epoch_count += 1self.epoch_loss = 0self.epoch_num_correct = 0def end_epoch(self):epoch_duration = time.time() - self.epoch_start_timerun_duration = time.time() - self.run_start_timeloss = self.epoch_loss / len(self.loader) # .datasetaccuracy = self.epoch_num_correct.item() / len(self.loader.dataset)*100self.tb.add_scalar('Loss', loss, self.epoch_count)self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)for name,param in self.network.named_parameters():self.tb.add_histogram(name, param, self.epoch_count)self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)results = OrderedDict()results['run'] = self.run_countresults['epoch'] = self.epoch_countresults['loss'] = lossresults['accuracy'] = f'{accuracy:.2f}%'results['epoch duration'] = epoch_durationresults['run duration'] = run_durationfor k,v in self.run_params._asdict().items():results[k] = vself.run_data.append(results)df = pd.DataFrame.from_dict(self.run_data, orient='columns')clear_output(wait = True)display(df)#         results = OrderedDict()def track_loss(self, loss):self.epoch_loss += loss.item()def track_num_correct(self, preds, labels):self.epoch_num_correct += self._get_num_correct(preds, labels)@torch.no_grad()def _get_num_correct(self,preds,labels):return torch.argmax(F.softmax(preds),dim=1).eq(labels).sum()def save(self, filename):pd.DataFrame.from_dict(self.run_data,orient = 'columns').to_csv(f'{filename}.csv')with open(f'{filename}.json', 'w+', encoding='utf-8') as f:for dic in self.run_data:json.dump(dic, f, ensure_ascii=False, indent=4) #
params = OrderedDict(lr = [.01],batch_size = [1000]
#     ,shuffle = [True, False],num_workers = [1]
)
m = RunManager()
for  run in RunBuilder.get_runs(params):network = Network()loader = DataLoader(train_set,batch_size=run.batch_size,num_workers=run.num_workers)optimizer = optim.Adam(network.parameters(), lr=run.lr)m.begin_run(run, network, loader)for epoch in range(10):m.begin_epoch()for batch in loader:images, labels = batchpreds = network(images)loss = F.cross_entropy(preds,labels)optimizer.zero_grad()loss.backward()optimizer.step()m.track_loss(loss)m.track_num_correct(preds,labels)m.end_epoch()m.end_run()
m.save('results')
run epoch loss accuracy epoch duration run duration lr batch_size num_workers
0 1 1 0.943128 63.89% 5.532876 10.976489 0.01 1000 1
1 1 2 0.537875 78.96% 5.653928 16.732941 0.01 1000 1
2 1 3 0.442508 83.57% 5.632912 22.457873 0.01 1000 1
3 1 4 0.387121 85.72% 5.735993 28.281893 0.01 1000 1
4 1 5 0.355424 86.92% 5.648000 34.019917 0.01 1000 1
5 1 6 0.330770 87.83% 5.721980 39.835942 0.01 1000 1
6 1 7 0.315759 88.22% 5.666653 45.596620 0.01 1000 1
7 1 8 0.296206 88.97% 5.591570 51.281192 0.01 1000 1
8 1 9 0.289356 89.29% 5.714649 57.087968 0.01 1000 1
9 1 10 0.278595 89.71% 5.812043 62.992033 0.01 1000 1
# 保存模型 方法 1 (这个方法需要在读取模型的文件中import相应包和对模型定义)
from sklearn.externals import joblib
import os dirs = 'testModels'
if not os.path.exists(dirs):os.makedirs(dirs)joblib.dump(network, dirs+'/network.pkl')# 读取模型
read_network = joblib.load(dirs+'/network.pkl')
['testModels/network.pkl']

# 读取模型 方法 2
# 保存
torch.save(network,dirs+'/network.pt')
# 读取
read_model = torch.load(dirs+'/network.pt')

python_torch_加载数据集_构建模型_构建训练循环_保存和调用训练好的模型相关推荐

  1. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  2. R语言构建xgboost模型:使用xgb.DMatrix保存、加载数据集、使用getinfo函数抽取xgb.DMatrix结构中的数据

    R语言构建xgboost模型:使用xgb.DMatrix保存.加载数据集.使用getinfo函数抽取xgb.DMatrix结构中的数据 目录

  3. Pytorch加载数据集的方式总结

    Pytorch加载数据集的方式总结 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加 ...

  4. FlexCell控件初始化以及加载数据集[原创]

    '================================写在之前的话 抱歉,一直没有时间,所以FlexCell作者给我的几种加载数据集方法的代码一直没有发出来. 同时再次感谢FlexCell ...

  5. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  6. Pytorch深度学习(五):加载数据集以及mini-batch的使用

    Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...

  7. pytorch创建自己的Dataset加载数据集

    文章目录 创建一个类并继承torch.utils.data.dataset.Datase类 创建__getitem__方法 加载数据集 创建一个类并继承torch.utils.data.dataset ...

  8. python加载数据集卡住 dmesg报错Nvidia xid31

    在一次运维中发现客户加载数据集会卡住,物理机总共是4块显卡.使用k8s独占显卡进行任务训练,其中有三块显卡在跑任务训练加载数据集时卡住,同时查看dmesg报错 (xid 31). [Tue Apr 1 ...

  9. URLError: <urlopen error [Errno 11004] getaddrinfo failed>关于使用seabron加载数据集报错的解决方案

    在使用seaborn加载内置数据集时,出现以下错误: dataset = sns.load_dataset("iris") dataset.head() 解决方案: 一.原因需要连 ...

最新文章

  1. C语言与OpenCL的编程示例比较
  2. 第十五届全国大学生智能车竞赛各分赛区赛道数量以及比赛系统数量
  3. Science发布2021年度十大科学突破榜单:除了AlphaFold2,还有哪些大丰收?
  4. JavaScript高级应用(二)
  5. java写微信小程序答辩问题_微信小程序毕业设计选题和毕业论文怎么写,答辩流程是怎样的?...
  6. dubbo原理_dubbo实现原理介绍
  7. MPU6050参考代码
  8. EBS R12.2 ADOP (R12.2 AD Online Patching) - 5
  9. 云炬Android开发笔记 3-1项目架构初始化
  10. OpenCV拼接细节stitching detailed的实例(附完整代码)
  11. 动人配乐是如何炼成的?带您了解《花之灵》背景原声的幕后制作秘辛
  12. 这里有一份面筋请查收(六)
  13. inventor如何钣金出弧面_Inventor教程之钣金多规则
  14. python嵌套html开发gui_python GUI库图形界面开发之PyQt5表单布局控件QFormLayout详细使用方法与实例...
  15. 数据库原理—数据库基础(二)
  16. Android系统性能调优工具介绍
  17. 阿里云运行python项目_荐个人博客开发-06:Nginx + uWSGI + Django项目部署到阿里云服务器运行...
  18. 版本 tomcat_Tomcat爆出安全漏洞!Spring Cloud/Boot框架多个版本受影响
  19. 泰坦尼克 (有剧透)
  20. 节气朔望时刻计算和日食月食预测

热门文章

  1. 行车记录仪改家用监控求助
  2. 华工计算机工程学院,计算机工程学院赴华工兄弟学院交流学习
  3. 【Java整合Milvus】SpringBoot整合Milvus向量数据库以及虹软SDK实现以图搜图
  4. 【Flutter】Dart 数据类型 List 集合类型 ( 定义集合 | 初始化 | 泛型用法 | 初始化后添加元素 | 集合生成函数 | 集合遍历 )
  5. FFmpeg和RTMP结合编译
  6. android 选择答题功能,Android实现简单的答题系统
  7. URL详细分析及在python中处理URL
  8. STM32平衡小车 TB6612电机驱动学习
  9. 云队友丨人生的管理,就是目标的管理——管理工具大盘点
  10. 元宇宙迷思:科幻世界内外,“元宇宙”都几乎没有意义……