模型定义:

import torch
import torch.nn as nnclass DF(nn.Module):def __init__(self, nb_classes):super(DF, self).__init__()self.block1 = nn.Sequential(         nn.Conv1d(in_channels=1,              out_channels=32,            kernel_size=8,              stride=1,                   padding=0,                 ),  nn.BatchNorm1d(32),                   nn.ELU(alpha=1.0),                     nn.Conv1d(32, 32, 8, 1, 0),nn.BatchNorm1d(32),nn.ELU(alpha=1.0),nn.MaxPool1d(8, 4, 0), nn.Dropout(0.1), )self.block2 = nn.Sequential(nn.Conv1d(32, 64, 8, 1, 0),nn.BatchNorm1d(64),nn.ReLU(),nn.Conv1d(64, 64, 8, 1, 0),nn.BatchNorm1d(64),nn.ReLU(),nn.MaxPool1d(8, 4, 0),nn.Dropout(0.1),)self.block3 = nn.Sequential(nn.Conv1d(64, 128, 8, 1, 0),nn.BatchNorm1d(128),nn.ReLU(),nn.Conv1d(128, 128, 8, 1, 0),nn.BatchNorm1d(128),nn.ReLU(),nn.MaxPool1d(8, 4, 0),nn.Dropout(0.1),)self.block4 = nn.Sequential(nn.Conv1d(128, 256, 8, 1, 0),nn.BatchNorm1d(256), nn.ReLU(),nn.Conv1d(256, 256, 8, 1, 0),nn.BatchNorm1d(256),nn.ReLU(),nn.MaxPool1d(8, 4, 0),nn.Dropout(0.1),)self.fc1 = nn.Sequential(         nn.Flatten(),nn.Linear(3328,512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.7),              )self.fc2 = nn.Sequential(nn.Linear(512, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),              )self.out = nn.Sequential(nn.Linear(512, nb_classes),)   def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.fc1(x)x = self.fc2(x)          output = self.out(x)return output, x

模型训练:

train_loader=train_dl
NB_CLASSES = 50
EPOCH = 100
BATCH_SIZE = 128
LR = 0.001cnn = DF(NB_CLASSES).float().cuda()loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
validation_loader=test_dl
validation_size=450
train_size=4050for epoch in range(EPOCH):for step, (b_x, b_y) in enumerate(train_loader):# 128 1 5000b_x = b_x.cuda()b_y = b_y.cuda()output = cnn(b_x.unsqueeze(-2).float())[0]loss = loss_func(output, b_y.squeeze().long())resnet_loss.append(loss)optimizer.zero_grad()loss.backward()optimizer.step()if step % 1 == 0:corrects = 0avg_loss = 0for _, (b_x, b_y) in enumerate(validation_loader):b_x = b_x.cuda()b_y = b_y.cuda()logit = cnn(b_x.unsqueeze(-2).float())[0]loss = loss_func(logit, b_y.squeeze().long())avg_loss += loss.item()corrects += (torch.max(logit, 1)[1].view(b_y.size()).data == b_y.data).sum()size = validation_sizeavg_loss /= sizeaccuracy = 100.0 * corrects / sizeresnet_accuracy.append(accuracy)print('Epoch: {:2d}({:6d}/{}) Evaluation - loss: {:.6f}  acc: {:3.4f}%({}/{})'.format(epoch,step * 128,train_size,avg_loss, accuracy, corrects, size))

一个比较完整的pytorch项目相关推荐

  1. 分享一个比较完整的Vue2+项目供大家交流学习

    分享一个比较完整的Vue2+项目供大家交流学习,这个项目的英文简介:Awesome douban DEMO created with Vue2.x + Vuex + Vue-router + vue- ...

  2. [vue]开源一个精致完整的Vue项目(豆瓣)

    各位道友,开源一个Vue2+项目: Awesome douban DEMO created with Vue2.x + Vuex + Vue-router + vue-resource 项目地址: h ...

  3. Intellij Idea 搭建一个完整的JavaWeb项目(二)

    手把手搭建一个完整的JavaWeb项目 本案例使用Servlet+jsp制作,用Intellij Idea IDE和Mysql数据库进行搭建,详细介绍了搭建过程及知识点. 主要功能有: 1.用户注册 ...

  4. 机器学习入门系列(2)--如何构建一个完整的机器学习项目(一)

    上一篇机器学习入门系列(1)–机器学习概览简单介绍了机器学习的一些基本概念,包括定义.优缺点.机器学习任务的划分等等. 接下来计划通过几篇文章来介绍下,一个完整的机器学习项目的实现步骤会分为几步,最后 ...

  5. 一个完整的pytorch预训练实现图像分类,模型融合

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 利用pytorch实现图像分类,其中包含的densenet,resnext,mobilenet ...

  6. 【React进阶-1】从0搭建一个完整的React项目(入门篇)

    这篇文章带领大家从零开始手动撸一个React项目的基础框架,集成React全家桶.万字长文,请各位有足够的时间时再来阅读和学习. 概述 平时工作中一直在用React提供的脚手架工具搭建React项目, ...

  7. 一个完整的网络项目,如何根据需求配置交换机?值得收藏学习

    通过实例来详细讲解一个完整的网络项目从规划到交换机配置的详细过程. 一.案例要求拓扑图 小型园区中,分为两个部门,每个部门相互独立,却又通信,进行组网如下图. 二.分析 在拿到项目后,首先就是对项目进 ...

  8. 分享一个完整的社区项目(Android端加后台)

    '乡吧'社区APP安卓端代码 简介 "乡吧"APP是我之前的一个练手项目,此App主要是提供一个同乡的社区交流平台, 用户可以选择自己家乡的'乡吧'进行发帖.评论.创建群或者直接与 ...

  9. 如何去实践一个完整的数据挖掘项目?

    [每日一问] 如何去实践一个完整的数据挖掘项目? 机器学习项目 1 抽象成数学问题(明确问题) 2 获取数据 3 特征预处理与特征选择 4 训练模型与调优 5 模型诊断 6 模型融合(非必须) 7 上 ...

  10. Scikit-Learn TensorFlow机器学习实用指南(二):一个完整的机器学习项目【上】

    机器学习实用指南(二):一个完整的机器学习项目[上] 作者:LeonG 本文参考自:<Hands-On Machine Learning with Scikit-Learn & Tens ...

最新文章

  1. 【Matlab】dde23解时滞时延微分方程
  2. 数据蒋堂 | 怎样看待存储过程的移植困难
  3. 理解Silverlight的路径填充规则
  4. 关于glusterfs-3.0.4中AFR修复的一个bug
  5. python连接不上数据库_绕不过去的Python连接MySQL数据库
  6. sublime html整理阶梯,sublime text 之添加插件 并使用
  7. mysql事务管理及spring声明式事务中主动异常抛出使数据库回滚
  8. 导航猫(NaviCat for MySql)建立表的方法
  9. 读书笔记 effective c++ Item 34 区分接口继承和实现继承
  10. java原生开发项目-快递e栈
  11. 调查称谷歌占北美25%互联网流量
  12. 判定是否支持XHTML
  13. CodeForce 589J Cleaner Robot
  14. Arcgis使用教程(十)ARCGIS地图制图之统一修改地图符号样式的边框
  15. 2020年数学建模亚太赛赛后分享总结
  16. 机器学习_深度学习毕设题目汇总——运动活动动作
  17. 台式电脑怎么连接蓝牙和无线网
  18. picpick文字竖排了怎么变成横排
  19. 吐泡泡_via牛客网
  20. 算命的理科生——顺口说说算命的事......

热门文章

  1. 【空间分析】0 基本空间分析工具
  2. php实现的进度条功能示例,PHP 进度条函数的简单实例
  3. java反射机制的实现机制_Java反射机制实践
  4. python选取tensor某一维_Python按维数操作多维张量,Pytorch,对,Tensor,维度
  5. php 魔方,PHP解密:魔方二代-免费解密代码详解
  6. oracle goldengate 触发器,Oracle goldengate的触发器错误 OGG-00869
  7. 存储端显示主机链路降级_链路优化、产品升级,腾讯广告让汽车营销更轻松
  8. 空间变量php,PHP名称空间可以包含变量吗?
  9. java hql 连接查询,java – 如何从HQL表单中的两个连接表查询中选择*?
  10. linux监控文件是否传输,利用SecureCRT在linux与Windows之间传输文件