近几年来,考公的人数越来越多,而申论作为考公非常重要的一部分,也是另很多人头痛的一部分。很多人在考试之前都会背一些优秀范文或句段,以便在考试时派上用场。这里我用GPT2预训练很多篇申论范文,使之能在某个话题的提示下自动申成一片范文或句段。话不多说,直接上代码。

数据预处理

这里我找了500篇申论范文,不是很多,当然也可以多找点,最好是各类话题都有,越多越好。

造字典

将所有文章中的字,符号提取出来,去重后存入一个txt文档中
代码实现

import os
DIR_PATH = r"novels"
VOCAB_FILE = r"Vocab.txt"
words = set()
x=0
for i, filename in enumerate(os.listdir(DIR_PATH)):x=x+1f_path = os.path.join(DIR_PATH, filename)print(f_path)with open(f_path, "r+", encoding="utf-8") as f:w = f.read(1)while w:if w == '\n' or w == '\r' or w == ' ':# words.add('[SEP]')passelse:words.add(w)w = f.read(1)
print(x)
with open(VOCAB_FILE, "w+", encoding="utf-8") as f:f.write("[START] [SEQ] [UNK] [PAD] [END] ")f.write(" ".join(words))f.flush()

对文章进行编码

利用字典对文章进行编码,如字典中第12个字是“我”,则在原文中的“我”就用数字11代替,然后保存每篇文章的编码。
代码实现:

import os
SRC_DIR = r"novels"
DST_DIR = r"encoded_novels"
VOCAB_FILE = "Vocab.txt"
if not os.path.exists(DST_DIR):os.makedirs(DST_DIR)
with open(VOCAB_FILE, "r+", encoding="utf-8") as f:tokens = f.read().split()
count = 0
for i, filename in enumerate(os.listdir(SRC_DIR)):f_path = os.path.join(SRC_DIR, filename)print(f_path)with open(f_path, "r+", encoding="utf-8") as f:dst = ["0"]w = f.read(1)while w:if w == '\n' or w == '\r' or w == '\t' or ord(w) == 12288:dst.append("1")elif w == ' ':dst.append("3")else:try:dst.append(str(tokens.index(w)))except:dst.append("2")w = f.read(1)count+=1with open(os.path.join(DST_DIR, "{}.txt".format(count)), "w+", encoding="utf-8") as df:df.write(" ".join(dst))
print(count)

网络模型

我搭建的是带多头注意力的GPT模型,由于电脑GPU显存不大,所以头数设的12,模块数设的6,字的维数为768,最多可生成500字

# config文件
block_num = 6
head_num = 12
embed_dim = 768
vocab_num = 3012
pos_num =500
multi=4
stride=1
device = "cuda:0"
import torch
from torch import nn
import config as cfg
class Attention(nn.Module):def __init__(self, isMask=True):super().__init__()self.dk = (cfg.embed_dim // cfg.head_num) ** 0.5self.isMask = isMaskself.c_attn = nn.Linear(cfg.embed_dim, cfg.embed_dim * 3)self.attn_drop = nn.Dropout(0.1)self.resi_drop = nn.Dropout(0.1)self.c_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)if self.isMask:# self.register_buffer("mask", torch.tril(torch.ones(cfg.pos_num, cfg.pos_num)))self.mask = torch.tril(torch.ones(cfg.pos_num, cfg.pos_num)).cuda()def forward(self, x):x = self.c_attn(x) # x形状(N,S,V),N代表多少个句子,S代表多少个词,V代表每个词的维度x = x.reshape(*x.shape[:-1], cfg.head_num, -1)  # (N,S,V)——>(N,S,12,768/12*3)x = x.transpose(-2, -3)  # (N,S,12,768/12*3)——>(N,12,,S,768/12*3)q, k, v = x.chunk(3, dim=-1)w = (q @ k.transpose(-1, -2)) / self.dk  # (N,12,S,64)@(N,12,64,S)=(N,12,S,S)# if self.isMask:# mask=(self.mask if self.isMask else 1)mask=torch.tril(torch.ones(w.size(-2), w.size(-1))).cuda()w = w * mask - (1 - mask) * 1e5w = torch.softmax(w, dim=-1)w = self.attn_drop(w)a = w @ v  # (N,12,S,S)@(N,12,S,64)-->(N,12,S,64)a = a.transpose(-2, -3)  # (N,12,S,64)-->(N,S,12,64)a = a.reshape(*a.shape[:-2], cfg.embed_dim)  # (N,S,12,64)-->(N,S,768)h = self.c_proj(a)h = self.resi_drop(h)return h
class Block(nn.Module):def __init__(self, isMask=True):super().__init__()self.layer_normal_1 = nn.LayerNorm(cfg.embed_dim)self.attention = Attention(isMask)self.layer_normal_2 = nn.LayerNorm(cfg.embed_dim)self.proj = nn.Sequential(nn.Linear(cfg.embed_dim, cfg.multi * cfg.embed_dim),nn.LeakyReLU(),nn.Linear(cfg.multi * cfg.embed_dim, cfg.embed_dim),)self.dropout = nn.Dropout(0.1)def forward(self, x):h = self.layer_normal_1(x)a = self.attention(h)a = a + x  # 加一个残差a = self.layer_normal_2(a)h = self.proj(a)h = self.dropout(h)y = h + a  # 加一个残差return y
class GPT2(nn.Module):def __init__(self):super().__init__()self.vocab_embed = nn.Embedding(cfg.vocab_num, cfg.embed_dim) # 定义一个字典self.pos_embed = nn.Embedding(cfg.pos_num, cfg.embed_dim)   # 定义一个位置编码# self.type_embed = nn.Embedding(cfg.type_num, cfg.embed_dim)   # 定义一个类型编码self.blocks = []for _ in range(cfg.block_num):self.blocks.append(Block())self.drop = nn.Dropout(0.1)self.sequential = nn.Sequential(*self.blocks)self.output_layer = nn.Linear(cfg.embed_dim, cfg.vocab_num, bias=False)def forward(self, x, p):e = self.vocab_embed(x)  # 对输入进行词向量编码p = self.pos_embed(p)    # 对输入进行位置编码# t = self.type_embed(t)   # 对输入进行类型编码h = self.drop(e + p)h = self.sequential(h)return self.output_layer(h)

网络训练

生成训练数据

import torch, os
from torch.utils.data import Dataset
import config as cfg
class MyDataset(Dataset):def __init__(self, dir):self.dataset = []for filename in os.listdir(dir):with open(os.path.join(dir, filename), "r+") as f:ws = [int(x) for x in f.readline().split()]ws_len = len(ws)start = 0while ws_len - start > cfg.pos_num + 1:self.dataset.append(ws[start:start + cfg.pos_num + 1])start += cfg.strideelse:if ws_len > cfg.pos_num + 1:self.dataset.append(ws[ws_len - cfg.pos_num - 1:])def __len__(self):return len(self.dataset)def __getitem__(self, index):data = torch.tensor(self.dataset[index])return data[0:-1], data[1:]

训练


from module import *
from dataset import *
import torch, os
from torch import  optim
from torch.utils.data import DataLoader
from torch.nn import  functional as F
# def weight_init(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_normal_(m.weight)
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0)
save_path=r"网络参数"
class Trainer:def __init__(self):self.net = GPT2()self.weight_file = os.path.join(save_path, "gpt2_k.pt")if os.path.exists(self.weight_file):self.net.load_state_dict(torch.load(self.weight_file))# else:#     self.net.apply(weight_init)self.net.to(torch.device(cfg.device))self.opt = optim.Adam(self.net.parameters(), lr=0.0001)def train(self):myDataset = MyDataset(r"encoded_novels")print(len(myDataset))dataloader = DataLoader(myDataset, batch_size=4, shuffle=True)epoch=0while True:epoch=epoch+1sum_loss = 0for i, (x, y) in enumerate(dataloader):x, y = x.to(torch.device(cfg.device)), y.to(torch.device(cfg.device))p = torch.arange(0, x.shape[1])[None, :].repeat(x.shape[0], 1).to(torch.device(cfg.device))# print(p)_y = self.net(x, p).reshape(-1, cfg.vocab_num)y = y.reshape(-1)loss = F.cross_entropy(_y, y)self.opt.zero_grad()loss.backward()self.opt.step()print(loss.cpu().detach().item())sum_loss += loss.cpu().detach().item()if i % 1000 == 0 and i > 0:torch.save(self.net.state_dict(), self.weight_file)print("第{0}轮训练完毕".format(epoch))print("轮的平均损失为{0}".format(sum_loss / len(dataloader)))torch.save(self.net.state_dict(), self.weight_file)print("参数保存成功")

测试

from module import *
def transer(x):              # 索引到字的换算VOCAB_FILE = "Vocab.txt"with open(VOCAB_FILE, "r+", encoding="utf-8") as f:tokens = f.read().split()y=x[0]for i in y:print(tokens[i], end=" ")
def Transfer(str):          # 字到索引的换算VOCAB_FILE = "Vocab.txt"with open(VOCAB_FILE, "r+", encoding="utf-8") as f:tokens = f.read().split()idx=tokens.index(str)return idx
if __name__ == '__main__':gpt = GPT2()gpt.to(torch.device(cfg.device))gpt.eval()gpt.load_state_dict(torch.load(r"网络参数\gpt2_k.pt"))os = []x = torch.tensor([[Transfer("依"),Transfer("法"),Transfer("治"),Transfer("国")]]).cuda()  # 给定一个开始词p = torch.tensor([[0,1,2,3]]).cuda()  # 给定一个起始位置l=x.size()[1]for i in range(400):y = gpt(x, p)y = y[:, -1:]v, y = torch.topk(y, 8, dim=-1)v, y = v.reshape(-1, 8), y.reshape(-1, 8)v = torch.multinomial(torch.softmax(v, dim=-1), 1)y = torch.gather(y, -1, v)x = torch.cat([x, y], dim=1)p = torch.tensor([range(i + l + 1)]).cuda()print(transer(x))

比如,输入“人工智能”,则会生成如下片段:

人 工 智 能 , 网 上 购 物 , 物 联 网 , 各 种 新 兴 技 术 层 出 不 穷 , 各 种 创 新 思 想 不 断 迸 发 , 国 家 政 策 环 境 需 求 都 为 创 新 提 供 了 丰 富 的 土 壤 , 这 也 是 最 坏 的 时 代 , 自 主 品 牌 创 新 能 力 薄 弱 , 山 寨 产 品 盛 行 , 核 心 技 术 被 外 方 意 志 很 大 程 度 上 削 减 了 我 国 的 竞 争 力 , 究 其 原 因 , 一 方 面 是 企 业 缺 乏 竞 争 意 识 , 创 新 意 识 目 光 短 浅 所 致 , 而 另 一 方 面 在 于 人 才 的 流 失 , 由 于 学 术 界 浮 躁 的 氛 围 , 以 及 体 制 的 不 完 善 等 , 许 多 科 研 人 员 面 临 工 资 低 , 没 有 项 目 的 窘 境 , 为 了 改 善 环 境 , 降 低 生 存 压 力 , 转 而 流 向 其 他 的 领 域 , 因 此 想 要 中 国 品 牌 走 出 国 门 , 提 升 竞 争 力 , 创 新 是 关 键 。 打 造 中 国 品 牌 提 升 国 家 竞 争 力 , 融 入 民 族 精 神 是 重 点 。 中 国 品 牌 之 所 以 被 称 为 中 国 品 牌 , 关 键 在 于 其 拥 有 独 特 的 魅 力 , 不 同 于 其 他 国 家 , 必 须 有 中 国 的 特 色 , 必 须 有 中 国 的 文 化 , 与 文 化 紧 密 结 合 , 故 宫 博 物 院 的 文 创 产 品 , 就 是 将 这 一 融 合 发 挥 到 极 致 的 典 范 , 将 文 物 蕴 含 的 文 化 内 容 融 入 到 产 品 设 计 当 中 , 设 计 出 具 有 中 国 特 色 的 独 一 无 二 的 文 创 产 品 , 不 仅 能 够 吸 引 大 量 的 游 客 , 更 传 承 了 中 国 文 化 之 道 , 不 仅 打 造 了 品 牌 , 更 将 这 一 品 牌 销 往 国 外 , 可 见 , 打 造 中 国 品 牌 , 还 必 须 要 将 中 国 文 化 结 合 其 中 , 方 能 够 让 中 国 品 牌 脱 颖 而 出 , 与 众 不 同 , 方 能 体 现 中 国 竞 争 力 。

基于GPT2实现考公申论文章生成相关推荐

  1. GPT2实现考公申论文章生成

    向AI转型的程序员都关注了这个号???????????? 人工智能大数据与深度学习  公众号:datayx 近几年来,考公的人数越来越多,而申论作为考公非常重要的一部分,也是另很多人头痛的一部分.很多 ...

  2. 执着于考研考公却一再挫败,拿什么拯救你的职场后半生?

    今天之所以想写一篇这样的文章,确确实实是有感而发,因为从近来接触的学员身上,能够最直观地感受到:考公考研失败后的同学,他们内心的那种焦虑感远超往期! 用他们的话讲:"目前的状态就是感觉自己和 ...

  3. 再见吧,996!程序员开源考公指南获高赞:三人已成功上岸

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 整理 | 钰莹 转载自公众号:AI前线 近年来,互联网公司 996 ...

  4. 长见识!居然还有程序员考公指南这种东西?

    整理 | 王晓曼 出品 | 程序人生 (ID:coder _life) 最近,拼多多事件的发酵再次把互联网打工人的996推到了风口浪尖. 虽然并不是每一个猝死事件都能与"过劳"建立 ...

  5. 考研、考公还是找工作?别在大学因为迷茫这个问题浪费时间了

    在大学,千万不要因为分数去限制你的思维 如果看到此篇文章的你,是处于即将或已经步入大一的学弟学妹们,那首先我要恭喜你们通过自己的努力考上了大学.在这我告诉你们,只要进入大学的大门,就不要纠结于一些某些 ...

  6. 超全!互联网大厂职级薪资表,全国各地互联网大厂分布(校招/社招/考研/考公)

    中国互联网大厂从实力上划分,可以分为第一梯队.第二梯队.第三梯队 互联网巨头市值缩水排行榜 互联网大厂月薪情况 互联网大厂时薪排行榜 互联网大厂薪资&职级参考表 2022届校招薪资汇总 202 ...

  7. 程序员考公指南:逃离996的最强出路,拒绝秃顶的最佳方法

    最近,拼多多事件的发酵再次把互联网打工人的996推到了风口浪尖. 虽然并不是每一个猝死事件都能与"过劳"建立直接联系,但互联网行业超负荷加班处理Bug是家常便饭,虽然收入高于很多行 ...

  8. 程序员考公指南(逃离996的最强后路!!!)

    最近,拼多多事件的发酵再次把互联网打工人的996推到了风口浪尖. 虽然并不是每一个猝死事件都能与"过劳"建立直接联系,但互联网行业超负荷加班处理Bug是家常便饭,虽然收入高于很多行 ...

  9. 加餐1 | 考公、考编、军队文职以及事业编

    文章目录 一.简介 二.考公 1.国考 2.省考 三.考编(各种事业编) 四.军队文职 五.银行 一.简介 考编一般是指公务员和事业编公务员. 公务员分为国考.省考.选调生. 国考在每年的 11 月底 ...

  10. 失意互联网人,决定去考公

    深燃(shenrancaijing)原创 作者 | 邹帅  唐亚华 王敏 宛其 李秋涵 编辑 | 王敏 互联网的尽头是考公? 近年来,公务员考试越来越热,今年的竞争尤其激烈.据统计,国家公务员招录考试 ...

最新文章

  1. linux命令学习——file
  2. 使用预训练的卷积神经网络(猫狗图片分类)
  3. 插入排序(c++实现)
  4. 5、优化MySQL服务器
  5. 【JAVA多线程学习笔记】(1)实现线程的方式 线程生命周期 操作线程的方法
  6. 买房一定要知道的购房误区 买涨不买跌的心态可能得改
  7. C# 大文件分块下载
  8. cmd windows 命令sleep_最实在的9个黑客命令!确定不学习下?
  9. Vue教程:简介(一)
  10. 如何判断NSMutableDictionary是否有某个key
  11. Excel 使用技巧集锦—163种技巧
  12. ipadmini1iOS9.3.5降级8.4.1教程
  13. 实现数据结构中的栈---后进先出LIFO
  14. Android 11 : 隐私和安全
  15. 因果关系基本概念:后门标准
  16. SEED LABS初入
  17. 虚拟服务器的常用服务器选什么,如何选择合适的虚拟主机,虚拟主机选什么系统...
  18. poi excel 导出设置边框,自定义背景色,自定义字体
  19. oracle开放查询表权限_Oracle创建用户并给用户授权查询指定表或视图的权限
  20. Android 开发的两种框架 MVC和MVP 的简单分析

热门文章

  1. win10 WIFI连接无选项时的解决方法
  2. Xilinx 文件的编写
  3. mysql field in set_MySQL中的find_in_set()函数使用技巧心得与应用场景总结
  4. 一生只为一个女人奋斗
  5. m118w重置墨粉_富士施乐 Fuji Xerox DocuPrint M118w/M118z墨盒换粉加粉详解
  6. WordPress使用腾讯云CDN配置如何实现https访问?
  7. word修改表格和下方段落的间距
  8. 网络打印机 这台计算机上没有安装,Win7添加网络打印机时提示打印处理器不存在怎么办?...
  9. IEEE trans模板格式中左下角添加脚注的方法
  10. 将多张小图片合并成一张大图片 Python3