【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)
《PyTorch深度学习实践》-刘二大人 Otto Group Product Classification作业
将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集里商品是哪个类别
数据集可以在https://www.kaggle.com/c/otto-group-product-classification-challenge下载
代码及注释如下
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import matplotlib.pyplot as pltdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 函数将字符型标签转换为数值标签,方便计算交叉熵
def lables2id(lables):target_id = []target_lables = ['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']for lable in lables:target_id.append(target_lables.index(lable))return target_id# 定义数据集类
class TrainDataset(Dataset):def __init__(self, filepath):data = pd.read_csv(filepath)lables = data['target']self.len = data.shape[0] # shape(多少行,多少列)self.x_data = torch.tensor(np.array(data)[:, 1:-1].astype(float))self.y_data = lables2id(lables)def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lentrain_dataset = TrainDataset('D:/Research/Deep learning/pytorch刘二大人/otto-group-product-classification-challenge/train.csv')
# 数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=0)class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(93, 64)self.linear2 = torch.nn.Linear(64, 32)self.linear3 = torch.nn.Linear(32, 16)self.linear4 = torch.nn.Linear(16, 9)self.activate = torch.nn.ReLU()def forward(self, x):x = self.activate(self.linear1(x))x = self.activate(self.linear2(x))x = self.activate(self.linear3(x))x = self.linear4(x)return xmodel = Model()#定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.6)loss_list = []def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader):inputs, target = datainputs = inputs.float()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, target)loss_list.append(loss.item())loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299: # 每300轮打印一次结果print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0# 开始训练
if __name__ == '__main__':for epoch in range(50):train(epoch)# 预测保存函数,用于保存预测结果
def predict_save():with torch.no_grad():test_data = pd.read_csv('D:/Research/Deep learning/pytorch刘二大人/otto-group-product-classification-challenge/test.csv')x_text = torch.tensor(np.array(test_data)[:, 1:].astype(float))y_pred = model(x_text.float())_, predicted = torch.max(y_pred, dim=1) # 这里先取出最大概率的索引,即是所预测的类别。out = pd.get_dummies(predicted) # get_dummies 利用pandas实现one hot encode,方便保存为预测文件。lables = ['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']# 添加列标签out.columns = lables# 插入id行out.insert(0, 'id', test_data['id'])result = pd.DataFrame(out)result.to_csv('my_predict.csv', index=False)#画损失函数曲线plt.plot(range(len(loss_list)), loss_list)plt.xlabel('step')plt.ylabel('loss')plt.show()predict_save()
输出结果
[1, 300] loss: 1.424
[1, 600] loss: 0.863
[1, 900] loss: 0.742
…
[25, 300] loss: 0.467
[25, 600] loss: 0.461
[25, 900] loss: 0.477
…
[50, 300] loss: 0.421
[50, 600] loss: 0.423
[50, 900] loss: 0.426
损失函数曲线如下
可以尝试不同的optimizer,参数,进一步处理数据等等再优化。
【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)相关推荐
- 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)
文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...
- PyTorch深度学习实践概论笔记9-SoftMax分类器
上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...
- 【PyTorch】PyTorch深度学习实践|视频学习笔记|P6-P9
PyTorch深度学习实践 逻辑斯蒂回归及实现 背景与概念 基于分类问题中属性是类别性的,所以不能采取基于序数的线性回归模型,而提出了新的分类模型--逻辑斯蒂回归模型,输出每个样本在各个预测值上的概率 ...
- PyTorch深度学习实践
根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...
- 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)
这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...
- 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)
从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...
- PyTorch 深度学习实践 第13讲
PyTorch 深度学习实践 第13讲 引言 代码 结果 引言 近期学习了B站 刘二大人的PyTorch深度学习实践,传送门PyTorch 深度学习实践--循环神经网络(高级篇),感觉受益匪浅,发现网 ...
- 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归
刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...
- 《PyTorch深度学习实践》 课堂笔记 Lesson7 神经网络多维特征输入的原理推导与实现
文章目录 1.为什么使用多维的特征输入 2. 多维特征向量输入推导 3.实现过程 3.1源代码 3.2训练结果 写在最后 1.为什么使用多维的特征输入 对于现实世界来说,影响一个事物发展的因素有很多种 ...
最新文章
- 深度学习人体姿态估计算法综述
- Python到底是个什么东西
- 封装成vla函数_第四章:Python之函数
- 使用API​​身份验证的Spring Security
- 5g pdu session_运营商下架4G套餐,用户被5G!
- android c 11 编译,Android NDK r9b和编译C 11
- 纯新手DSP编程--5.18--调试
- 智能手机linux系统下载软件,智能手机下载工具
- 小郡肝火锅点餐系统——部分代码实现
- 2021年计算机能力挑战赛真题总结C++版
- PubMed插件:分区、影响因子和即时IF一目了然,还能秒下文献(亲测有效)
- 用于自动化的 10 个杀手级 Python 脚本
- ATS667LSG:真零速、高精度齿传感器 IC
- CSMA/CD和CSMD/CA
- FileStream写入文件
- ICV:超声波雷达迎来数字化变革,2026年全球市场规模将达145亿美元
- 英语四级作文模板(一)
- Oracle IMS DB2都属于,IMS数据库 IMS database
- java的三层架构是什么_java中的三层架构
- 算法描述怎么写伪代码java_伪代码描述算法
热门文章
- 计算机专业和物联网专业课,物联网工程专业课程有哪些
- 【电源模块】ME3116 DCDC降压模块设计
- 基于运放的放大电路分析
- 正确率,精确率,召回率.
- **6-7 十进制转换二进制 (10分)**
- [Unity2D]Tilemap Collider2D只给部分地图瓦片加上Collider的方法
- CEGUI安装、编译、运行总结
- W32Dasm反编译教程+工具
- 网络购物09.22-09.28
- 黑莓9000软件测试面试,初步测试有5大发现_黑莓9000 Bold - CNMO