前言

本文是对Zhang XuYao等的大作Drawing and Recognizing Chinese Characters with Recurrent Neural Network的简化复现。本文分为两部分,第一部分是基于在Layer层自定义的Stacked LSTM Layer进行Online汉字的识别,第二部分是基于在Cell层自定义的GRU Cell进行汉字生成。

数据取自模式识别国家重点实验室提供的手写汉字数据集,其中Online数据是包含次序的数据序列,Offline则不包含次序,类似位图。因此处理Online数据可以优先考虑RNN。

1.汉字识别

1.1数据集准备

对于Tensorflow的Dataset来说,本文称其数据分为x部与y部(对于汉字识别,x部为样本,y部为标签)。

x部的数据准备方法文章里讲得很清楚,也很容易理解,这里仅说结果。处理结果是将汉字转化成长度为n的序列,即形状为 [n, 6] 的二维数组。对于不同的汉字书写实例,n很可能是不同的,即序列的长度是不定的,也正是RNN模型的特点——接受变长序列。

准备好的x部(若干汉字的若干书写实例,包含data augment)是一个元素为 [n, 6] 二维数组的list(此处n虽然不一,但都是已知的),在喂给RNN模型时,我们还要把n相等的元素进行合并,即打成批次(batch),这对提高运算速度有重要意义。

从以下使用generator来构造Dataset的代码可以看到是如何做batch的。在Tensorflow中,在模型构建时未定,每次运行时注入才确定的维度,用None表示。

y部比较简单,本文在汉字识别部分只取了10个字,标签就是字的类别的one-hot编码。

with open(ss.data_path + "x_y_la_n_" + str(ss.la_total) + "_s_" + str(ss.la_per_sample) + "_dist_" + str(ss.la_remove_dist_th) + "_ang_" + str(ss.la_remove_ang_th), "rb") as f:x, y = pickle.load(f)def train_generator():i = 0while True:curlen = np.size(x[i], 0)xout = [x[i]]yout = [y[i]]i += 1if i == len(x):returnwhile np.size(x[i], 0) == curlen:xout = np.append(xout, [x[i]], axis=0)yout = np.append(yout, [y[i]], axis=0)i += 1if i == len(x):returnyield (xout, np.expand_dims(yout, 1))dataset = tf.data.Dataset.from_generator(train_generator, output_types=(tf.float32, tf.float32),output_shapes=([None, None, 6], [None, 1, 10]))

可以看到generator是将等长序列合并提供的(此前序列已按长度排序),因此output_shapes多了一维:从左到右依次是batch_size,time_step(序列长度),input_feature。

1.2模型构建及训练

这一部分的模型构建可以直接采用keras的现成组件,网上也有不少教程,这里推荐一下论文二作的示范代码GitHub - YifeiY/hanzi_recognition。

本文不采用现成的组件,自行实现一个LSTM层。

对于基于keras的RNN模型而言,层次结构为Cell → Layer → Model,本文自定义一个S_LSTM类(Stacked LSTM Layer),派生自keras.layers.Layer,因此需要重载Layer类的各标准接口。

LSTM的实现逻辑参考资料:

Creating a simple RNN from scratch with TensorFlow

Implementing a LSTM from scratch with Numpy – Christina's blog

Unfolding RNNs II - Vanilla, GRU, LSTM RNNs from scratch in Tensorflow

GitHub - suriyadeepan/rnn-from-scratch: Use tensorflow's tf.scan to build vanilla, GRU and LSTM RNNs

class S_LSTM(keras.layers.Layer):def __init__(self, h_size, nlayer, **kwargs):super(S_LSTM, self).__init__(**kwargs)self.h_size = h_sizeself.nlayer = nlayerself.Wxi, self.Wxf, self.Wxc, self.Wxo = [], [], [], []self.Whi, self.Whf, self.Whc, self.Who = [], [], [], []self.Whsi, self.Whsf, self.Whsc, self.Whso = [], [], [], []self.Wc = []self.bi, self.bf, self.bc, self.bo = [], [], [], []def build(self, input_shape):for i in range(self.nlayer):self.Wxi.append(self.add_weight(shape=(input_shape[2], self.h_size), initializer=keras.initializers.glorot_uniform))self.Wxf.append(self.add_weight(shape=(input_shape[2], self.h_size), initializer=keras.initializers.glorot_uniform))self.Wxc.append(self.add_weight(shape=(input_shape[2], self.h_size), initializer=keras.initializers.glorot_uniform))self.Wxo.append(self.add_weight(shape=(input_shape[2], self.h_size), initializer=keras.initializers.glorot_uniform))if i > 0:self.Whsi.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whsf.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whsc.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whso.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whi.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whf.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Whc.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Who.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.Wc.append(self.add_weight(shape=(self.h_size, self.h_size), initializer=keras.initializers.glorot_uniform))self.bi.append(self.add_weight(shape=(self.h_size,), initializer=keras.initializers.zeros))self.bf.append(self.add_weight(shape=(self.h_size,), initializer=keras.initializers.zeros))self.bc.append(self.add_weight(shape=(self.h_size,), initializer=keras.initializers.zeros))self.bo.append(self.add_weight(shape=(self.h_size,), initializer=keras.initializers.zeros))def call(self, inputs, **kwargs):in_shape = tf.shape(inputs)batch_size = in_shape[0]init_c = []init_h = []for i in range(self.nlayer):init_c.append(tf.zeros([batch_size, self.h_size]))init_h.append(tf.zeros([batch_size, self.h_size]))def time_step(prev, xt):c_1, h_1 = tf.unstack(prev)c, h = [], []for i in range(self.nlayer):_i = tf.math.sigmoid(tf.matmul(xt, self.Wxi[i]) + (i > 0 and tf.matmul(h[i - 1], self.Whsi[i - 1])) + tf.matmul(h_1[i], self.Whi[i]) + self.bi[i])_f = tf.math.sigmoid(tf.matmul(xt, self.Wxf[i]) + (i > 0 and tf.matmul(h[i - 1], self.Whsf[i - 1])) + tf.matmul(h_1[i], self.Whf[i]) + self.bf[i])_o = tf.math.sigmoid(tf.matmul(xt, self.Wxo[i]) + (i > 0 and tf.matmul(h[i - 1], self.Whso[i - 1])) + tf.matmul(h_1[i], self.Who[i]) + self.bo[i])_c = tf.math.tanh(tf.matmul(xt, self.Wxc[i]) + (i > 0 and tf.matmul(h[i - 1], self.Whsc[i - 1])) + tf.matmul(h_1[i], self.Whc[i]) + self.bc[i])c.append(tf.multiply(_f, c_1[i]) + tf.multiply(_i, _c))h.append(tf.multiply(_o, tf.tanh(c[i])))return tf.stack([c, h])outputs = tf.scan(time_step, tf.transpose(inputs, [1, 0, 2]), tf.stack([init_c, init_h]))return tf.transpose(outputs[:, 1, self.nlayer-1, :, :], [1, 0, 2])def get_config(self):config = super(S_LSTM, self).get_config()config.update({"h_size": self.h_size, "nlayer": self.nlayer})return config

这样自定义的LSTM层能完美融入keras的Sequential模型中:

s_lstm_layer = S_LSTM(16, 2)model = keras.Sequential([keras.layers.Input(shape=(None, 6), dtype=tf.float32, ragged=False),s_lstm_layer,keras.layers.Dropout(0.2),keras.layers.TimeDistributed(keras.layers.Dense(10, activation="softmax")),
])model.compile(optimizer=keras.optimizers.Adam(1e-4), loss=keras.losses.CategoricalCrossentropy(from_logits=False), metrics=['accuracy'])
model.summary(line_length=200)
model.fit(take_batches, steps_per_epoch=500, epochs=300)

可以直接使用keras的内置优化器、损失函数和训练方法。本文的训练准确率可达到90%以上。

2.汉字生成

2.1基本思路

在准备数据之前,先分析一下如何实现汉字生成。思路实际上是对序列数据的针对性预测,即让RNN模型根据输入数据和历史信息来预测下一个数据,此类任务的入门推荐Tensorflow的官方教程循环神经网络(RNN)文本生成 。原论文所引用的经典大作Generating Sequences With Recurrent Neural Networks,和其经典实现https://github.com/hardmaru/write-rnn-tensorflow,也是本文的重要参考之一。

回到本文,生成一个汉字直观来说就是笔尖不断移动的过程,因此需要预测下一个点的位置和书写状态(落笔 [1, 0, 0],提笔 [0, 1, 0],终止 [0, 0, 1])。

我们先讲对下一个点的位置的预测。为了增加生成的多变性(笔者猜测),原论文与Generating Sequences With Recurrent Neural Networks对下一个点的位置的预测并不是直接产出一个二维向量 [x, y],而是产出一个用于随机生成下一个点的位置的高斯混合模型(GMM)的一组参数(有点儿绕,意思是每当要写下一笔时,就产出一组参数,由这组参数确定一个GMM,再由该GMM随机生成落笔点)。关于GMM的讲解:详解EM算法与混合高斯模型(Gaussian mixture model, GMM)_林立民爱洗澡-CSDN博客_gaussian mixture model。针对GMM的损失函数长这样:

可以这样认为,当训练使得该损失函数越来越小时,GMM生成的落笔点与实际的落笔点越相吻合。建议此处结合本文代码与原论文多加理解。

对于书写状态的预测比较简单,直接预测就行,损失函数采用带权重的交叉熵。

2.2数据集准备

先说y部,使用下一点的实际位置+实际书写状态即可。经过笔者实验,这里下一点的实际位置信息采用绝对位置优于采用相对位移,有助于使训练出的模型具备“整体纠错”能力。

再说x部,原论文是将当前点的实际位置+实际书写状态作为输入提供。经过笔者实验分析,认为下一点的位置与书写状态主要与当前是哪个字的第几笔有关,而可以与当前的点位和书写状态无关,甚至这种无关性有助于提升“整体纠错”的能力。因此在本文的简化复现中,x部只输入字的类别(当前是第几笔的信息由RNN的隐藏状态表达)。

2.3模型构建及训练

本文自行实现一个GRU Cell,派生自AbstractRNNCell,实现各标准接口,使其能被keras.layers.RNN直接调用。GRU相当于一个简化版的LSTM,相信在上一部分能够熟悉理解LSTM的实现代码后,GRU的代码就比较容易了。参考资料:

GitHub - suriyadeepan/rnn-from-scratch: Use tensorflow's tf.scan to build vanilla, GRU and LSTM RNNs

class SGRUCell(AbstractRNNCell, keras.layers.Layer):def __init__(self, units, nclass, **kwargs):super(SGRUCell, self).__init__(**kwargs)self.units = unitsself.nclass = nclass@propertydef state_size(self):return self.unitsdef build(self, input_shape):self.bias_z = self.add_weight(shape=(self.units), initializer=keras.initializers.constant(5), name='bias_z')self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4), initializer='orthogonal', name='recurrent_kernel')self.kernel_c = self.add_weight(shape=(self.nclass, self.units * 4), initializer='glorot_uniform', name='kernel_c')self.bias = self.add_weight(shape=(self.units * 3), initializer='zeros', name='bias')self.built = Truedef call(self, inputs, states, training):tf.debugging.assert_all_finite(inputs, 'sgrucell inputs ill')h_tm1 = states[0] if nest.is_sequence(states) else states  # previous memorych = tf.cast(inputs[:, 0], tf.int32)_ch = tf.one_hot(ch, self.nclass)z = tf.sigmoid(tf.matmul(h_tm1, self.recurrent_kernel[:, :self.units])+ tf.matmul(_ch, self.kernel_c[:, :self.units])+ self.bias_z)r = tf.sigmoid(tf.matmul(h_tm1, self.recurrent_kernel[:, self.units:self.units * 2])+ tf.matmul(_ch, self.kernel_c[:, self.units:self.units * 2])+ self.bias[:self.units])hh = tf.tanh(tf.matmul(r * h_tm1, self.recurrent_kernel[:, self.units * 2:self.units * 3])+ tf.matmul(_ch, self.kernel_c[:, self.units * 2:self.units * 3])+ self.bias[self.units * 1:self.units * 2])h = z * h_tm1 + (1 - z) * hho = tf.tanh(tf.matmul(h, self.recurrent_kernel[:, self.units * 3:])+ tf.matmul(_ch, self.kernel_c[:, self.units * 3:])+ self.bias[self.units * 2:])new_state = [h] if nest.is_sequence(states) else hreturn o, new_statedef get_config(self):config = super(SGRUCell, self).get_config()config.update({'units': self.units, 'nclass': self.nclass})return config

为了提升运算效率,将部分操作移至一个PostProcess层:

class PostProcess(keras.layers.Layer):def __init__(self, M, **kwargs):super(PostProcess, self).__init__(**kwargs)self.M = Mdef build(self, input_shape):input_dim = input_shape[-1]self.Wgmm = self.add_weight(shape=(input_dim, self.M * 5), initializer='glorot_uniform', name='Wgmm')self.bgmm = self.add_weight(shape=(self.M * 5), initializer='zeros', name='bgmm')self.Wsoftmax = self.add_weight(shape=(input_dim, 3), initializer='glorot_uniform', name='Wsoftmax')self.bsoftmax = self.add_weight(shape=(3), initializer='zeros', name='bsoftmax')self.built = Truedef call(self, inputs, **kwargs):tf.debugging.assert_all_finite(inputs, 'postprocess inputs ill')R5M = tf.matmul(inputs, self.Wgmm) + self.bgmm_pi = R5M[:, :, :self.M]pi = exp_safe(_pi) / tf.reduce_sum(exp_safe(_pi), axis=-1, keepdims=True)mux = R5M[:, :, self.M:self.M * 2]muy = R5M[:, :, self.M * 2:self.M * 3]sigmax = exp_safe(R5M[:, :, self.M * 3:self.M * 4])sigmay = exp_safe(R5M[:, :, self.M * 4:])R3 = tf.matmul(inputs, self.Wsoftmax) + self.bsoftmaxp = exp_safe(R3) / tf.reduce_sum(exp_safe(R3), axis=-1, keepdims=True)return tf.concat([pi, mux, muy, sigmax, sigmay, p], axis=-1)def get_config(self):config = super(PostProcess, self).get_config()config.update({"M": self.M})return config

这样仍然能完美融入keras的Sequential模型:

def construct_model(rnn_cell_units, nclass, M, stateful, batch_shape):rnn_cell = SGRUCell(units=rnn_cell_units, nclass=nclass)rnn_layer = tf.keras.layers.RNN(rnn_cell, return_state=False, return_sequences=True, stateful=stateful)model = tf.keras.Sequential([tf.keras.layers.Input(batch_shape=batch_shape),rnn_layer,PostProcess(M=M)])return model

损失函数:

def exp_safe(x):return tf.clip_by_value(tf.exp(x), clip_value_min=1e-10, clip_value_max=1e10)def log_safe(x):return tf.clip_by_value(tf.math.log(x), clip_value_min=-20, clip_value_max=20)def N(x, mu, sigma):return exp_safe(-(x - mu)**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))def Loss(y, pred):tf.debugging.assert_all_finite(pred, 'loss inputs ill')pi, mux, muy, sigmax, sigmay,  = tf.split(pred[:, :, :-3], 5, axis=-1)p = pred[:, :, -3:]xtp1 = tf.expand_dims(y[:, :, 0], axis=-1)ytp1 = tf.expand_dims(y[:, :, 1], axis=-1)stp1 = y[:, :, 2:5]w = tf.constant([1, 5, 100], dtype=tf.float32)lPd = log_safe(tf.reduce_sum(pi * N(xtp1, mux, sigmax) * N(ytp1, muy, sigmay), axis=-1))lPs = tf.reduce_sum(w * stp1 * log_safe(p), axis=-1)return - (lPd + lPs)

训练:

with open(ss.data_path + "x_y_lb100_n_" + str(ss.nclass) + "_r_" + str(ss.repeat) + "_dist_" + str(ss.lb_remove_dist_th) + "_ang_" + str(ss.lb_remove_ang_th) + "_drop_" + str(ss.drop) + "_np_" + str(ss.noise_prob) + "_nr_" + str(ss.noise_ratio), 'rb') as f:x, y = pickle.load(f)dataset = tf.data.Dataset.from_generator(lambda: iter(zip(x, y)), output_types=(tf.float32, tf.float32),output_shapes=([None, None, 1], [None, None, 5]))
take_batches = dataset.repeat().shuffle(5000)sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.45)))
tf.compat.v1.keras.backend.set_session(sess)loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")class CustomModel(keras.Model):def train_step(self, data):x, y = datawith tf.GradientTape() as tape:y_pred = self(x, training=True)loss = Loss(y, y_pred)trainable_vars = self.trainable_variables_gradients = tape.gradient(loss, trainable_vars)gradients = []for i in range(len(_gradients)):_g = tf.clip_by_norm(_gradients[i], clip_norm=100)gradients.append(_g)self.optimizer.apply_gradients(zip(gradients, trainable_vars))loss_tracker.update_state(loss)return {"loss": loss_tracker.result()}@propertydef metrics(self):return [loss_tracker]class CustomCallback(keras.callbacks.Callback):def __init__(self, model2, ckp_path):self.model2 = model2self.ckp_path = ckp_pathdef on_epoch_begin(self, epoch, logs=None):passtf.random.set_seed(123 + epoch)def on_epoch_end(self, epoch, logs=None):self.model2.load_weights(tf.train.latest_checkpoint(self.ckp_path))draw_chars(x, y, self.model2, [0, 1, 2, 3, 4], 50, False, self.ckp_path + 'epoch_' + str(epoch + 1))rnn_cell = SGRUCell(units=ss.units, nclass=ss.nclass)
rnn_layer = tf.keras.layers.RNN(rnn_cell, return_state=False, return_sequences=True, stateful=False)
postprocess = PostProcess(M=ss.M)inputs = keras.Input(batch_shape=[None, None, 1])
rnn_out = rnn_layer(inputs)
outputs = postprocess(rnn_out)
model = CustomModel(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001))
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=ss.checkpoint_path + 'ck_{epoch}', save_weights_only=True)
custom_callback = CustomCallback(construct_model(ss.units, ss.nclass, ss.M, True, [1, 1, 1]), ss.checkpoint_path)
model.run_eagerly = False
#  model.load_weights(tf.train.latest_checkpoint(ss.checkpoint_path))
model.fit(take_batches, steps_per_epoch=ss.steps_per_epoch, epochs=ss.epochs, initial_epoch=0,callbacks=[checkpoint_callback, custom_callback])

训练效果(上为预测数据,下为真实样本):

第1轮:

第50轮:

以上代码均为摘录,并不完整,不能直接运行。完整项目代码见本文Github地址:https://github.com/xiexi1990/rnn_chinese_new。

使用自定义的Layer和Cell实现手写汉字生成(Tensorflow2)相关推荐

  1. 自定义非等高 Cell

    1.自定义非等高 Cell介绍 1.1 代码自定义(frame) 新建一个继承自 UITableViewCell 的类. 重写 initWithStyle:reuseIdentifier: 方法. 添 ...

  2. 自定义等高的cell(storyboard)

    /*storyboard自定义cell1.创建一个继承自UITabelViewCell的子类,比如tgCell2.在storyboard中:往cell里面增加需要用到的子控件:设置cell的重用标识; ...

  3. CorelDRAW VBA - 在图层执行自定义命令 Layer.CustomCommand

    Layer.CustomCommand 方法用来执行组件提供的特定于某个图层的命令.例如,您可以使用 CustomCommand方法来执行表相关的命令,方法是使用 Table 开头的类. 参数说明 参 ...

  4. python自定义随机数_python:numpy.random模块生成随机数

    简介 所谓生成随机数,即按照某种概率分布,从给定的区间内随机选取一个数.常用的分布有:均匀分布(uniform distribution),正态分布(normal distribution),泊松分布 ...

  5. php 自定义表格并统计,PHP 使用Echarts生成数据统计报表的实现

    这篇文章主要介绍了PHP 使用Echarts生成数据统计报表的实现代码,需要的朋友可以参考下 echarts统计,简单示例 先看下效果图 看下代码 HTML页面 为ECharts准备一个Dom,宽高自 ...

  6. java操作跨页的word cell,利用itext 生成pdf,处理cell 跨页问题 [转]

    处理方法: PdfPTable table =newPdfPTable(1); table.setSplitLate(false); table.setSplitRows(true); 开发中的例子: ...

  7. 回调函数自定义传参_koroFileHeader:一个用于生成文件头部注释和函数注释的插件...

    小金子 读完需要 2分钟 速读仅需 1 分钟 大家好,我是你们的小金子. 今天给大家分享的这个工具呢?对于使用 VS Code 的同学来讲,是一个好东西. koroFileHeader,一个在 vsc ...

  8. 基于Conditional Layer Normalization的条件文本生成

    作者丨苏剑林 单位丨追一科技 研究方向丨NLP,神经网络 个人主页丨kexue.fm 从文章从语言模型到Seq2Seq:Transformer如戏,全靠Mask中我们可以知道,只要配合适当的 Atte ...

  9. java自定义表单_JSP实现用于自动生成表单标签html代码的自定义表单标签

    本文实例讲述了JSP实现用于自动生成表单标签HTML代码的自定义表单标签.分享给大家供大家参考.具体如下: 这个是自己写的一个简单的JSP表单标签,用于自动生成checkBox,select,radi ...

最新文章

  1. 设计模式---------门面模式
  2. 计算机报名锁定后可以修改吗,网上报名正式提交后 报名信息即被锁定 无法修改...
  3. instagram架构_如何为亚马逊,Instagram,Zalando和天猫生成产品图像
  4. 【Guava】对Guava类库的注释类型 VisibleForTesting的理解
  5. git pull出现错误的解决办法
  6. MySQL 实例空间使用率过高的原因和解决方法
  7. 有谁还遇到同样的问题?
  8. windows虚拟显示器SDK开发和提供
  9. 【软件工程】对软件工程课程的希望及个人目标
  10. KiB、MiB与KB、MB的区别
  11. Excel2013 单元格锁定
  12. win11提示windows许可证即将过期
  13. 【清华大学陈渝】第一章 操作系统概述
  14. 【Android 逆向】Android 逆向用途 | Android 逆向原理
  15. Linux_加密和安全详细介绍
  16. http://www.cnblogs.com/xd502djj/p/3473516.html
  17. managed, unmanaged
  18. 巨型计算机卡通,动漫史上十大超巨型机体
  19. Android录制桌面视频screenrecord
  20. Linux文件类型发布啦!

热门文章

  1. OpenFeign 夺命连环 9问
  2. IPC Send timeout/node eviction etc with high packet reassembles failure
  3. 大剖析:中国数万亿家装市场,为何出不了一个30亿美金的Houzz?
  4. win10用html文件做壁纸,利用win10自带工具制作动态壁纸的简单方法
  5. 解决每次新建word都有页眉和页脚
  6. MacBook Pro安装homebrew
  7. 在阅读中培养自己的注意力
  8. sql Mirroring
  9. vue canvas typescript 绘制时间标尺
  10. xlsx无法导入MySQL?