数据加载(简单看)

from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import torch
def findFiles(path): return glob.glob(path)#print(findFiles('data/names/*.txt'))import unicodedata
import stringall_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
print(all_letters,n_letters)
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)print(unicodeToAscii('Ślusàrski'))

数据预处理(将人名按字符转化为tensor)

# Build the category_lines dictionary, a list of names per language
category_lines = {}#字典
all_categories = []# Read a file and split into lines
def readLines(filename):lines = open(filename, encoding='utf-8').read().strip().split('\n')#print([unicodeToAscii(line) for line in lines])return [unicodeToAscii(line) for line in lines]for filename in findFiles('data/names/*.txt'):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)#预处理为listcategory_lines[category] = lines#字典的每一个元素代表一个类别,每一个类别对应一个listn_categories = len(all_categories)
print(all_categories)
# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):# 首字母的indexreturn all_letters.find(letter)# Just for demonstration, turn a letter into a <1 ,n_letters> Tensor,one-hot操作
def letterToTensor(letter):tensor = torch.zeros(1, n_letters)tensor[0][letterToIndex(letter)] = 1#只有一行,对应的字母的index设为1,其余为0return tensor# Turn a line into a <line_length , 1 , n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):#对于输入人名的每一个字符,建立一个letterToTensor(1*n_letters)的张量tensor[li][0][letterToIndex(letter)] = 1#最终得到的也就是line_length*1*n_letters的张量return tensorprint(letterToTensor('J'))print(lineToTensor('Jones').size())
def categoryFromOutput(output):top_n, top_i = output.topk(1)#根据网络的输出(18分类的输出为一个18维的张量),选择最大可能性的类别category_i = top_i[0].item()#得到在18类中的indexreturn all_categories[category_i], category_iprint(categoryFromOutput(output))
import randomdef randomChoice(l):temp=l[random.randint(0, len(l) - 1)]print(temp)return tempdef randomTrainingExample():category = randomChoice(all_categories)#任选一个类别line = randomChoice(category_lines[category])#dict的key为category,value为一个list,在list中随机取一个符合以上类别的人名category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)#得到选取类别的indexline_tensor = lineToTensor(line)#将随机获取的人名按字符转换为张量return category, line, category_tensor, line_tensorfor i in range(10):category, line, category_tensor, line_tensor = randomTrainingExample()print('category =', category, '/ line =', line)
随机采样的数据(输出):
category = Korean / line(random_name) = Cho
category = Portuguese / line(random_name) = Castro
category = Polish / line(random_name) = Niemczyk
category = Russian / line(random_name) = Yaminsky
category = Spanish / line(random_name) = Marti
category = Portuguese / line(random_name) = Esteves
category = Vietnamese / line(random_name) = Vinh
category = French / line(random_name) = Laurent
category = Italian / line(random_name) = Napoletani
category = Spanish / line(random_name) = Gallego

构建 网络

import torch.nn as nnclass RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(input_size + hidden_size, hidden_size)self.i2o = nn.Linear(input_size + hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input, hidden):combined = torch.cat((input, hidden), 1)#RNN的思想,combined=本次输入+上次的输出hidden = self.i2h(combined)output = self.i2o(combined)output = self.softmax(output)return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)

训练网络

learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
def categoryFromOutput(output):top_n, top_i = output.topk(1)#根据网络的输出(18分类的输出为一个18维的张量),选择最大可能性的类别category_i = top_i[0].item()#得到在18类中的indexreturn all_categories[category_i], category_idef train(category_tensor, line_tensor):hidden = rnn.initHidden()rnn.zero_grad()for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i], hidden)#将上一次迭代结果的hidden作为下一次迭代的输入实现RNN的思想loss = criterion(output, category_tensor)loss.backward()# Add parameters' gradients to their values, multiplied by learning ratefor p in rnn.parameters():p.data.add_(-learning_rate, p.grad.data)return output, loss.item()########################
n_iters = 100000
print_every = 5000# Keep track of losses for plotting
current_loss = 0
all_losses = []for iter in range(1, n_iters + 1):category, line, category_tensor, line_tensor = randomTrainingExample()#随机采样数据output, loss = train(category_tensor, line_tensor)#训练current_loss += loss#计算loss# 将采样的数据 计算出的output,由categoryFromOutput得到算法识别出的category,与实际category做对比if iter % print_every == 0:guess, guess_i = categoryFromOutput(output)correct = '✓' if guess == category else '✗ (%s)' % categoryprint('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

pytorch RNN实现分类相关推荐

  1. pytorch实现文本分类_使用变形金刚进行文本分类(Pytorch实现)

    pytorch实现文本分类 'Attention Is All You Need' "注意力就是你所需要的" New deep learning models are introd ...

  2. LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数

    在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...

  3. 深度学习之循环神经网络(5)RNN情感分类问题实战

    深度学习之循环神经网络(5)RNN情感分类问题实战 1. 数据集 2. 网络模型 3. 训练与测试 完整代码 运行结果  现在利用基础的RNN网络来挑战情感分类问题.网络结构如下图所示,RNN网络共两 ...

  4. Pytorch RNN(详解RNN+torch.nn.RNN()实现)

    目录 一.RNN简介 二.RNN简介2 三.pytorch RNN 3.1    定义RNN()

  5. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  6. Pytorch RNN 实现新闻数据分类

    Pytorch RNN 实现新闻数据分类 概述 数据集 Text RNN 模型 评估函数 主函数 输出结果 概述 RNN (Recurrent Netural Network) 是用于处理序列数据的神 ...

  7. Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)

    接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...

  8. 手把手写深度学习(5)——Pytorch+RNN自动生成邓紫棋风格歌词

    前言:前面两篇文章讲了RNN的基础理论和用mxnet搭建一个RNN网络,自动生成歌词.本文是时候亮出我三十年邓紫棋歌迷的身份,用使用更广泛的Pytorch框架,搭建一个RNN模型,用来自动生成邓紫棋风 ...

  9. pytorch对MNIST分类

    深度学习 基础知识和各种网络结构实战 ... pytorch对MNIST分类 深度学习 前言 一.导入第三方库 二.下载MNIST数据集 三.创建神经网络模型 四.训练数据集 五.测试 完整代码 总结 ...

最新文章

  1. SaaS项目管理软件有什么用?
  2. 图解RxJava2(一)
  3. python自定义包_详解python自定义模块、包
  4. 撒花!吴恩达新书《Machine Learning Yearning》完整中文版pdf开放下载
  5. jsp前3章试题分析
  6. 拔号×××与站点×××的配置
  7. IT:如何在Windows Server 2008 R2上安装Hyper-V虚拟化
  8. 并行计算(一)——并行计算机系统及结构模型
  9. 02年六代雅阁的整备质量_2020年宝安第八批更新计划:联投地产5.4万㎡“工改”项目...
  10. Qt 字符串QString arg()用法总结
  11. spring4笔记----spring4构造注入
  12. 今天主要改了罗宾钢琴的首页图片缩放问题
  13. 关于 Java 的强制类型转换
  14. Java面向对象编程三大特征 - 继承
  15. 推荐系统系列 - 引导 - 5类系统推荐算法,非常好使,非常全
  16. matlab飞机降落过程模拟,scratch作品 “模拟飞机降落”---东风东路小学一年(14)班 沈宸玮...
  17. 如何选择关键词以及关键词分析优化
  18. iOS开发之常用第三方框架
  19. 一度智信|想要提高店铺流量,商家需要了解这些引流渠道
  20. 【架构】电商微服务架构图 -来源网络(非原创)

热门文章

  1. apache添加php语言模块,在apache中添加php处理模块-Go语言中文社区
  2. template不支持v-show
  3. 全面回顾2022年加密行业大事件:破后而立方能绝处逢生
  4. 如何理解vcc,vdd,vss
  5. VCC、VDD、VSS、GND等等V某某究竟是什么意思
  6. 搜狗搜索引擎+浏览器,双轮驱动读图时代
  7. Java程序员秋招面经大合集(BAT美团网易小米华为中兴等)
  8. 未来10年,C++5个非常有前景的就业方向
  9. vue引用public目录下文件
  10. Android HIDL HAL 接口定义语言详解