环境准备:

1、anaconda官网下载

下载地址
https://www.anaconda.com/distribution/
注意选用该电脑相应的系统和64/32位。

已安装Python使用环境的请跳过此步骤。

已安装Python使用环境的请跳过此步骤。

2、pytorch安装

https://pytorch.org/get-started/previous-versions/

基础框架:

1、数据准备

假设我们现在用开盘价、收盘价、最高价和最低价来预测下一天涨幅。其中,开盘价、收盘价、最高价、最低价使用连续5天的数据,那么模型输入数据为4x5的矩阵。矩阵的每一行分别对应开盘价、收盘价、最高价、最低价。模型输出为下一天涨幅。下面采用随机的方法生成10个样本数据,输入为维度为10x4x5,目标涨幅为10x1。

# 构建输入数据集
class Data_set(Dataset):def __init__(self, transform=None):self.transform = transformself.x_data = np.random.randint(0, 10, (10,4,5))self.y_data = np.random.randint(0, 10, (10,1))def __getitem__(self, index):x, y = self.pull_item(index)return x, ydef __len__(self):return self.x_data.shape[0]   def pull_item(self, index):return self.x_data[index, :, :], self.y_data[index, :]

2、定义神经网络模型

这里简单定义一个两层模型,第一层为卷积层,第二层为全连接层。

class MyModel(nn.Module):def __init__(self, num_classes=10):super(MyModel, self).__init__()self.model_name = "test"self.conv = nn.Conv1d(4, 1, 3, 1, 1)self.fc   = nn.Linear(5, 1)def forward(self, x):x = self.conv(x)x = x.view(-1, 5)return self.fc(x)

3、损失函数定义

损失函数用于计算模型输出与真实结果之间的误差,可以自己定义,也可以直接使用pytorch中的损失函数。

class MyLoss(nn.Module):def __init__(self):super().__init__()def forward(self, x, y):return torch.mean(torch.pow((x - y), 2))

4、完整代码

# -*- coding: utf-8 -*-
"""
Created on Fri Mar 13 20:30:41 2020@author: yehx
"""import argparse
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable# 构建输入数据集
class Data_set(Dataset):def __init__(self, transform=None):self.transform = transformself.x_data = np.random.randint(0, 10, (10,4,5))self.y_data = np.random.randint(0, 10, (10,1))def __getitem__(self, index):x, y = self.pull_item(index)return x, ydef __len__(self):return self.x_data.shape[0]   def pull_item(self, index):return self.x_data[index, :, :], self.y_data[index, :]class MyModel(nn.Module):def __init__(self, num_classes=10):super(MyModel, self).__init__()self.model_name = "test"self.conv = nn.Conv1d(4, 1, 3, 1, 1)self.fc   = nn.Linear(5, 1)def forward(self, x):x = self.conv(x)x = x.view(-1, 5)return self.fc(x)class MyLoss(nn.Module):def __init__(self):super().__init__()def forward(self, x, y):return torch.mean(torch.pow((x - y), 2))if __name__ == "__main__":parser = argparse.ArgumentParser(description='基础模型参数配置')train_set = parser.add_mutually_exclusive_group()parser.add_argument('--batch_size', default=2, type=int,help='Batch size for training')args = parser.parse_args()dataset = Data_set()data_loader = DataLoader(dataset, args.batch_size, shuffle=True)Net = MyModel()criterion = MyLoss()#criterion = nn.MSELoss() #也可使用pytorch自带的损失函数optimzer = torch.optim.SGD(Net.parameters(), lr=0.001)Net.train()loss_list = []num_epoches = 20for epoch in range(num_epoches):for i, data in enumerate(data_loader):inputs, labels = datainputs, labels = Variable(inputs).float(), Variable(labels).float()out = Net(inputs)loss = criterion(out, labels)  # 计算误差optimzer.zero_grad()  # 清除梯度loss.backward()optimzer.step()loss_list.append(loss.item())if (epoch+1) % 10 == 0:print('[INFO] {}/{}: Loss: {:.4f}'.format(epoch+1, num_epoches, loss.item()))#作图:误差loss在迭代过程中的变化情况plt.plot(loss_list, label='loss for every epoch')plt.legend()plt.show()  #训练的模型参数   print('[INFO] 训练后模型的参数:')for name,parameters in Net.named_parameters():print(name,':',parameters)#测试模型结果print('[INFO] 计算某个样本模型运算结果:')Net.eval()x_data = np.random.randint(0, 10, (4,5))x_data = torch.tensor(x_data).float()x_data = x_data.unsqueeze(0)y_data = Net(x_data)print(y_data.item())#模型保存torch.save(Net, 'model0.pth')#模型加载print('[INFO] 验证模型加载运算结果:')model0 =torch.load('model0.pth')y_data = model0 (x_data)print(y_data.item())

Pytorch 神经网络模型量化分析基本框架相关推荐

  1. 定点 浮点 神经网络 量化_神经网络模型量化论文小结

    神经网络模型量化论文小结 发布时间:2018-07-22 13:25, 浏览次数:278 现在"边缘计算"越来越重要,真正能落地的算法才是有竞争力的算法.随着卷积神经网络模型堆叠的 ...

  2. 深度学习修炼(五)——基于pytorch神经网络模型进行气温预测

    文章目录 5 基于pytorch神经网络模型进行气温预测 5.1 实现前的知识补充 5.1.1 神经网络的表示 5.1.2 隐藏层 5.1.3 线性模型出错 5.1.4 在网络中加入隐藏层 5.1.5 ...

  3. 神经网络模型量化论文小结

    现在"边缘计算"越来越重要,真正能落地的算法才是有竞争力的算法.随着卷积神经网络模型堆叠的层数越来越多,网络模型的权重参数数量也随之增长,专用硬件平台可以很好的解决计算与存储的双重 ...

  4. 神经网络模型通用性分析,神经网络模型可解释性

    BP神经网络的可行性分析 神经网络的是我的毕业论文的一部分4.人工神经网络人的思维有逻辑性和直观性两种不同的基本方式. 逻辑性的思维是指根据逻辑规则进行推理的过程:它先将信息化成概念,并用符号表示,然 ...

  5. 神经网络模型量化方法简介

    笔记mark: jpg算法中就用到了量化. png压缩算法中用到了霍夫曼编码. 本文主要梳理了模型量化算法的一些文章,阐述了每篇文章主要的内核思想和量化过程,整理了一些对这些文章的分析和看法. Dee ...

  6. 整合量化分析和基础研究——投资的艺术和科学

    作者:W.乔治.格雷戈 CHINAQIR编译整理 简介 利用量化分析和基础研究的投资过程一直都存在.许多基础研究使用量化方法帮助其筛选出一定量的可以重点考虑的公司.有一些量化方法使得分析师能够超越那些 ...

  7. 通过pytorch建立神经网络模型 分析遗传基因数据

    DNA双螺旋(已对齐)合并神经网络(黄色) 我最近进行了有关基因序列的研究工作.我想到的主要问题是:"哪一种最简单的神经网络能与遗传数据最匹配".经过大量文献回顾,我发现与该主题相 ...

  8. PyTorch框架:(2)使用PyTorch框架构建神经网络模型---气温预测

    目录 第一步:数据导入 第二步:将时间转换成标准格式(比如datatime格式) 第三步: 展示数据:(画了4个子图) 第四步:做独热编码 第五步:指定输入与输出 第六步:对数据做一个标准化 第七步: ...

  9. PyTorch | (4)神经网络模型搭建和参数优化

    PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch | (3)Tensor及其基本操作 PyTorch | (4)神经网络模型搭建和参数 ...

  10. 人工神经网络模型定义,人工神经网络基本框架

    人工神经网络评价法 人工神经元是人工神经网络的基本处理单元,而人工智能的一个重要组成部分又是人工神经网络.人工神经网络是模拟生物神经元系统的数学模型,接受信息主要是通过神经元来进行的. 首先,人工神经 ...

最新文章

  1. 中国信通院《新型智慧城市发展研究报告》
  2. 三星15TB固态硬盘开卖 售价高达10000美元
  3. 《包青天》中的《鸳鸯蝴蝶梦》单元,剧中有一个很漂亮的女子叫“离垢”
  4. 工作以后如何有效学习
  5. 怎样使用Mendeley高效地管理中文文献
  6. Shovels Shop
  7. 【C/C++9】天气APP:Oracle的虚表/日期/序列,索引/视图/链路/同义词,数据库高可用性
  8. sqlyog如何设置.时提示字段名_雷神新用户手册:拿到新电脑时如何简易设置参数!...
  9. XnSay临时网盘程序v1.0全开源
  10. c语言 整数除以分数,2019年六年级数学上册 3.1分数除法(第1课时)分数除法的意义和整数除以分数练习题 新人教版 (I).doc...
  11. java awt point_100分 解决java import java.awt.Point;import java.awt.Rectangle;
  12. 线程同步之条件变量和信号量(生产者消费者模型)
  13. Linux将鼠标解放,DwellClick:让鼠标下岗 解放你的手指
  14. R语言Γ(gamma)分布
  15. qgis中加载矢量切片
  16. 手机连接360免费WIFI一直显示正在获取IP地址、无法连接的解决方法
  17. 怎么给win10进行分区?
  18. 解决outlook2016 中邮件中,点击链接提示(您的组织策略阻止我们为您完成此操作)解决方案
  19. 泡泡龙游戏开发系列教程(二)
  20. java 坦克大战画坦克_【JAVA语言程序设计基础篇】--JAVA实现坦克大战游戏--画出坦克(二)...

热门文章

  1. yolov3/yolov4/yolov5/yolov6/yolov7/lite/fastdet/efficientdet各系列模型开发、项目交付、组合改造创新之—桥梁基建隧道裂痕裂缝检测实战
  2. python re.sub和lambda_【python学习笔记】 re.sub()
  3. elasticsearch源码:unicast列表解析
  4. 大富豪5.3全网首发,真正的5.3正版破解授权,不是高防端
  5. Dissect Eclipse Plugin Framework
  6. JAVA一些方法技巧
  7. BNUOJ 52506 Captcha Cracker
  8. 基于Python的招聘信息可视化分析研究
  9. 无线局域网安全(一)———WEP加密
  10. 【JVM】Java IDEA 配置项目的JVM运行内存大小