100个汉字,放在data目录下。直接将下述文件和data存在同一个目录下运行即可。

关键参数:

run_mode = "train" 训练模型用,修改为validation 表示验证100张图片的预测精度,修改为inference表示预测 './data/00098/102544.png'这个图片手写识别结果,返回top3。

charset_size = 100 表示汉字数目。如果是全量数据,则为3755.

代码参考了:https://github.com/burness/tensorflow-101/blob/master/chinese_hand_write_rec/src/chinese_rec.py

其中加入:(1)图像随机左右旋转30度特性 (2)断点续传进行训练(3)为了达到更高精度,加入了一个卷积层,见https://github.com/AmemiyaYuko/HandwrittenChineseCharacterRecognition

import tensorflow as tf
import os
import random
import math
import tensorflow.contrib.slim as slim
import time
import logging
import numpy as np
import pickle
from PIL import Imagelogger = logging.getLogger('Training a chinese write char recognition')
logger.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
logger.addHandler(ch)run_mode = "train"
charset_size = 100 # 3755
max_steps = 12002
save_steps = 2000"""
# for online 3755 words training
checkpoint_dir = '/aiml/dfs/checkpoint/'
train_data_dir = '/aiml/data/train/'
test_data_dir = '/aiml/data/test/'
log_dir = '/aiml/dfs/'
"""checkpoint_dir = './checkpoint2/'
train_data_dir = './data/'
test_data_dir = './data/'
log_dir = './'tf.app.flags.DEFINE_string('mode', run_mode, 'Running mode. One of {"train", "valid", "test"}')
tf.app.flags.DEFINE_boolean('random_flip_up_down', True, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")tf.app.flags.DEFINE_integer('charset_size', charset_size, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', max_steps, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', save_steps, "the steps to save")tf.app.flags.DEFINE_string('checkpoint_dir', checkpoint_dir, 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', train_data_dir, 'the train dataset dir')
tf.app.flags.DEFINE_string('test_data_dir', test_data_dir, 'the test dataset dir')
tf.app.flags.DEFINE_string('log_dir', log_dir, 'the logging dir')##############################
# resume training
tf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')
##############################tf.app.flags.DEFINE_boolean('epoch', 10, 'Number of epoches')
tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')
FLAGS = tf.app.flags.FLAGSclass DataIterator:def __init__(self, data_dir):# Set FLAGS.charset_size to a small value if available computation power is limited.truncate_path = data_dir + ('%05d' % FLAGS.charset_size)print(truncate_path)self.image_names = []for root, sub_folder, file_list in os.walk(data_dir):if root < truncate_path:self.image_names += [os.path.join(root, file_path) for file_path in file_list]random.shuffle(self.image_names)self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]@propertydef size(self):return len(self.labels)@staticmethoddef data_augmentation(images):if FLAGS.random_flip_up_down:# images = tf.image.random_flip_up_down(images)images = tf.contrib.image.rotate(images, random.randint(0, 30) * math.pi / 180, interpolation='BILINEAR')if FLAGS.random_brightness:images = tf.image.random_brightness(images, max_delta=0.3)if FLAGS.random_contrast:images = tf.image.random_contrast(images, 0.8, 1.2)return imagesdef input_pipeline(self, batch_size, num_epochs=None, aug=False):images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)labels = input_queue[1]images_content = tf.read_file(input_queue[0])images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)if aug:images = self.data_augmentation(images)new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)images = tf.image.resize_images(images, new_size)image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000)return image_batch, label_batchdef build_graph(top_k):# with tf.device('/cpu:0'):keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')conv_4 = slim.conv2d(max_pool_3, 512, [3, 3], [2, 2], scope="conv4", padding="SAME")max_pool_4 = slim.max_pool2d(conv_4, [2, 2], [2, 2], padding="SAME")flatten = slim.flatten(max_pool_4)fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)probabilities = tf.nn.softmax(logits)tf.summary.scalar('loss', loss)tf.summary.scalar('accuracy', accuracy)merged_summary_op = tf.summary.merge_all()predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))return {'images': images,'labels': labels,'keep_prob': keep_prob,'top_k': top_k,'global_step': global_step,'train_op': train_op,'loss': loss,'accuracy': accuracy,'accuracy_top_k': accuracy_in_top_k,'merged_summary_op': merged_summary_op,'predicted_distribution': probabilities,'predicted_index_top_k': predicted_index_top_k,'predicted_val_top_k': predicted_val_top_k}def train():print('Begin training')train_feeder = DataIterator(FLAGS.train_data_dir)test_feeder = DataIterator(FLAGS.test_data_dir)with tf.Session() as sess:train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)graph = build_graph(top_k=1)sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)saver = tf.train.Saver()train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')start_step = 0if FLAGS.restore:ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:saver.restore(sess, ckpt)print("restore from the checkpoint {0}".format(ckpt))start_step += int(ckpt.split('-')[-1])logger.info(':::Training Start:::')try:while not coord.should_stop():start_time = time.time()train_images_batch, train_labels_batch = sess.run([train_images, train_labels])feed_dict = {graph['images']: train_images_batch,graph['labels']: train_labels_batch,graph['keep_prob']: 0.8}_, loss_val, train_summary, step = sess.run([graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],feed_dict=feed_dict)train_writer.add_summary(train_summary, step)end_time = time.time()logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))if step > FLAGS.max_steps:breakif step % FLAGS.eval_steps == 1:test_images_batch, test_labels_batch = sess.run([test_images, test_labels])feed_dict = {graph['images']: test_images_batch,graph['labels']: test_labels_batch,graph['keep_prob']: 1.0}accuracy_test, test_summary = sess.run([graph['accuracy'], graph['merged_summary_op']],feed_dict=feed_dict)test_writer.add_summary(test_summary, step)logger.info('===============Eval a batch=======================')logger.info('the step {0} test accuracy: {1}'.format(step, accuracy_test))logger.info('===============Eval a batch=======================')if step % FLAGS.save_steps == 1:logger.info('Save the ckpt of {0}'.format(step))saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),global_step=graph['global_step'])except tf.errors.OutOfRangeError:logger.info('==================Train Finished================')saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])finally:coord.request_stop()coord.join(threads)def validation():print('validation')test_feeder = DataIterator(FLAGS.test_data_dir)final_predict_val = []final_predict_index = []groundtruth = []with tf.Session() as sess:test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)graph = build_graph(top_k=3)sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())  # initialize test_feeder's inside statecoord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)saver = tf.train.Saver()ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:saver.restore(sess, ckpt)print("restore from the checkpoint {0}".format(ckpt))print(':::Start validation:::')try:i = 0acc_top_1, acc_top_k = 0.0, 0.0while not coord.should_stop():i += 1start_time = time.time()test_images_batch, test_labels_batch = sess.run([test_images, test_labels])feed_dict = {graph['images']: test_images_batch,graph['labels']: test_labels_batch,graph['keep_prob']: 1.0}batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],graph['predicted_val_top_k'],graph['predicted_index_top_k'],graph['accuracy'],graph['accuracy_top_k']], feed_dict=feed_dict)final_predict_val += probs.tolist()final_predict_index += indices.tolist()groundtruth += batch_labels.tolist()acc_top_1 += acc_1acc_top_k += acc_kend_time = time.time()logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)".format(i, end_time - start_time, acc_1, acc_k))except tf.errors.OutOfRangeError:logger.info('==================Validation Finished================')acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.sizeacc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.sizelogger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))finally:coord.request_stop()coord.join(threads)return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}def inference(image):print('inference')temp_image = Image.open(image).convert('L')temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)temp_image = np.asarray(temp_image) / 255.0temp_image = temp_image.reshape([-1, 64, 64, 1])with tf.Session() as sess:logger.info('========start inference============')# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])# Pass a shadow label 0. This label will not affect the computation graph.graph = build_graph(top_k=3)saver = tf.train.Saver()ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:saver.restore(sess, ckpt)predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})return predict_val, predict_indexdef main(_):print(FLAGS.mode)if FLAGS.mode == "train":train()elif FLAGS.mode == 'validation':dct = validation()result_file = 'result.dict'logger.info('Write result into {0}'.format(result_file))with open(result_file, 'wb') as f:pickle.dump(dct, f)logger.info('Write file ends')elif FLAGS.mode == 'inference':image_path = './data/00098/102544.png'final_predict_val, final_predict_index = inference(image_path)logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,final_predict_val))if __name__ == "__main__":tf.app.run()

转载于:https://www.cnblogs.com/bonelee/p/8952748.html

cnn handwrite使用原生的TensorFlow进行预测相关推荐

  1. CNN入门详解及TensorFlow源码实现--深度学习笔记

    CNN入门详解及TensorFlow源码实现–深度学习笔记 ##一.卷积神经网络 ###1.简介 卷积神经网络是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出 ...

  2. 深度学习多变量时间序列预测:卷积神经网络(CNN)算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:卷积神经网络(CNN)算法构建时间序列多变量模型预测交通流量+代码实战 卷积神经网络,听起来像是计算机科学.生物学和数学的诡异组合,但它们已经成为计算机视觉领域中最具影响力 ...

  3. [Python人工智能] 八.卷积神经网络CNN原理详解及TensorFlow编写CNN

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章介绍了什么是过拟合,并采用droput解决神经网络中过拟合的问题,以TensorFlow和sklearn的lo ...

  4. TensorFlow 房价预测

    TensorFlow 房价预测 以下资料来源于极客时间学习资料 • 房价预测模型介绍 前置知识:监督学习(Supervised Learning) 监督学习是机器学习的一种方法,指从训练数据(输入和预 ...

  5. 基于CNN和LSTM的气象图降水预测示例

    我们是否可以通过气象图来预测降水量呢?今天我们来使用CNN和LSTM进行一个有趣的实验. 我们这里使用荷兰皇家气象研究所(也称为KNMI)提供的开放数据集和公共api,来获取数据集并且构建模型预测当地 ...

  6. tensorflow lstm 预测_解析seq2seq原理+tensorflow实现

    1 写在前面 最近抽时间写文章,将自己以前学过的知识总结一下,通过文章记录下来,以后忘记了也可以随时翻阅. 本文主要介绍最基础的seq2seq模型,并尝试使用tensorflow实现.很多朋友都说py ...

  7. tensorflow lstm 预测_图卷积神经网络GCN与递归结构RNN相结合的时间序列预测

    时间序列预测任务可以按照不同的方法执行.最经典的是基于统计和自回归的方法.更准确的是基于增强和集成的算法,我们必须使用滚动周期生成大量有用的手工特性.另一方面,我们可以使用在开发过程中提供更多自由的神 ...

  8. 使用TensorFlow概率预测航空乘客人数

    TensorFlow Probability uses structural time series models to conduct time series forecasting. In par ...

  9. CNN中的卷积核及TensorFlow中卷积的各种实现

    声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...

最新文章

  1. linux下vmware tools工具共享
  2. Mendix 披露低代码方法论,解读真实技术趋势
  3. 小马智行获2.67亿美元新融资,估值超53亿美元
  4. python中文读音ndarray-Python Numpy 控制台完全输出ndarray的实现
  5. 各种排序实现以及稳定性分析
  6. 石墨烯区块链(4)API
  7. 中国汽车脚垫市场消费趋势与营销渠道分析报告2022版
  8. redis 6379端口telnet不通的解决办法
  9. Windows下Android开发环境 搭建
  10. 《Java8实战》笔记(05):使用流
  11. 【C++】 Boost 库编译技巧总结
  12. (第十二章)创建数据表索引
  13. 第十天内容《基础交换十》
  14. kibana安装与基础用法
  15. 要毕业了,兄弟也签了工作。。。
  16. Ant Design 遭删库!
  17. 小白学JS,利用JavaScripty验证通过15位和18位身份证验证性别
  18. 每日工作记录——ERROR:Simulator:793 - Unable to elaborate instantiated module work
  19. Android 启动“无启动图标的 apk“
  20. Conv2d函数详解(Pytorch)

热门文章

  1. 宜昌高新区三峡云计算机大楼,【智慧宜昌】CREATOR快捷CS分布式系统成功入驻三峡云计算中心...
  2. 熟悉Linux的环境实验报告,实验1 熟悉Linux开发环境 实验报告
  3. linux服务器性能监控命令汇总之sar命令(五)
  4. Python金融大数据分析——第四章数据类型和结构
  5. java mongodb 使用MongoCollection,BasicDBObject 条件查询
  6. python具备哪些特点_Python具备那些特点?
  7. golang日志服务器_日志系统 | log/syslog (log) – Go 中文开发手册 - Break易站
  8. python 递归目录_Python3:递归实现输出目录下所有的文件
  9. 用lisp编写串口助手源代码_实战用python来写个串口助手--界面篇
  10. active server pages 错误 asp 0126_微信小程序全栈开发课程【视频版】2.1 小程序前端页面初始配置、ESlint格式错误...