代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
import numpy as np
from sklearn.datasets import load_winedevice=torch.device('cuda' if torch.cuda.is_available else 'cpu')#借助gpu训练data = load_wine()#获取数据集
X = data.data
Y = data.targetx_train=torch.FloatTensor(X)#将数据集转化为tensor格式
y_train=torch.LongTensor(Y)class MyDataset(Data.Dataset):#Dataset与Dataloaderdef __init__(self,x_train,y_train):self.x_train=x_trainself.y_train=y_traindef __getitem__(self,idx):return x_train[idx],y_train[idx]def __len__(self):return len(x_train)train_dataset=MyDataset(x_train,y_train)
train_dataloader=Data.DataLoader(train_dataset,batch_size=16,shuffle=True)class MyModel(nn.Module):#搭建模型def __init__(self):super(MyModel, self).__init__()self.liner1=nn.Linear(13,64)self.activate1=nn.ReLU()self.liner2=nn.Linear(64,8)self.activate2=nn.ReLU()self.liner3=nn.Linear(8,3)def forward(self,x):output=self.liner1(x)output=self.activate1(output)output = self.liner2(output)output = self.activate2(output)output = self.liner3(output)return outputmodel=MyModel().to(device)
optimizer=optim.Adam(model.parameters(),lr=0.001)
loss_fn=nn.CrossEntropyLoss()Epoch=10000#没必要训练这么多轮
idx=0
for i in range(Epoch):#训练模型for x,y in train_dataloader:x,y=x.to(device),y.to(device)pred=model(x)loss=loss_fn(pred,y)idx+=1if idx%1000==0:print(loss)#三件套optimizer.zero_grad()loss.backward()optimizer.step()import random #测试模型
success=0
cnt=1000
with torch.no_grad():#不产生梯度,表明当前计算不需要反向传播for i in range(cnt):t=random.randint(0,len(x_train)-1)x_test=x_train[t].to(device)y_test=y_train[t]pred=model(x_test)
#         print(pred)result=np.argmax(pred.cpu().detach().numpy())if result==y_test.cpu().detach().numpy():success+=1
print(success/cnt*100,'%')
#这里测试模型有个问题,就是数据是从训练集中随机找到的,不是那么合理


测试结果:

测试结果居然是100%,大概率是因为过拟合了

Pytorch搭建网络训练葡萄酒分类数据集(三分类)相关推荐

  1. 使用 PyTorch 搭建网络 - predict_py篇

    predict_py篇 python中采用驼峰书写法且首字母大写的变量符号一般表示类名. 学习网络步骤:看原论文+看别人对原论文的理解,学习网络结构,看损失函数计算,看数据集,看别人写的代码,复现代码 ...

  2. caffe-MobileNet-ssd环境搭建及训练自己的数据集模型

    caffe-MobileNet-ssd环境搭建及训练自己的数据集模型 ***************************************************************** ...

  3. 搭建并训练多标签数据集的模型并将结果可视化

    #搭建并训练多标签数据集的模型并将结果可视化(tensorflow2) 1.数据集的介绍 该数据为拥有颜色与衣服类别两个标签的衣服识别,对于这样的数据集要求我们的神经网络需要两个输出,一个是类别,另一 ...

  4. ML之DT:基于简单回归问题训练决策树(DIY数据集+三种深度的二元DT性能比较)

    ML之DT:基于简单回归问题训练决策树(DIY数据集+三种深度的二元DT性能比较) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 for i in range(1, len(xPl ...

  5. 实例:【基于逻辑回归的鸢尾花二分类和三分类问题】

    基于逻辑回归的鸢尾花二分类和三分类问题 一.问题描述及数据集获取 二.逻辑回归 1.概述 2.应用 3. LogisticRegression回归算法 4. sklearn逻辑回归API 三.代码实现 ...

  6. pytorch下搭建网络训练并保存模型

    最近在学习pytorch,使用mnist数据集,搭建AlexNet训练并保存模型,将代码做一记录. 建立数据集的方法见pytorch建立自己的数据集(以mnist为例) 搭建网络的方法见用pytorc ...

  7. PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

    目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...

  8. 神经网络为什么可以实现分类?---三分类网络0,1,2与弹性振子力学系统

    本文制作了一个三分类的网络来分类mnist数据集的0,1,2.并同时制作了一个力学模型,用来模拟这个三分类的过程,并用这个模型解释分类的原理. 上图可以用下列方程描述 只要ωx0,ωx1,ωx2,ωx ...

  9. torch从零开始搭建deeplabv3+训练自己的数据集!

    目录 一.制作自己数据集 1.1 torch数据加载原理 1.2 地理信息科学与深度学习的结合 1.3代码实现 1.4分批次加载数据集 二.训练网络 2.1参数选择 2.2训练过成可视化 三.执行预测 ...

最新文章

  1. dataframe如何理解df[df[‘type‘]==‘xType‘]
  2. js 浅拷贝直接赋值_js的浅拷贝和深拷贝的简单理解和使用方法
  3. TF之VGG系列:利用预先编制好的脚本data_convert .py文件将图片格式转换为tfrecord 格式
  4. Python之字符串的134个常用操作
  5. git学习笔记04-将本地仓库添加到GitHub远程仓库-git比svn先进的地方
  6. Docker日志收集最佳实践
  7. 哈佛第二、哥大第三,第一还是它!2020USNews美国大学排名发布!
  8. 项目练习(二)—微博数据结构化
  9. 慢慢的,就没有了,就像从未存在过(转载)
  10. 如何准备Java初级和高级的技术面试
  11. TCP协议如何保证可靠传输
  12. 利用Dbgview和OutputDebugString
  13. 第5课 电子商务基础
  14. 破解md5加密的方法
  15. android+网速监控源码,记录: Android测试网速实现
  16. mysql数据库修改初始密码
  17. 网络无法找到计算机6,手机可以搜到WiFi6路由器的信号,电脑却搜不到这是怎么回事?...
  18. io-nio-socket步步为营(七) IO模型-心得体会
  19. 编译优化之 - 通用循环优化
  20. 洛谷 1282 多米诺骨牌#线性动态规划#

热门文章

  1. “主播露真容,男粉丝销号”的真正启示是什么?
  2. “东华杯”2021年大学生网络安全邀请赛 暨第七届上海市大学生网络安全大赛线上赛MISC-Writeup
  3. 计算机初级职称多久能拿证,请问助理工程师多久可以评下来及费用多少
  4. 使用node搭建后台管理系统(1)
  5. 前端MVVM是什么?和jQuery的区别是什么?
  6. 关于css中“点“,“井号“,“逗号“,“空格“,“冒号“的用法
  7. Factorialize a N umber(计算阶乘)—freeCodeCamp上边的项目
  8. 服务器系统负荷,服务器的系统负载
  9. IOS ATS 配置
  10. Golang基础——统计字符串中汉字的数量