作者 | 李秋键

编辑 | Carol

出品 | CSDN(ID:CSDNnews)

自然语言处理作为人工智能的一个重要分支,在我们的生活中得到了广泛应用。其中RNN算法作为自然语言处理的经典算法之一,是文本生成的重要手段。而今天我们就将利用RNN算法建立一个写歌词的软件。其中的界面如下:

RNN指的是循环神经网络,Recurrent Neural Network。不同于前馈神经网络的是,RNN可以利用它内部的记忆来处理任意时序的输入序列,这让它可以更容易处理如不分段的手写识别、语音识别等。

RNN模型有比较多的变种,这里介绍最主流的RNN模型结构如下:

上图中左边是RNN模型没有按时间展开的图,如果按时间序列展开,则是上图中的右边部分。我们重点观察右边部分的图。

这幅图描述了在序列索引号tt附近RNN的模型。其中:

  1. x(t)x(t)代表在序列索引号tt时训练样本的输入。同样的,x(t−1)x(t−1)和x(t+1)x(t+1)代表在序列索引号t−1t−1和t+1t+1时训练样本的输入。

  2. h(t)h(t)代表在序列索引号tt时模型的隐藏状态。h(t)h(t)由x(t)x(t)和h(t−1)h(t−1)共同决定。

  3. o(t)o(t)代表在序列索引号tt时模型的输出。o(t)o(t)只由模型当前的隐藏状态h(t)h(t)决定。

  4. L(t)L(t)代表在序列索引号tt时模型的损失函数。

  5. y(t)y(t)代表在序列索引号tt时训练样本序列的真实输出。

  6. U,W,VU,W,V这三个矩阵是我们的模型的线性关系参数,它在整个RNN网络中是共享的,这点和DNN很不相同。也正因为是共享了,它体现了RNN的模型的“循环反馈”的思想。

基于以上认知,我们开始搭建我们的软件。

实验前的准备

首先我们使用的python版本是3.6.5所用到的库有TensorFlow,是用来训练和加载神经网络常见的框架,常常用于数值计算的开源软件库。节点表示数学操作,线则表示在节点间相互联系的多维数据数组,即张量(tensor);tkinter用来绘制GUI界面的库;

Pillow库在此项目中用来处理图片和字体等问题。因为我们的软件不是空白背景的。需要借助Image函数添加背景。

RNN算法搭建

1、数据集处理和准备:

我们训练的数据集使用各种歌手的歌词本作为训练集。其中数据集放在date.txt里,其中部分数据集如下:

2、模型的训练:

模型训练的代码直接运行train.py即可训练。其中流程如下:

  1. 首先要读取数据集

  2. 设定训练批次、步数等等

  3. 数据载入RNN进行训练即可

其中代码如下:

def train:filename = 'date.txt'with open(filename, 'r', encoding='utf-8') as f:text = f.readreader = TxtReader(text=text, maxVocab=3500)reader.save('voc.data')array = reader.text2array(text)generator = GetBatch(array, n_seqs=100, n_steps=100)model = CharRNN(numClasses = reader.vocabLen,mode ='train',numSeqs = 100,numSteps = 100,lstmSize = 128,numLayers = 2,lr = 0.001,Trainprob = 0.5,useEmbedding = True,numEmbedding = 128)model.train(generator,logStep = 10,saveStep = 1000,maxStep = 100000)

3、RNN网络搭建:

RNN算法的搭建,我们定义整个神经网络类,然后分别定义初始化、输入、神经元定义等函数。损失函数和优化器使用均方差和AdamOptimizer优化器即可。

部分代码如下:

# 创建输入def buildInputs(self):numSeqs = self.numSeqsnumSteps = self.numStepsnumClasses = self.numClassesnumEmbedding = self.numEmbeddinguseEmbedding = self.useEmbeddingwith tf.name_scope('inputs'):self.inData = tf.placeholder(tf.int32, shape=(numSeqs, numSteps), name='inData')self.targets = tf.placeholder(tf.int32, shape=(numSeqs, numSteps), name='targets')self.keepProb = tf.placeholder(tf.float32, name='keepProb')# 中文if useEmbedding:with tf.device("/cpu:0"):embedding = tf.get_variable('embedding', [numClasses, numEmbedding])self.lstmInputs = tf.nn.embedding_lookup(embedding, self.inData)# 英文else:self.lstmInputs = tf.one_hot(self.inData, numClasses)# 创建单个Celldef buildCell(self, lstmSize, keepProb):basicCell = tf.nn.rnn_cell.BasicLSTMCell(lstmSize)drop = tf.nn.rnn_cell.DropoutWrapper(basicCell, output_keep_prob=keepProb)return drop# 将单个Cell堆叠多层def buildLstm(self):lstmSize = self.lstmSizenumLayers = self.numLayerskeepProb = self.keepProbnumSeqs = self.numSeqsnumClasses = self.numClasseswith tf.name_scope('lstm'):multiCell = tf.nn.rnn_cell.MultiRNNCell([self.buildCell(lstmSize, keepProb) for _ in range(numLayers)])self.initial_state = multiCell.zero_state(numSeqs, tf.float32)self.lstmOutputs, self.finalState = tf.nn.dynamic_rnn(multiCell, self.lstmInputs, initial_state=self.initial_state)seqOutputs = tf.concat(self.lstmOutputs, 1)x = tf.reshape(seqOutputs, [-1, lstmSize])with tf.variable_scope('softmax'):softmax_w = tf.Variable(tf.truncated_normal([lstmSize, numClasses], stddev=0.1))softmax_b = tf.Variable(tf.zeros(numClasses))self.logits = tf.matmul(x, softmax_w) + softmax_bself.prediction = tf.nn.softmax(self.logits, name='prediction')# 计算损失def buildLoss(self):numClasses = self.numClasseswith tf.name_scope('loss'):targets = tf.one_hot(self.targets, numClasses)targets = tf.reshape(targets, self.logits.get_shape)loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=targets)self.loss = tf.reduce_mean(loss)# 创建优化器def buildOptimizer(self):gradClip = self.gradCliplr = self.lrtrainVars = tf.trainable_variables# 限制权重更新grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainVars), gradClip)trainOp = tf.train.AdamOptimizer(lr)self.optimizer = trainOp.apply_gradients(zip(grads, trainVars))# 训练def train(self, data, logStep=10, saveStep=1000, savepath='./models/', maxStep=100000):if not os.path.exists(savepath):os.mkdir(savepath)Trainprob = self.Trainprobself.session = tf.Sessionwith self.session as sess:step = 0sess.run(tf.global_variables_initializer)state_now = sess.run(self.initial_state)for x, y in data:step += 1feed_dict = {self.inData: x,self.targets: y,self.keepProb: Trainprob,self.initial_state: state_now}loss, state_now, _ = sess.run([self.loss, self.finalState, self.optimizer], feed_dict=feed_dict)if step % logStep == 0:print('[INFO]: : {}/{}, loss: {:.4f}'.format(step, maxStep, loss))if step % saveStep == 0:self.saver.save(sess, savepath, global_step=step)if step > maxStep:self.saver.save(sess, savepath, global_step=step)break# 从前N个预测值中选def GetTopN(self, preds, size, top_n=5):p = np.squeeze(preds)p[np.argsort(p)[:-top_n]] = 0p = p / np.sum(p)c = np.random.choice(size, 1, p=p)[0]return c

4、歌词的生成:

设置关键词变量,读取模型文件,输出结果即可。

代码如下:

def main(_):reader = TxtReader(filename='voc.data')model = CharRNN(numClasses = reader.vocabLen,mode = 'test',lstmSize = 128,numLayers = 2,useEmbedding = True,numEmbedding = 128)checkpoint = tf.train.latest_checkpoint('./models/')model.load(checkpoint)key="雪花"prime = reader.text2array(key)array = model.test(prime, size=reader.vocabLen, n_samples=300)print("《"+key+"》")print(reader.array2text(array))

界面的定义和调用

界面中我们的布局是文本框、编辑框和按钮控件。程序的调用使用批处理文件调用以达到显示运行过程的效果。因为如果没有运行过程,难免会导致用户不清楚程序流程而强制运行容易导致卡死的情况。

其中Bat里直接写入:

python song.py

其中过程效果如下:

1、界面布局:

界面布局使用canvas画布以达到添加背景图片的效果。背景图片设置为1.jpg,按钮背景图片设置为3.jpg。图片也可以自己更换掉。然后文本框作为提示的效果,分别定义字体,大小等等即可

代码如下:

root = tk.Tkroot.title('AI写歌词')# 背景canvas = tk.Canvas(root, width=800, height=500, bd=0, highlightthickness=0)imgpath = '1.jpg'img = Image.open(imgpath)photo = ImageTk.PhotoImage(img)imgpath2 = '3.jpg'img2 = Image.open(imgpath2)photo2 = ImageTk.PhotoImage(img2)canvas.create_image(700, 400, image=photo)canvas.packlabel=tk.Label(text="请输入关键词:

rnn按时间展开_作词家下岗系列:教你用 RNN 算法做一个写词软件相关推荐

  1. 作词家下岗系列:教你用 RNN 算法做一个写词软件

    作者 | 李秋键 编辑 | Carol 出品 | CSDN(ID:CSDNnews) 自然语言处理作为人工智能的一个重要分支,在我们的生活中得到了广泛应用.其中RNN算法作为自然语言处理的经典算法之一 ...

  2. rnn按时间展开_双向RNN的理解

    我们在学习某种神经网络模型时,一定要把如下几点理解透了,才算真正理解了这种神经网络. 网络的架构:包含那些层,每层的输入和输出,有那些模型参数是待优化的 前向传播算法 损失函数的定义 后向传播算法 什 ...

  3. rnn按时间展开_一文搞懂RNN(循环神经网络)基础篇

    神经网络基础 神经网络可以当做是能够拟合任意函数的黑盒子,只要训练数据足够,给定特定的x,就能得到希望的y,结构图如下: 将神经网络模型训练好之后,在输入层给定一个x,通过网络之后就能够在输出层得到特 ...

  4. 微信怎么at所有人_任正非被遗漏的讲话:怎么做一个谦虚的领导者?

    任正非曾在市场大会做过一个叫<做谦虚的领导者>的讲话,少有媒体和社会层面的关注,甚至内部也未有什么波澜.但细看起来,这个讲话基本上是华为的管理纲要,如:以利润为中心:建立大区协调机制:坚持 ...

  5. 卡通化图片python实现代码_媳妇儿喜欢玩某音中的动漫特效,那我就用python做一个图片转化软件。...

    ​    最近某音上的动漫特效特别火,很多人都玩着动漫肖像,我媳妇儿也不例外.看着她这么喜欢这个特效,我决定做一个图片处理工具,这样媳妇儿的动漫头像就有着落了. 编码 为了快速实现我们的目标,我们就不 ...

  6. python网易云歌词做成词云图_讨好女朋友:用Python给女朋友做一个歌曲词云图

    今天咋们来看看网易云赵雷的歌曲歌词,并做一个词云图.这篇文章可以学习到什么是词云,爬虫的基本流程,简单的可视化操作 一 什么是词云 可视化有很多种,好的数据可视化,可以使得数据分析的结果更加通俗易通. ...

  7. 实现时间排序_面试官:手撕十大排序算法,你会几种?

    推荐阅读: 去面试大厂被 Kafka 虐了,后悔没有早点看到这份Kafka手写笔记 面试阿里,京东,百度,快手归来,三年Java开发总结了这些经验 阿里,字节,腾讯,面试题都涵盖了,这一份Java面试 ...

  8. 用python做一个数据查询软件_使用Python实现NBA球员数据查询小程序功能

    本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理. 以下文章来源于早起Python ,作者投稿君 一.前言 有时将代码转成带有界面的程序,会极大地方便 ...

  9. 一步一步教你如何用python做词云_一步一步教你如何用Python做词云

    前言 在大数据时代,你竟然会在网上看到的词云,例如这样的. 看到之后你是什么感觉?想不想自己做一个? 如果你的答案是正确的,那就不要拖延了,现在我们就开始,做一个词云分析图,Python是一个当下很流 ...

最新文章

  1. Tomcat 源码阅读记录(1)
  2. ajax点赞只能点一次,php+mysql+ajax局部刷新点赞取消点赞功能(每个账号只点赞一次).pdf...
  3. Android Studio导入第三方类库的方法
  4. python制作excel表格-excel表格中怎么制作一份燃尽图表
  5. SQL小技巧系列 --- 行转列合并
  6. AIX中查找端口号和进程
  7. hdu4091(暴力)
  8. 数据库mysql_row_MYSQL数据库mysql found_row()使用详解
  9. oracle10官网下载安装,oracle11g安装(win10)下载安装
  10. JPA EntityManager详解
  11. Swif语法基础 要点归纳(一)
  12. 在Hibernate,EhCache,Quartz,DBCP和Spring中启用JMX
  13. 「声明」本博客自动采集于博客园-niceyoo
  14. Linux——grep文本搜索命令
  15. 腾讯PCG光影研究室招聘计算机视觉算法/实习生
  16. css3的新属性 新增的颜色--- 透明度---两种渐变---定义多张背景图--background-size...
  17. 三个杯子的倒水问题(BFS)
  18. 谷歌云服务器的ip是什么ip,看到有人在问谷歌云的IP段问题,我推荐几个自用觉得不错的...
  19. 阿里云智能语音交互服务-录音文件识别采样率不支持-UNSUPPORTED_SAMPLE_RATE 解决方案
  20. 0130更新:完美wine QQ2011正式版(5074)

热门文章

  1. android加载转圈动画,android 围绕中心旋转动画
  2. centos php mcrypt,CentOS yum php mcrypt 扩展安装方法
  3. explain mysql 权限_explain命令为什么可能会修改MySQL数据
  4. mysql 创建删除表_mysql创建删除表的实例详解
  5. spark executor内存分配_二十二、Spark之图解Executor端内存管理
  6. qt creator创建cmake构建的程序,无法启动调试(点左下角运行不出结果 No executable specified.)
  7. PyQt编程之模态与非模态对话框(二)
  8. 到底什么是面向对象,面试中怎么回答。面向过程和面向对象的区别是什么。java跨平台特性以及java和C++的区别。面向对象的三大特性——封装、继承和多态。面向对象的高拓展性以及低耦合度怎么体现?
  9. mysql 下载地址及安装教程
  10. python土木_土木和结构工程师用Python-Python for civil and structural engineers