先看看我摘录的一些结果吧

七字春联,开头两个字分别为 虎年、虎气、春节

虎年啸虎春虎虎
虎气伏虎牛龙龙

虎年虎虎展啸风
春节啸春虎风虎

虎年虎啸啸千啸
春节萝啸气风春

虎年啸浩一讯欢
春节回旧鹤绣舞

虎年啸一着有处
春节回一福舞福

虎年啸月翼业来
春节回旧精绣福

我语文水平不太行,乍一眼看感觉还是满高级的。

训练数据 春联.csv

0
龙腾虎啸
腊尽春回
高崖伏虎啸
茅庐卧龙飞
虎啸风声远
龙腾海浪高
龙引千江水
虎越万重山
密林藏伏虎
萝峰染晴云
山月流古雪
风虎浴清泉
事事都如意
虎虎有生气
云中熊虎将
天上凤凰儿
百尺飞泉鸣震谷
一声长啸势惊天
丙部琳琅春馥郁
寅宾璀灿日光华
长白虎啸林中日
江南猿啼岭上风
赤县奔腾如虎跃
神州崛起似龙飞
丑旧寅新宏图展
牛归虎跃春意浓
丑去寅来千里锦
牛奔虎啸九州春
丑去寅来人益健
牛奔虎跃春愈新
春风浩荡神州绿
虎气升腾岳麓雄
春风着意随人愿
虎气生威壮国魂
春光春色源春意
虎将虎年扬虎威
春节乍闻春有喜
虎年乐见虎生风
春雷巨响山河动
月夜旋风草木飞
春晓寅回人起舞
岁祯虎啸物昭苏
电闪金光夸五色
雷鸣巨吼动千山
憨厚忠诚牛品德
高昂奋勇虎精神
虎步奔腾开胜景
春风浩荡展鸿图
虎踞龙盘今胜昔
花得鸟语旧更新
虎踞龙盘今胜昔
莺歌燕舞呈吉祥
虎年赢得春风意
喜讯唤来燕子情
虎气顿生年属虎
春风常驻户迎春
虎气频催翻旧景
春风浩荡著新篇
虎添双翼前程远
国展宏图事业新
虎啸大山山献宝
龙腾祖国国扬威
虎啸密林凤万壑
鹤眠苍松月千岩
虎啸青山千里锦
风拂绿柳万家春
虎啸一声山海动
龙腾三界吉祥来
虎跃龙腾生紫气
风调雨顺兆丰年
虎跃龙腾兴骏业
莺歌燕舞羡鹏程
虎跃神州千业旺
春临盛世万民欢
花事才逢花好日
虎年更有虎威风
黄牛虽去精神在
猛虎初来气象新
江山秀丽春增色
事业辉煌虎更威
江山一统腾龙日
岁月三春入虎年
皆称飞虎一身胆
不负英雄千古名
金牛昂首高歌去
玉虎迎春敛福来
金牛辞岁寒风尽
白虎迎春喜气来
金牛辞岁千仓满
玉虎迎春百业兴
金牛奋蹄奔大道
乳虎添翼舞新春
金牛奋蹄开锦绣
乳虎添翼会风云
金牛送旧千家乐
玉虎迎新万户欢
龙腾虎跃人间景
鸟语花香天地春
绿野春深禾涌碧
神州虎啸青山来
门庭虎踞平安岁
柳浪莺歌锦绣春
门浴春风梅吐艳
户生虎气鸟争鸣
年逢寅虎群情奋
岁别丑牛大地春
牛肥马壮丰收岁
虎跃龙腾大有年
牛肥马壮家家富
虎跃龙腾处处春
牛奋千程荣盛世
虎驮五福贺新春
牛奋四蹄开锦绣
虎添双翼会风云
牛耕绿野千仓满
虎啸青山万木荣
牛耕沃野扬长去
虎啸群山大步来
千载难逢新世纪
万民谱写虎春秋
乾元启运三阳泰
斗丙回寅万户春
人逢盛世精神壮
虎跃奇峰气势雄
人间喜庆康平世
虎岁承欢幸福春
人民气魄如龙虎
祖国江山似画图
人入虎年鼓虎劲
门添春色发春辉
人添志气虎添翼
雪舞丰年燕舞春
人效黄牛心自贵
岁朝寅虎劲更高
山明水秀风光丽
虎跃龙腾日月新
生气联吟欣虎虎
留春伴读奋年年
四海龙腾抒壮志
千山虎啸振雄风
四海三江春气息
千家万户虎精神
四海笙歌迎虎岁
九州英杰跃鹏程
唯大英雄能伏虎
是真俊杰敢擒龙
啸一声惊天动地
睁双眼照耀乾坤
新年捷报虎添翼
大路朝阳马奋蹄
兴伟业仍须牛劲
展宏图更壮虎威
一代英豪生虎气
三春杨柳动莺歌
英雄气概如龙虎
祖国江山似画图
英雄时代英雄业
龙虎精神龙虎年
莺歌燕舞新春日
虎跃龙腾大治年
迎春节莺歌遍地
兴中华虎劲冲天
云喷笔花腾虎豹
雨翻墨浪走蛟龙
宅后青山金虎踞
门前绿水玉龙盘
致富脱贫添虎翼
开山治水展鹏程
丙穴鱼生人间改岁
寅方斗指天下皆春
春风浩荡花香鸟语
岁月峥嵘虎跃龙腾
虎跃龙腾九州焕彩
风调雨顺五谷丰登
牛奔福地普天献瑞
虎卧华堂满院生辉
势如破竹人欢马叫
安若泰山虎踞龙盘
紫气东来江山如画
红旗招展龙虎扬威
祖国富强神龙活虎
人民幸福舞燕飞莺
白虎替青牛招财进宝
黄莺鸣翠柳辞旧迎新
虎跃龙腾创人间奇迹
莺歌燕舞描大地春光
虎跃龙腾有天皆丽日
花香鸟语无地不春风
花团锦簇江山添异彩
虎啸龙吟华夏壮神威
金牛辞旧携凯歌而去
乳虎迎春带捷报新来
瑞雪兆丰年年年大吉
丑牛接寅虎虎虎生威
岁月逢春山河添锦绣
人民思治龙虎振精神
效虎豪吟放怀歌富岁
闻鸡起舞挥笔颂春光
祖国腾飞大鹏振羽翼
宏图再展乳虎显神通
庆虎岁把酒高吟虎跃曲
祝丰年扶犁又唱丰收谣
迎虎年敢逐改革拦路虎
送牛岁勇当奉献老黄牛
迎新春处处呈文明气象
入虎岁人人当改革先锋
虎跃龙腾碧海黄山妆玉宇
莺歌燕舞春风旭日蔚神州
虎跃龙腾华夏人民多俊杰
莺歌燕舞阳春山水尽朝晖
牛耕广野丑年犁出文明路
虎跃深山寅岁图开舜尧天
岁步寅年喜庆团圆同把酒
珠还合浦欢歌一统共迎春
喜庆牛年两制先迎香港还
欢歌兔岁亿民再赞澳门归
栽竹栽松竹隐凤凰松隐鹤
培山培水山藏虎豹水藏龙
丑岁建奇功香港回归昌国运
寅年兴大业宏图展现壮情怀
虎年喜虎劲攻关夺隘皆如虎
春节焕春光绣水描山总是春
牛年虽过去牛劲更增多奉献
虎岁喜临门虎威大振有精神
寅时入虎年十亿人民振虎劲
佳节描春色九州大地荡春潮
忆旧岁牛劲冲霄汉神鞭一指神州巨变
看今朝虎威壮中华众志成城经济腾飞

滑动窗口设置为2,即两个字符预测下一个字符,滑动预测5个字符。
但是还是存在较大问题,因为我是直接将所有的对联都拼接起来然后滑动选取训练数据和标签,而不是每句滑动选取,这样就导致会串,但是懒得切了,就这样吧。

完整代码

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import torch.nn.functional as Fclass lstm_model(nn.Module):def __init__(self, vocab, hidden_size, num_layers, dropout=0.5):super(lstm_model, self).__init__()self.vocab = vocab  # 字符数据集# 索引,字符self.int_char = {i: char for i, char in enumerate(vocab)}self.char_int = {char: i for i, char in self.int_char.items()}# 对字符进行one-hot encodingself.encoder = OneHotEncoder(sparse=True).fit(vocab.reshape(-1, 1))self.hidden_size = hidden_sizeself.num_layers = num_layers# lstm层self.lstm = nn.LSTM(len(vocab), hidden_size, num_layers, batch_first=True, dropout=dropout)# 全连接层self.linear = nn.Linear(hidden_size, len(vocab))def forward(self, sequence, hs=None):out, hs = self.lstm(sequence, hs)  # lstm的输出格式(batch_size, sequence_length, hidden_size)output = self.linear(out[:, -1])  # linear的输出格式,(batch_size * sequence_length, vocab_size)return output, hsdef onehot_encode(self, data):return self.encoder.transform(data)def onehot_decode(self, data):return self.encoder.inverse_transform(data)def label_encode(self, data):return np.array([self.char_int[ch] for ch in data])def label_decode(self, data):return np.array([self.int_char[ch] for ch in data])def get_batches(data, batch_size, seq_len):''':param data: 源数据,输入格式(num_samples, num_features):param batch_size: batch的大小:param seq_len: 序列的长度(精度):return: (batch_size, seq_len, num_features)'''num_features = data.shape[1]num_chars = batch_size * seq_len  # 一个batch_size的长度num_batches = int(np.floor(data.shape[0] / num_chars))  # 计算出有多少个batchesneed_chars = num_batches * num_chars  # 计算出需要的总字符量targets = np.vstack((data[1:].A, data[0].A))  # 可能版本问题,取成numpy比较好reshapeinputs = data[:need_chars].A.astype("int")  # 从原始数据data中截取所需的字符数量need_wordstrain_data = np.zeros((inputs.shape[0] - seq_len, seq_len, num_features))train_label = np.zeros((inputs.shape[0] - seq_len, num_features))for i in range(0, inputs.shape[0] - seq_len, 1):train_data[i] = inputs[i:i+seq_len]train_label[i] = inputs[i+seq_len-1]for i in range(0, inputs.shape[0] - seq_len,  batch_size):if i + batch_size > inputs.shape[0] - seq_len:breakx = train_data[i:i+batch_size]y = train_label[i:i+batch_size]yield x, ydef train(model, data, batch_size, seq_len, epochs, lr=0.01, valid=None):device = 'cuda' if torch.cuda.is_available() else 'cpu'model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()if valid is not None:data = model.onehot_encode(data.reshape(-1, 1))valid = model.onehot_encode(valid.reshape(-1, 1))else:data = model.onehot_encode(data.reshape(-1, 1))train_loss = []val_loss = []for epoch in range(epochs):model.train()hs = None  # hs等于hidden_size隐藏层节点train_ls = 0.0val_ls = 0.0for x, y in get_batches(data, batch_size, seq_len):optimizer.zero_grad()x = torch.tensor(x).float().to(device)out, hs = model(x, hs)hs = ([h.data for h in hs])y = y.reshape(-1, len(model.vocab))y = model.onehot_decode(y)y = model.label_encode(y.squeeze())y = torch.from_numpy(y).long().to(device)loss = criterion(out, y.squeeze())loss.backward()optimizer.step()train_ls += loss.item()if valid is not None:model.eval()hs = Nonewith torch.no_grad():for x, y in get_batches(valid, batch_size, seq_len):x = torch.tensor(x).float().to(device)  # x为一组测试数据,包含batch_size * seq_len个字out, hs = model(x, hs)# out.shape输出为tensor[batch_size * seq_len, vocab_size]hs = ([h.data for h in hs])  # 更新参数y = y.reshape(-1, len(model.vocab))  # y.shape为(128,100,43),因此需要转成两维,每行就代表一个字了,43为字典大小y = model.onehot_decode(y)  # y标签即为测试数据各个字的下一个字,进行one_hot解码,即变为字符# 但是此时y 是[[..],[..]]形式y = model.label_encode(y.squeeze())  # 因此需要去掉一维才能成功解码# 此时y为[12...]成为一维的数组,每个代表自己字典里对应字符的字典序y = torch.from_numpy(y).long().to(device)# 这里y和y.squeeze()出来的东西一样,可能这里没啥用,不太懂loss = criterion(out, y.squeeze())  # 计算损失值val_ls += loss.item()val_loss.append(np.mean(val_ls))train_loss.append(np.mean(train_ls))print("val_loss", val_ls)print("train_loss:", train_ls)plt.plot(train_loss, label="train_loss")plt.plot(val_loss, label="val loss")plt.title("loop vs epoch")plt.legend()plt.show()model_name = "lstm_model.net"with open(model_name, 'wb') as f:  # 训练完了保存模型torch.save(model.state_dict(), f)def predict(vocab_size, model, char, top_k=None, hidden_size=None):device = 'cuda' if torch.cuda.is_available() else 'cpu'model.to(device)model.eval()  # 固定参数with torch.no_grad():char = np.array([char])  # 输入一个字符,预测下一个字是什么,先转成numpychar = char.reshape(-1, 1)  # 变成二维才符合编码规范char_encoding = model.onehot_encode(char).A  # 对char进行编码,取成numpy比较方便reshapechar_encoding = char_encoding.reshape(1, -1, vocab_size)# 转成模型输入格式 char_tensor = torch.tensor(char_encoding, dtype=torch.float32)  # 转成tensorchar_tensor = char_tensor.to(device)out, hidden_size = model(char_tensor, hidden_size)  # 放入模型进行预测,out为结果probs = F.softmax(out, dim=1).squeeze()  # 计算预测值,即所有字符的概率print(probs.shape)if top_k is None:  # 选择概率最大的top_k个indices = np.arange(vocab_size)else:probs, indices = probs.topk(top_k)indices = indices.cpu().numpy()probs = probs.cpu().numpy()char_index = np.random.choice(indices, p=probs/probs.sum())  # 随机选择一个字符索引作为预测值char = model.int_char[char_index]  # 通过索引找出预测字符return char, hidden_sizedef sample(vocab_size, seq_len, model, length,sentence, top_k=None):hidden_size = Nonenew_sentence = [char for char in sentence]for i in range(length):next_char, hidden_size = predict(vocab_size, model, new_sentence[-seq_len:], top_k=top_k, hidden_size=hidden_size)new_sentence.append(next_char)return "".join(new_sentence)def main():hidden_size = 512num_layers = 4batch_size = 128seq_len = 2epochs = 30lr = 0.0001f = pd.read_csv("春联.csv")f = f["0"]text = list(f)text = ".".join(text).replace(".", "")vocab = np.array(sorted(set(text)))  # 建立字典vocab_size = len(vocab)print("vocab_size", vocab_size)val_len = int(np.floor(0.2 * len(text)))  # 划分训练测试集trainset = np.array(list(text[:-val_len]))validset = np.array(list(text[-val_len:]))model = lstm_model(vocab, hidden_size, num_layers)  # 模型实例化train(model, trainset, batch_size, seq_len, epochs, lr=lr, valid=validset)  # 训练模型model.load_state_dict(torch.load("lstm_model.net"))  # 调用保存的模型new_text1 = sample(vocab_size, seq_len, model, 5, "虎年", top_k=7)  # 预测模型,生成100个字符,预测时选择概率最大的前5个new_text2 = sample(vocab_size, seq_len, model, 5, "春节", top_k=7)  # 预测模型,生成100个字符,预测时选择概率最大的前5个print(new_text1)  # 输出预测文本print(new_text2)if __name__ == "__main__":main()

基于pytorch使用LSTM进行虎年春联生成相关推荐

  1. 【项目实战课】基于Pytorch的StyleGAN v1人脸图像生成实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的StyleGAN v1人脸图像生成实战>. 所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题, ...

  2. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 在https://blog.csdn.net/A ...

  3. 基于pytorch的LSTM模型构建

    上文我们利用pytorch构建了BP神经网络,LeNet,这次我们利用LSTM网络实现对MNIST数据集的分类,具体的数据获取方法本文不详细介绍,这里只要讲解搭建LSTM网络的方法以及参数设置. 这里 ...

  4. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 提起LSTM大家第一反应是在NLP的数据集上比较 ...

  5. 基于Pytorch的LSTM实战160万条评论情感分类

    数据以及代码的github地址    说明:训练速度使用cpu会很慢 # 目标:情感分类 # 数据集 Sentiment140, Twitter上的内容 包含160万条记录,0 : 负面, 2 : 中 ...

  6. 基于pytorch使用LSTM进行文本情感分析

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  7. 基于PyTorch的LSTM模型的IMBD情感分类遇到的问题

    今天想学LSTM的情感分类,结果碰到了一系列问题,耽误了很多时间.特此记录! 一.项目来源 lesson53-情感分类实战 B站视频 二.碰到的问题 1.报错AttributeError: modul ...

  8. 基于Pytorch实现LSTM(多层LSTM,双向LSTM)进行文本分类

    LSTM原理请看这:点击进入 LSTM: nn.LSTM(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, ba ...

  9. 基于pytorch的Bi-LSTM中文文本情感分类

    基于pytorch的Bi-LSTM中文文本情感分类 目录 基于pytorch的Bi-LSTM中文文本情感分类 一.前言 二.数据集的准备与处理 2.1 数据集介绍 2.2 文本向量化 2.3 数据集处 ...

  10. 【项目实战课】基于Pytorch的SRGAN图像超分辨实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的SRGAN图像超分辨实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战 ...

最新文章

  1. wp7上MD5加密类
  2. oracle 孟硕_关于几大主机厂的阿里云论坛用户知识和技术交流
  3. 【python教程】append()与extend()方法的区别教程
  4. 不要在foreach循环里进行元素的remove/add操作。remove元素请使用iteratot方式,如果并发操作,需要对Iterator对象加锁
  5. jQuery自适应倒计时插件
  6. 区块链技术公司谈找到合适的激励机制
  7. ios 旋转屏幕试图切换_TCL·XESS 旋转智屏 A200Pro 评测:方向一换,体验大不相同...
  8. 全球首例无人车撞人致死事故判决:Uber无罪,安全员要进一步调查
  9. javascript的bind方法绑定深入理解
  10. 二手轻型载货车报价图片_业主坐地提价, 新房抢客, 10月广州二手房成交跌了24%...
  11. nlp gpt论文_开放AI革命性的新NLP模型GPT-3
  12. Diamond types are not supported at language level ‘5‘ 解决方法
  13. 电力系统的常用仿真模块MATLAB/SIMULINK(1)
  14. MSYS以及MinGW安装
  15. OCR应用:OCR识图取字
  16. C语言的自动关机程序和一个用来整人的小程序
  17. windows7旗舰版序列号[经测试,第一枚即可完成升级!]
  18. Unity UGUI-Canvas
  19. 真的是没有底线了,重新认识Java
  20. 学计算机的用双核CPU够吗,电脑cpu核数越多越好吗

热门文章

  1. 云服务器哪家强?AWS、Azure、阿里云、腾讯云、华为云深度评测
  2. 学计算机的一直对画画感兴趣,[电脑绘画兴趣小组教学总结]sai电脑绘画入门教学...
  3. 直面程序人生,始于当下,奔赴未来!
  4. android 三点参数,iqoo3参数配置详情-iqoo3参数配置手机参数详细表
  5. vba打开html文件,VBA调用浏览器打开指定网页的几种方法
  6. word排版技巧:如何撤销删除自动编号
  7. GLSL内置数学函数部分解析
  8. jupyter连接失败
  9. 百度校招社招-知识图谱部门直推 机会多多
  10. 微课在小学计算机教学中的应用,微课技术在小学信息技术课堂中的应用