点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:作者丨wfnian@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/396666255

编辑丨极市平台

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

目录如下:

  1. 导入包以及设置随机种子

  2. 以类的方式定义超参数

  3. 定义自己的模型

  4. 定义早停类(此步骤可以省略)

  5. 定义自己的数据集Dataset,DataLoader

  6. 实例化模型,设置loss,优化器等

  7. 开始训练以及调整lr

  8. 绘图

  9. 预测

一、导入包以及设置随机种子

import numpy as np
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as pltimport random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

二、以类的方式定义超参数

class argparse():passargs = argparse()
args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]
args.hidden_size, args.input_size= [40, 30]
args.device, = [torch.device("cuda:0"if torch.cuda.is_available() else"cpu"),]

三、定义自己的模型

class Your_model(nn.Module):def __init__(self):super(Your_model, self).__init__()passdef forward(self,x):passreturn x

四、定义早停类(此步骤可以省略)

class EarlyStopping():def __init__(self,patience=7,verbose=False,delta=0):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltadef __call__(self,val_loss,model,path):print("val_loss={}".format(val_loss))score = -val_lossif self.best_score isNone:self.best_score = scoreself.save_checkpoint(val_loss,model,path)elif score < self.best_score+self.delta:self.counter+=1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter>=self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss,model,path)self.counter = 0def save_checkpoint(self,val_loss,model,path):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')self.val_loss_min = val_loss

五、定义自己的数据集Dataset,DataLoader

class Dataset_name(Dataset):def __init__(self, flag='train'):assert flag in ['train', 'test', 'valid']self.flag = flagself.__load_data__()def __getitem__(self, index):passdef __len__(self):passdef __load_data__(self, csv_paths: list):passprint("train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n".format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

六、实例化模型,设置loss,优化器等

model = Your_model().to(args.device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []early_stopping = EarlyStopping(patience=args.patience,verbose=True)

七、开始训练以及调整lr

for epoch in range(args.epochs):Your_model.train()train_epoch_loss = []for idx,(data_x,data_y) in enumerate(train_dataloader,0):data_x = data_x.to(torch.float32).to(args.device)data_y = data_y.to(torch.float32).to(args.device)outputs = Your_model(data_x)optimizer.zero_grad()loss = criterion(data_y,outputs)loss.backward()optimizer.step()train_epoch_loss.append(loss.item())train_loss.append(loss.item())if idx%(len(train_dataloader)//2)==0:print("epoch={}/{},{}/{}of train, loss={}".format(epoch, args.epochs, idx, len(train_dataloader),loss.item()))train_epochs_loss.append(np.average(train_epoch_loss))#=====================valid============================Your_model.eval()valid_epoch_loss = []for idx,(data_x,data_y) in enumerate(valid_dataloader,0):data_x = data_x.to(torch.float32).to(args.device)data_y = data_y.to(torch.float32).to(args.device)outputs = Your_model(data_x)loss = criterion(outputs,data_y)valid_epoch_loss.append(loss.item())valid_loss.append(loss.item())valid_epochs_loss.append(np.average(valid_epoch_loss))#==================early stopping======================early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save')if early_stopping.early_stop:print("Early stopping")break#====================adjust lr========================lr_adjust = {2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,10: 5e-7, 15: 1e-7, 20: 5e-8}if epoch in lr_adjust.keys():lr = lr_adjust[epoch]for param_group in optimizer.param_groups:param_group['lr'] = lrprint('Updating learning rate to {}'.format(lr))

八、绘图

plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.show()

九、预测

# 此处可定义一个预测集的Dataloader。也可以直接将你的预测数据reshape,添加batch_size=1
Your_model.eval()
predict = Your_model(data)

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

收藏 | 深度学习pytorch训练代码相关推荐

  1. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  2. 深度学习PyTorch,TensorFlow中GPU利用率较低,CPU利用率很低,且模型训练速度很慢的问题总结与分析

    在深度学习模型训练过程中,在服务器端或者本地pc端,输入nvidia-smi来观察显卡的GPU内存占用率(Memory-Usage),显卡的GPU利用率(GPU-util),然后采用top来查看CPU ...

  3. GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练

    GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练 目录 GPU信息查看以及确认Pytorch使用了GPU计算模块进行深度学习的训练 GPU基础信息查看 Pytorch是否使用 ...

  4. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  5. 深度学习模型训练的一般方法(以DSSM为例)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 本文主要用于记录DSSM模型学习期间遇到的问题及分析.处理经验.先统领性地提出深度学习模型训练 ...

  6. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  7. 程序如何在两个gpu卡上并行运行_深度学习分布式训练相关介绍 - Part 1 多GPU训练...

    本篇文章主要是对深度学习中运用多GPU进行训练的一些基本的知识点进行的一个梳理 文章中的内容都是经过认真地分析,并且尽量做到有所考证 抛砖引玉,希望可以给大家有更多的启发,并能有所收获 介绍 大多数时 ...

  8. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

最新文章

  1. 【css】如何使页面压缩时文本内容不换行
  2. c语言小于n的素数和,关于求N以内素数的一点小问题(N小于一亿)
  3. 拒绝穿模!新方法让虚拟偶像自由互动无障碍“贴贴”,8000网友追着点赞
  4. Matconvnet安装:win7+VS2015(pro)+Matlab 2017a+cuda8.0+cudnn 5.1
  5. java线程在什么时候结束,java – 什么时候线程超出范围?
  6. Oracle Controlfile控制文件中记录的信息片段sections
  7. python的整数类型_Python int 数字整型类型 定义int()范围大小转换
  8. 知道答案吗?知道为什么是这个答案吗?
  9. Mapillary发布世界最大交通标志数据集,用于自动驾驶研究
  10. 某银行软件中心产品开发流程
  11. 使用CancellationToken——而不是Thread.Sleep
  12. 自动驾驶汽车也能聊天?
  13. Qt 编译完成拷贝文件 INSTALL
  14. android设计常用字体,界面设计必备!全方位科普常用的字体规范
  15. PyG快速安装(一键脚本,2021.7.14简单有效)
  16. 信用卡评分模型(R语言)
  17. 杂散发射干扰和阻塞干扰
  18. doucument.referrer部分安卓机型一直为空问题
  19. win10家庭版如何修改用户名对应的文件夹的名字(中文该成英文字符)
  20. matlab最小二乘法拟合 做图像,用MatLab画图(最小二乘法做曲线拟合)

热门文章

  1. 求1000以内的所有水仙数c语言,求1000以内的所有水仙花数
  2. 密码学原理与实践_到底什么是防火墙入侵检测密码学身份认证?如何高效建立网络安全知识体系?...
  3. java 图形应用有必要学吗_儿童英语口语怎么学?有必要报班吗?
  4. java 504错误怎么解决_求助java.lang.NoClassDefFoundError怎么解决,报错信息如下
  5. python自动下载邮件附件_Python批量下载电子邮件附件并汇总合并Excel文件
  6. 鸿蒙系统公布名单,鸿蒙系统首批升级名单公布_鸿蒙系统首批升级机型
  7. CentOS7.4安装nginx和php5.40
  8. 组织JSON数据、JSON转换
  9. java batch size_java – @BatchSize但在@ManyToOne案例中有很多往返
  10. 学生渐进片add如何给_【收藏】为青少年验配渐进多焦点时,如何选择合适ADD?...