DL之LiR&DNN&CNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测

目录

输出结果

设计思路

核心代码


输出结果

数据集:Dataset之MNIST:MNIST(手写数字图片识别+csv文件)数据集简介、下载、使用方法之详细攻略

设计思路

核心代码

classifier = skflow.TensorFlowLinearClassifier(n_classes=10, learning_rate=0.01)
classifier.fit(X_train, y_train)
linear_y_predict = classifier.predict(X_test)classifier = skflow.TensorFlowDNNClassifier(hidden_units=[200, 50, 10], n_classes = 10,learning_rate=0.01)
classifier.fit(X_train, y_train)
dnn_y_predict = classifier.predict(X_test)classifier = skflow.TensorFlowEstimator(model_fn=conv_model, n_classes=10, steps=20000,learning_rate=0.001)
classifier.fit(X_train, y_train)
classifier.predict(X_test)
class TensorFlowDNNClassifier(TensorFlowEstimator, ClassifierMixin):"""TensorFlow DNN Classifier model.Parameters:hidden_units: List of hidden units per layer.n_classes: Number of classes in the target.tf_master: TensorFlow master. Empty string is default for local.batch_size: Mini batch size.steps: Number of steps to run over data.optimizer: Optimizer name (or class), for example "SGD", "Adam","Adagrad".learning_rate: If this is constant float value, no decay function is used.Instead, a customized decay function can be passed that acceptsglobal_step as parameter and returns a Tensor.e.g. exponential decay function:def exp_decay(global_step):return tf.train.exponential_decay(learning_rate=0.1, global_step,decay_steps=2, decay_rate=0.001)class_weight: None or list of n_classes floats. Weight associated withclasses for loss computation. If not given, all classes are suppose to haveweight one.tf_random_seed: Random seed for TensorFlow initializers.Setting this value, allows consistency between reruns.continue_training: when continue_training is True, once initializedmodel will be continuely trained on every call of fit.num_cores: Number of cores to be used. (default: 4)early_stopping_rounds: Activates early stopping if this is not None.Loss needs to decrease at least every every <early_stopping_rounds>round(s) to continue training. (default: None)max_to_keep: The maximum number of recent checkpoint files to keep.As new files are created, older files are deleted.If None or 0, all checkpoint files are kept.Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)keep_checkpoint_every_n_hours: Number of hours between each checkpointto be saved. The default value of 10,000 hours effectively disables the feature."""def __init__(self, hidden_units, n_classes, tf_master="", batch_size=32, steps=200, optimizer="SGD", learning_rate=0.1, class_weight=None, tf_random_seed=42, continue_training=False, num_cores=4, verbose=1, early_stopping_rounds=None, max_to_keep=5, keep_checkpoint_every_n_hours=10000):self.hidden_units = hidden_unitssuper(TensorFlowDNNClassifier, self).__init__(model_fn=self._model_fn, n_classes=n_classes, tf_master=tf_master, batch_size=batch_size, steps=steps, optimizer=optimizer, learning_rate=learning_rate, class_weight=class_weight, tf_random_seed=tf_random_seed, continue_training=continue_training, num_cores=4, verbose=verbose, early_stopping_rounds=early_stopping_rounds, max_to_keep=max_to_keep, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)def _model_fn(self, X, y):return models.get_dnn_model(self.hidden_units, models.logistic_regression)(X, y)@propertydef weights_(self):"""Returns weights of the DNN weight layers."""weights = []for layer in range(len(self.hidden_units)):weights.append(self.get_tensor_value('dnn/layer%d/Linear/Matrix:0' % layer))weights.append(self.get_tensor_value('logistic_regression/weights:0'))return weights@propertydef bias_(self):"""Returns bias of the DNN's bias layers."""biases = []for layer in range(len(self.hidden_units)):biases.append(self.get_tensor_value('dnn/layer%d/Linear/Bias:0' % layer))biases.append(self.get_tensor_value('logistic_regression/bias:0'))return biases

DL之LiRDNNCNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测相关推荐

  1. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  2. DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练、预测(95%)

    DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练.预测(95%) 目录 数据集展示 输出结果 设计代码 数据集展示 先查看sklearn自带di ...

  3. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  4. TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

    TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练.评估(偶尔100%准确度,交叉熵验证) 目录 输出结果 设计思路 代码设计 输出结果 第 0 accuracy 0. ...

  5. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  6. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  7. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  8. TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

    TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类 目录 设计思路 实现代码 设计思路 更新-- 实现代码 # -*- coding:utf-8 -*- import ten ...

  9. TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程

    TF之DNN:利用DNN[784→500→10]对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程 目录 输出结果 案例理解DNN过程思路 代码设计 输出结果 案 ...

最新文章

  1. P1996 约瑟夫问题
  2. 《Ossim应用指南》入门篇
  3. 不再颓废,重新开始,牛客第一题1016. 部分A+B (15)
  4. 转载:opencv错误rect错误
  5. linux subversion 根目录检出,经验总结:详解Linux下Subversion的安装配置记录 下
  6. mysql 恢复空密码_mysql 找回密码
  7. java生成word带多级标题,word文档怎样设置自动生成多级标题
  8. 如何对接小发猫的伪原创API
  9. 2020大学计算机答案,超星2020大学计算机基础答案 全
  10. [转]C#中的global关键字(global::)
  11. 诺基亚系列手机型号命名研究(转)
  12. tplink怎么进去_在TP-Link工作体验如何?
  13. [ZJOI2007]矩阵游戏(二分图匹配、匈牙利算法)
  14. MySql-主从复制
  15. FPGA学习之串口篇
  16. 项目记录 / 基于AT89C51的环境检测系统
  17. 谷歌浏览器如何彻底关闭右下角弹出的广告弹窗
  18. ####好好#####利用各种信息作为因子的股票价格预测模型研究过程
  19. 服务器ssl证书安装
  20. Google拼音输入法的问题

热门文章

  1. 数据类型_分享redis中除5种基础数据类型以外的高级数据类型
  2. 关于路径搜索的算法, 可能用到
  3. 《Cocos2d 跨平台游戏开发指南(第2版)》一1.9 添加动作到精灵
  4. MySQL中的条件赋值
  5. 使用Go语言从零编写PoS区块链
  6. 也许,DOM 不是答案
  7. Java:异常处理的一些注意事项
  8. Android --- ConnectTimeout 和 ReadTimeout 所代表的意义
  9. 什么是脱离文档流?什么是文档流?
  10. linux多线程编写哲学家,Linux系统编程(三) ------ 多线程编程