tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用。

数据目录在data,data下放了汉字识别图片:

data$ ls
0  1  10  11  12  13  14  15  16  2  3  4  5  6  7  8  9
datag$ ls 0
xxx.png yyy.png ....

代码:

如果将get model里的模型层数加非常深,训练时候很可能不会收敛,精度一直停留下1%以内。

# -*- coding: utf-8 -*-from __future__ import division, print_function, absolute_importimport os
import numpy as np
import pickle
import tflearnfrom PIL import Image
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, avg_pool_2d
from tflearn.layers.merge_ops import merge
from tflearn.layers.estimator import regression
from tflearn.data_utils import to_categorical, shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tflearn.layers.conv import highway_conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization, batch_normalizationdef resize_image(in_image, new_width, new_height, out_image=None,resize_mode=Image.ANTIALIAS):""" Resize an image.Arguments:in_image: `PIL.Image`. The image to resize.new_width: `int`. The image new width.new_height: `int`. The image new height.out_image: `str`. If specified, save the image to the given path.resize_mode: `PIL.Image.mode`. The resizing mode.Returns:`PIL.Image`. The resize image."""img = in_image.resize((new_width, new_height), resize_mode)if out_image:img.save(out_image)return imgdef convert_color(in_image, mode):""" Convert image color with provided `mode`. """return in_image.convert(mode)def pil_to_nparray(pil_image):""" Convert a PIL.Image to numpy array. """pil_image.load()return np.asarray(pil_image, dtype="float32")def iterbrowse(path):for home, dirs, files in os.walk(path):for filename in files:yield os.path.join(home, filename)def directory_to_samples(directory, flags):""" Read a directory, and list all subdirectories files as class sample """samples = []targets = []# label class is from 0 !!!label = 0try:  # Python 2classes = sorted(os.walk(directory).next()[1])except Exception:  # Python 3classes = sorted(os.walk(directory).__next__()[1])for c in classes:c_dir = os.path.join(directory, c)try:  # Python 2walk = os.walk(c_dir).next()except Exception:  # Python 3walk = os.walk(c_dir).__next__()for sample in walk[2]:if any(flag in sample for flag in flags):samples.append(os.path.join(c_dir, sample))targets.append(label)label += 1return samples, targets# Get the pixel from the given image
def get_pixel(image, i, j):# Inside image bounds?width, height = image.sizeif i > width or j > height:return None# Get Pixelpixel = image.getpixel((i, j))return pixel# Create a Grayscale version of the image
def convert_to_one_channel(image):# !!! I assume that the png file is grayscale. And R == G == B !!!! So I check it..."""for i in range(len(image)):for j in range(len(image[i])):pixel = image[i][j]# Get R, G, B values (This are int from 0 to 255)assert len(pixel) == 3red = pixel[0]green = pixel[1]blue = pixel[2]assert red == green == blueassert 0 <= red <= 1"""# Just extract 1 channel datareturn image[:, :, [0]]def image_dirs_to_samples(directory, resize=None, convert_gray=False,filetypes=None):print("Starting to parse images...")samples, targets = directory_to_samples(directory, flags=filetypes)for i, s in enumerate(samples):print("Process %d th file %s" % (i+1, s))samples[i] = Image.open(s)  # Load an image, returns PIL.Image.if resize:######################## TODO #######################samples[i] = resize_image(samples[i], resize[0],resize[1])######################### TODO ####################### convert to more data# if convert_gray:#    samples[i] = convert_color(samples[i], 'L')samples[i] = pil_to_nparray(samples[i])samples[i] /= 255.  # ormalize a list of sample image data in the range of 0 to 1samples[i] = convert_to_one_channel(samples[i]) # just want 1 channel dataprint("Parsing Done!")return samples, targetsdef load_data(dirname, resize_pics=(128, 128), shuffle_data=True):dataset_file = os.path.join(dirname, 'data.pkl')try:X, Y, org_labels = pickle.load(open(dataset_file, 'rb'))except Exception:# X, Y = image_dirs_to_samples(os.path.join(dirname, 'train/'), resize_pics, False, ['.jpg', '.png'])X, Y = image_dirs_to_samples(dirname, resize_pics, False,['.jpg', '.png'])  # TODO, memory is too small to load all dataorg_labels = YY = to_categorical(Y, np.max(Y) + 1)  # First class is '0', Convert class vector (integers from 0 to nb_classes)if shuffle_data:X, Y, org_labels = shuffle(X, Y, org_labels)pickle.dump((X, Y, org_labels), open(dataset_file, 'wb'))return X, Y, org_labelsclass EarlyStoppingCallback(tflearn.callbacks.Callback):def __init__(self, val_acc_thresh):# Store a validation accuracy threshold, which we can compare against# the current validation accuracy at, say, each epoch, each batch step, etc.self.val_acc_thresh = val_acc_threshdef on_epoch_end(self, training_state):"""This is the final method called in trainer.py in the epoch loop.We can stop training and leave without losing any information with a simple exception."""# print dir(training_state)print("Terminating training at the end of epoch", training_state.epoch)if training_state.val_acc >= self.val_acc_thresh and training_state.acc_value >= self.val_acc_thresh:raise StopIterationdef on_train_end(self, training_state):"""Furthermore, tflearn will then immediately call this method after we terminate training,(or when training ends regardless). This would be a good time to store any additionalinformation that tflearn doesn't store already."""print("Successfully left training! Final model accuracy:", training_state.acc_value)def get_model(width, height, classes=40):# TODO, modify model# Real-time data preprocessingimg_prep = tflearn.ImagePreprocessing()# Real-time data preprocessingimg_prep = tflearn.ImagePreprocessing()img_prep.add_featurewise_zero_center(per_channel=True)img_prep.add_featurewise_stdnorm()network = input_data(shape=[None, width, height, 1], data_preprocessing=img_prep)  # if RGB, 224,224,3network = conv_2d(network, 32, 3, activation='relu')network = max_pool_2d(network, 2)network = conv_2d(network, 64, 3, activation='relu')network = conv_2d(network, 64, 3, activation='relu')network = max_pool_2d(network, 2)network = fully_connected(network, 512, activation='relu')network = dropout(network, 0.5)network = fully_connected(network, classes, activation='softmax')network = regression(network, optimizer='adam',loss='categorical_crossentropy',learning_rate=0.001)model = tflearn.DNN(network, tensorboard_verbose=0)return modelif __name__ == "__main__":width, height = 32, 32X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)print("sample data:")print(trainX[0])print(trainY[0])print(testX[-1])print(testY[-1])model = get_model(width, height, classes=100)filename = 'cnn_handwrite-acc0.8.tflearn'# try to load model and resume training#try:#    model.load(filename)#    print("Model loaded OK. Resume training!")#except:#    pass# Initialize our callback with desired accuracy threshold.early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.9)try:model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')except StopIteration as e:print("OK, stop iterate!Good!")model.save(filename)# predict all data and calculate confusion_matrixmodel.load(filename)pro_arr =model.predict(X)predict_labels = np.argmax(pro_arr, axis=1)print(classification_report(org_labels, predict_labels))print(confusion_matrix(org_labels, predict_labels))

 运行效果:100个汉字2分钟就可以达到95%精度。

---------------------------------
Run id: cnn_handwrite
Log directory: /tmp/tflearn_logs/
---------------------------------
Preprocessing... Calculating mean over all dataset (this may take long)...
Mean: [ 0.89235026] (To avoid repetitive computation, add it to argument 'mean' of `add_featurewise_zero_center`)
---------------------------------
Preprocessing... Calculating std over all dataset (this may take long)...
STD: 0.192279 (To avoid repetitive computation, add it to argument 'std' of `add_featurewise_stdnorm`)
---------------------------------
Training samples: 19094
Validation samples: 4774
--
Training Step: 597  | total loss: 0.70288 | time: 40.959ss
| Adam | epoch: 001 | loss: 0.70288 - acc: 0.7922 | val_loss: 0.54380 - val_acc: 0.8460 -- iter: 19094/19094
--
Terminating training at the end of epoch 1Training Step: 1194  | total loss: 0.48860 | time: 40.245s
| Adam | epoch: 002 | loss: 0.48860 - acc: 0.8783 | val_loss: 0.37020 - val_acc: 0.8923 -- iter: 19094/19094
--
Terminating training at the end of epoch 2
Training Step: 1791  | total loss: 0.35790 | time: 41.315ss
| Adam | epoch: 003 | loss: 0.35790 - acc: 0.9090 | val_loss: 0.34719 - val_acc: 0.9049 -- iter: 19094/19094
--
Terminating training at the end of epoch 3
Successfully left training! Final model accuracy: 0.908959209919
OK, stop iterate!Good!precision    recall  f1-score   support0       1.00      0.99      0.99       2391       0.95      0.96      0.96       2372       0.91      0.98      0.94       2403       0.90      0.98      0.94       2394       0.96      0.98      0.97       2395       0.94      0.97      0.96       2396       0.98      0.98      0.98       2397       0.84      0.99      0.91       2408       0.99      0.87      0.93       2399       0.95      0.98      0.96       23910       0.97      0.94      0.96       24011       0.95      0.98      0.97       24012       0.92      0.99      0.95       24013       0.95      0.96      0.96       23914       0.94      0.94      0.94       23615       0.94      0.97      0.96       24016       0.94      0.98      0.96       24017       0.97      0.99      0.98       24018       0.94      0.93      0.94       24019       1.00      0.95      0.98       23920       0.96      0.98      0.97       24021       0.98      0.91      0.95       23922       0.97      0.95      0.96       23923       1.00      0.97      0.98       23924       0.94      0.98      0.96       24025       0.98      0.98      0.98       23726       0.91      1.00      0.95       23927       0.91      0.96      0.93       23928       0.97      0.88      0.92       23929       1.00      0.98      0.99       24030       0.99      0.94      0.96       23931       0.97      0.97      0.97       23732       0.94      0.98      0.96       23633       0.94      0.96      0.95       23934       0.98      0.99      0.98       23935       0.99      0.98      0.99       24036       0.96      0.92      0.94       23937       1.00      0.93      0.96       24038       0.96      0.99      0.98       23839       0.98      0.97      0.97       23840       0.92      0.90      0.91       24041       0.96      0.97      0.96       23742       0.98      0.97      0.97       24043       0.95      0.96      0.95       23944       0.97      0.96      0.97       23945       0.95      0.94      0.95       23946       0.93      0.96      0.94       23247       0.98      0.91      0.94       23748       0.95      0.97      0.96       23949       0.97      0.80      0.88       22650       0.90      0.95      0.92       24051       0.98      0.99      0.99       23652       0.96      0.90      0.93       24053       0.99      0.96      0.97       23554       0.97      0.93      0.95       24055       0.99      0.98      0.99       24056       0.97      0.92      0.95       23957       0.97      0.97      0.97       23958       1.00      0.98      0.99       23859       0.92      0.98      0.95       24060       0.99      0.90      0.94       24061       1.00      0.99      0.99       23862       0.92      0.95      0.94       23963       0.92      0.98      0.95       23864       0.98      0.92      0.95       24065       0.99      0.92      0.95       23966       0.98      0.99      0.99       24067       0.95      0.95      0.95       24068       0.96      0.98      0.97       23969       0.97      0.97      0.97       23970       0.98      0.94      0.96       24071       0.91      0.96      0.93       23972       0.98      0.97      0.97       23973       0.99      0.89      0.94       24074       0.97      0.99      0.98       23775       0.89      0.97      0.92       24076       0.97      0.96      0.97       24177       0.89      0.91      0.90       24078       1.00      0.89      0.94       23979       0.90      0.98      0.94       23980       0.89      0.96      0.92       24081       0.96      0.71      0.81       22582       0.95      1.00      0.97       23883       0.67      0.96      0.79       23984       0.97      0.85      0.91       24085       0.95      0.98      0.96       23986       0.99      0.93      0.96       24087       0.98      0.91      0.94       23988       0.97      0.97      0.97       24089       0.97      0.94      0.95       23990       0.97      0.96      0.96       23691       0.91      0.97      0.94       23992       0.98      0.95      0.96       24093       0.98      0.97      0.98       23994       0.98      0.95      0.97       24095       0.98      0.99      0.99       23996       0.95      0.97      0.96       24097       0.98      0.97      0.98       23998       0.95      0.98      0.97       23799       0.97      0.96      0.97       239avg / total       0.96      0.95      0.95     23868[[237   0   0 ...,   0   0   0][  0 228   0 ...,   0   0   0][  0   0 235 ...,   0   0   0]..., [  0   0   0 ..., 233   0   0][  0   0   0 ...,   0 233   0][  0   0   0 ...,   0   0 230]]

更多模型见:http://www.cnblogs.com/bonelee/p/8978060.html

将上述模型保存并给TensorFlow使用,仅仅在保存模型前加del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:],仅仅保留inference时候的OP(如果需要retrain注意),如下:

    model = get_model(width, height, classes=100)filename = 'cnn_handwrite-acc0.8.tflearn'# try to load model and resume training#try:#    model.load(filename)#    print("Model loaded OK. Resume training!")#except:#    pass# Initialize our callback with desired accuracy threshold.early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.8)try:model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')except StopIteration as e:print("OK, stop iterate!Good!")model.save(filename)del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]"""# print op namewith tf.Session() as sess:init_op = tf.initialize_all_variables()sess.run(init_op)for v in sess.graph.get_operations():print(v.name)"""filename = 'cnn_handwrite-acc0.8.infer.tflearn'model.save(filename)

参考:http://www.cnblogs.com/bonelee/p/8445261.html 里的脚本,修改:

output_node_names = "FullyConnected/Softmax"通常为:
output_node_names = "FullyConnected/Softmax"或者
output_node_names = "FullyConnected_1/Softmax"
output_node_names = "FullyConnected_2/Softmax"就看你使用的全连接层数,上面分别是1,2,3层。最后,tensorflow里的使用:
def inference(image):print('inference')temp_image = Image.open(image).convert('L')temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)temp_image = np.asarray(temp_image) / 255.0temp_image = temp_image.reshape([-1, 32, 32, 1])from tensorflow.python.platform import gfilewith tf.Graph().as_default():output_graph_def = tf.GraphDef()with open("frozen_model.pb", "rb") as f:output_graph_def.ParseFromString(f.read())tensors = tf.import_graph_def(output_graph_def, name="")#print tensorswith tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)op = sess.graph.get_operations()"""for m in op:print(m.values())"""op = sess.graph.get_tensor_by_name("FullyConnected_1/Softmax:0")input_tensor = sess.graph.get_tensor_by_name('InputData/X:0')probs = sess.run(op,feed_dict = {input_tensor:temp_image})print probsresult = []for word in probs:result.append(np.argsort(-word)[:3])return resultdef main(_):image_path = './data/test/00098/104405.png'#image_path = '../data/00010/17724.png'final_predict_val = inference(image_path)logger.info('the result info label {0} predict index {1}'.format(98, final_predict_val))

一般,输入TensorFlow input name默认为InputData/X,但只是op,如果要tensor的话,加上数字0,也就是:InputData/X:0

同理,FullyConnected_1/Softmax:0。

最后预测效果:

[[  8.42533936e-08   1.60850794e-11   2.60133332e-10   2.42555542e-144.96124599e-08   4.45251297e-15   3.98175590e-11   1.64476592e-117.03968351e-13   5.42319011e-12   8.55469237e-11   4.91866422e-131.77282828e-07   4.05237593e-10   3.13049003e-10   1.34780919e-112.05803235e-06   2.87827305e-07   1.47789994e-12   2.53391891e-113.77086790e-13   2.02639586e-10   9.03167027e-13   3.96698889e-111.30850096e-11   5.71980917e-12   3.03487374e-11   2.04132298e-146.25303683e-13   1.46122332e-07   2.17450633e-07   1.69623715e-096.80857757e-12   2.52643609e-13   6.56771096e-11   8.55152287e-161.34496514e-09   1.22644633e-06   1.12011307e-07   7.93476283e-058.24334611e-12   4.77531155e-14   9.39397757e-13   2.38438267e-142.11416329e-10   5.54395712e-08   2.30046147e-12   2.63584043e-104.70621564e-16   5.14432724e-12   6.42602327e-09   1.62485829e-137.39078274e-08   3.19146315e-12   5.25887156e-09   1.35877786e-131.39127886e-13   2.11998293e-13   9.09501097e-09   9.46486750e-072.47498733e-09   2.74523763e-12   1.02716433e-14   1.02069058e-173.09356682e-16   1.51022904e-15   9.34333665e-13   2.62195051e-143.38079781e-16   7.43019903e-13   1.92409091e-13   3.86611994e-132.61276265e-12   1.07969211e-09   1.30814548e-09   2.44038188e-149.79275905e-13   1.41007803e-10   6.15137758e-12   2.08893070e-101.34751668e-14   2.76824767e-15   7.84100464e-16   7.70873335e-155.45704757e-12   3.69386271e-18   2.06012223e-13   1.62567273e-141.54544960e-03   2.05292008e-06   1.31726174e-09   7.04993663e-094.11338266e-03   3.19344110e-07   3.96519717e-05   2.26919351e-122.39114349e-12   2.35558744e-07   9.94213998e-01   1.10125060e-11]]
the result info label 98 predict index [array([98, 92, 88])]

 
 

转载于:https://www.cnblogs.com/bonelee/p/8941654.html

tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛...相关推荐

  1. keras终止训练后显存不释放_Keras实现Large-scale Bisample Learning on ID vs. Spot Face Recognition...

    keras-lbl-IvS 论文地址:Large-scale Bisample Learning on ID vs. Spot Face Recognition 工程地址:keras-lbl-IvS ...

  2. 阿里 NIPS 2017 Workshop 论文:基于 TensorFlow 的深度模型训练 GPU 显存优化

    NIPS 2017 在美国长滩举办,场面非常热烈.阿里巴巴一篇介绍深度模型训练 GPU 显存优化的论文<Training Deeper Models by GPU Memory Optimiza ...

  3. NLP之PLUG:阿里达摩院发布最大中文预训练语言模型PLUG的简介、架构组成、模型训练、使用方法之详细攻略

    NLP之PLUG:阿里达摩院发布最大中文预训练语言模型PLUG的简介.架构组成.模型训练.使用方法之详细攻略 目录 PLUG的简介 PLUG的得分 PLUG的特点 PLUG的架构组成 PLUG的模型训 ...

  4. tensorflow 模型预训练后的参数restore finetuning

    之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于 ...

  5. 联邦学习【分布式机器学习技术】【①各客户端从服务器下载全局模型;②各客户端训练本地数据得到本地模型;③各客户端上传本地模型到中心服务器;④中心服务器接收各方数据后进行加权聚合操作,得全局模型】

    随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会. 然而,机器学习技术的发展过程中面临两大挑战: 一是数据安全难以得到保障,隐私数据泄露问题亟待解决: 二是网络安全隔离 ...

  6. AI周报丨中文巨量模型源1.0比GPT-3强在哪里?;谷歌用协同训练策略实现多个SOTA,单一ViT模型执行多模态多任务

    01 # 行业大事件 语言大模型的终极目标是什么? 在自然语言处理(NLP)领域,暴力美学仍在延续. 自 2018 年谷歌推出 BERT(3.4 亿参数)以来,语言模型开始朝着「大」演进.国内外先后出 ...

  7. 【机器学习】svm模型训练后的参数说明

    现简单对屏幕回显信息进行说明: #iter 为迭代次数, nu  与前面的操作参数 -n nu  相同, obj 为 SVM 文件转换为的二次规划求解得到的最小值, rho  为判决函数的常数项 b  ...

  8. python sklearn svm 模型训练后的参数说明

    在调用sklearn的SVM时,如果设置verbose=True,模型训练结束后会显示一些训练过程的说明信息,如下(以下是OCSVM的返回结果): * optimization finished, # ...

  9. [转载] tensorflow如何微调时如何只训练后两层_XLNet只存在于论文?都替你封装好了还不来用!...

    参考链接: 在Python中使用BERT Tokenizer和TensorFlow 2.0进行文本分类 相信前段时间大家都被各种XLNet的解读.解析轰炸了吧.好容易熬过了学会了,到网上一搜,诶!官方 ...

最新文章

  1. [转]linux下fms2流媒体服务器搭建之五-----flv播放器制作篇
  2. linux下JDK的安装
  3. 美团点评业务风控系统构建经验
  4. ABAP快速代码提示功能
  5. DataGridView的DataGridViewComboBoxColumn列点击一次
  6. 剑指Offer--数值的整数次方
  7. HTTPS协议在Tomcat中启用的配置
  8. Ubuntu 下使用SSH 代理
  9. Android利用soap WSDL与Webservice通信
  10. win 2008 64位IIS7出现数据库链接出错的解决办法
  11. vue 心跳监控_Vue中WebSocket加入心跳机制
  12. c语言高精度算法阶乘_学了这么久的C语言,原来可以这样解决算法问题...
  13. linux 读取权限目录权限,文件的读取与写入权限《 Linux 文件与目录权限 》
  14. 黑马程序员java整套视频地址 javaweb+ssh+ssm视频+源码+软件
  15. Linux Lite下打印机驱动安装及针式打印机校准
  16. aamp;m大学计算机科学,名校介绍丨美国 德克萨斯AM大学 Texas AM University
  17. C语言中常用math函数
  18. 其实大多数人没必要关注iPhone5
  19. Android扫描系统文件,安卓文档扫描仪
  20. [UNR #6]稳健型选手

热门文章

  1. linux模块化机制,Linux模块化机制和module_init
  2. android 按键kl文件,Android添加新按键
  3. 为提高访问速度建立本地文件服务器,html5 Application Cache——加快简历二次访问速度...
  4. c语言中void delay0.5(),单片机彩灯是怎样点亮
  5. laravel+vue.js的学习以及为什么浏览器中要有井号“#”
  6. 【大牛疯狂教学】java程序员大专找不到工作
  7. 基于Pytorch再次解读LeNet-5现代卷积神经网络
  8. python【数据结构与算法】01背包问题(附例题)
  9. 【Network Security!】虚拟化架构与系统部署
  10. open_basedir php.ini,关于PHP文件包含目录配置 open_basedir