import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import  shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale
import pickle
import struct
import os"""定义几个要用到的函数"""
def load_mnist(path, kind='train'):"""Load MNIST data from `path`"""labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))labels = np.fromfile(lbpath, dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)return images, labelsdef softmax(x, axis=1):# 计算每行的最大值row_max = x.max(axis=axis)# 每行元素都需要减去对应的最大值,否则求exp(x)会溢出,导致inf情况row_max=row_max.reshape(-1, 1)x = x - row_max# 计算e的指数次幂x_exp = np.exp(x)x_sum = np.sum(x_exp, axis=axis, keepdims=True)s = x_exp / x_sumreturn sdef sigmoid(x, deriv=False):if deriv==True:return sigmoid(x)*(1-sigmoid(x))return 1/(1+np.exp(-x))def transform_one_hot(labels):# 转为one-hot编码n_labels = np.max(labels) + 1one_hot = np.eye(n_labels)[labels]return one_hot"""加载数据"""
images, labels = load_mnist(r'G:\demo\daily\raw')
Y = transform_one_hot(labels.astype(np.int)).astype(np.float16)
X = minmax_scale(images, feature_range=(0, 1), axis=1)
X, Y = shuffle(X, Y)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 1/6)"""设置网络参数"""
hidden_layer = 100
learning_rate = 0.01
M = 1.05
batch_size = 200
epoch = 30"""初始化网络参数"""
# np.random.seed(2)
sample_size = X_train.shape[0]
input_layer_neurons = X_train.shape[1]
output_layer_neurons = y_train.shape[1]
w1 = 2*np.random.random((input_layer_neurons, hidden_layer))-1
w2 = 2*np.random.random((hidden_layer, output_layer_neurons))-1
b1 = np.zeros((1, hidden_layer))
b2 = np.zeros((1, output_layer_neurons))"""训练网络"""
losses = []
learning_accs = []
validation_accs = []
for i in range(epoch):alpha = M**i*learning_ratefor j in range(sample_size//batch_size):x_batch = X_train[j*batch_size:(j+1)*batch_size]y_batch = y_train[j*batch_size:(j+1)*batch_size]# 前向传播a1 = np.dot(x_batch, w1) + b1h1 = sigmoid(a1)a2 = np.dot(h1, w2) + b2# 计算损失loss = -y_batch.squeeze()*np.log(softmax(a2))# 反向传播dlda2 = softmax(a2) - y_batchdldw2 = np.dot(h1.T, dlda2)dldb2 = np.mean(dlda2, axis=0).reshape(1, -1)dldh1 = np.dot(dlda2, w2.T)dlda1 = dldh1 * sigmoid(a1, deriv=True)dldw1 = np.dot(x_batch.T, dlda1)dldb1 = np.mean(dlda1, axis=0).reshape(1, -1)# 更新权重w1 -= alpha * dldw1w2 -= alpha * dldw2b1 -= alpha * dldb1b2 -= alpha * dldb2# 打印训练进度if(sample_size//batch_size<=25):print('\r'+"Epoch: "+str(i)+"\t"+str(round((j+1)/(sample_size//batch_size)*100,1))+"%" +"\t"+(j+1)*'>'+(sample_size//batch_size-1-j)*'.', end="", flush=True)else:if (j+1)%(sample_size//batch_size//25)==0:print("\r"+"Epoch: "+str(i)+"\t"+str(round((j+1)/(sample_size//batch_size)*100,1))+"%" +"\t"+(j+1)//(sample_size//batch_size//25)*'>'+(sample_size//batch_size-1-j)//(sample_size//batch_size//25)*'.', end="", flush=True)# 模型评估ytrain_ = np.argmax(softmax(np.dot(sigmoid(np.dot(X_train[0:10000], w1) + b1), w2) + b2),axis=1)ytest_ = np.argmax(softmax(np.dot(sigmoid(np.dot(X_test[0:10000], w1) + b1), w2) + b2),axis=1)losses.append(loss.mean())learning_acc = np.sum(np.argmax(y_train[0:10000], axis=1)==ytrain_)/len(ytrain_)validation_acc = np.sum(np.argmax(y_test[0:10000], axis=1)==ytest_)/len(ytest_)learning_accs.append(learning_acc)validation_accs.append(validation_acc)# 打印迭代结果print("\tLoss:"+str(round(loss.mean(), 5))+"\tLearning_Acc:"+str(round(learning_acc*100, 2))+"%\tValidation_Acc:"+str(round(validation_acc*100, 2))+"%")"""画图"""
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax2.plot(losses,  '*-', linewidth=2, color='g')
ax1.plot(learning_accs, '^-', linewidth=2, label="learning_Acc")
ax1.plot(validation_accs, 'o-', linewidth=2, label='validation_Acc')
ax1.set_ylabel('Accuracy(%)')
ax2.set_ylabel('loss')
ax1.set_xlabel('epoch')
plt.grid()
ax1.legend(loc='best')
plt.show()

numpy实现全连接网络进行mnist训练测试相关推荐

  1. MNIST 训练测试

    如下是用的是普通的前馈神经网络,输入层节点为784,隐层节点500,输出层10.这是一个标准的讲解DeepLearning中,在Tensorflow, Pytorch框架下的典型的用于识别MNIST手 ...

  2. 从一到二:利用mnist训练集生成的caffemodel对mnist测试集与自己手写的数字进行测试...

    通过从零到一的教程,我们已经得到了通过mnist训练集生成的caffemodel,主要包含下面四个文件: 接下来就可以利用模型进行测试了.关于测试方法按照上篇教程还是选择bat文件,当然python. ...

  3. (caffe入门)windows caffe 之 mnist 训练

    文章目录 1. mnist 数据集下载 2. mnist 数据集转换 3. 修改网络结构文件 lenet_train_test.prototxt 和 网络求解文件 lenet_solver.proto ...

  4. 深度学习实战(七)——目标检测API训练自己的数据集(R-FCN数据集制作+训练+测试)

    TensorFlow提供的网络结构的预训练权重:https://cloud.tencent.com/developer/article/1006123 将voc数据集转换成.tfrecord格式供te ...

  5. 【MMDetection3D】环境搭建,使用PointPillers训练测试可视化KITTI数据集

    文章目录 前言 3D目标检测概述 KITTI数据集简介 MMDetection3D 环境搭建 数据集准备 训练 测试及可视化 绘制损失函数曲线 参考资料 前言 2D卷不动了,来卷3D,之后更多地工作会 ...

  6. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  7. mask-rcnn训练测试自制数据集

    mask-rcnn训练测试自制数据集 本项目简介 本项目用于口腔模型分割,数据类型有7种,本文主要用于介绍如何使用自制数据集训练自己的模型 训练环境配置 操作系统:win10 GPU: GTX 108 ...

  8. vid2vid 代码调试+训练+测试(debug+train+test)(一)测试篇

    ## Prerequisites - Linux or macOS - Python 3 - NVIDIA GPU + CUDA cuDNN - PyTorch 0.4 但一般的话我们为了保护已有的环 ...

  9. 什么是端到端训练测试_为什么端到端测试对您的团队很重要

    什么是端到端训练测试 by Phong Huynh 由Phong Huynh 为什么端到端测试对您的团队很重要 (Why End-to-End Testing is Important for You ...

最新文章

  1. 真AI用钱表达:这家少年班毕业生创办的AI公司三年就盈利,增速300%
  2. python 如何获取列表(List)中指定元素的下标? index() enumerate() 获取重复元素下标
  3. mysql null的作用_MySQL中对于NULL值的理解和使用教程
  4. flash可以编辑html文本吗,flash中怎么插入并编辑文字字体样式?
  5. mysql5.7.23手动配置安装windows版
  6. ng-content的一个实际例子
  7. 让div垂直以及水平居中浏览器窗口
  8. Python案例:给出三角形构成方案
  9. FontExplorer X Pro for Mac字体管理软件
  10. Scrum项目6.0 和8910章读后感
  11. bui框架与php结合,bui框架前端自定义配色基础属性
  12. Apple Pay发展与安全
  13. IE无法打开internet网站已终止操作的解决的方法
  14. 严格对角占优矩阵特征值_对角占优矩阵的性质.doc
  15. 如何把pdf文件变小一点
  16. python语言要英语基础吗_学编程需要英语基础吗?
  17. alibaba/COLA 4.0框架 使用记录
  18. ThinkPHP5结合云之讯短信验证简单案例
  19. aria2的安装使用
  20. 简历之精通 熟练 掌握 熟悉 了解

热门文章

  1. python flask oauth_Flask之 flask_httpauth(HTTPTokenAuth)
  2. 计算器百分号如何用代码实现_如何用 100 行 Python 代码实现新闻爬虫?这样可算成功?...
  3. python中for循环-python中关于for循环的碎碎念
  4. 【script】python 解析 Windows日志(python-evtx)
  5. ubuntu下mysql5.7安装教程_Ubuntu 16.04 上安装 MySQL 5.7 教程
  6. android onclick执行顺序,浅谈onTouch先执行,还是onClick执行(详解)
  7. com/mysql/jdbc/statementimpl_com.mysql.jdbc.异常.jdbc4。通信异常:通信链路故障
  8. 程序员离职代码交接_程序员离职大半个月,被老板命令回单位讲代码,员工:一次1万...
  9. button点击后变色_炒丝瓜怎么不变色?鹏厨教你制作小窍门,健康美味、颜色碧绿...
  10. 多个字符合并成一个数组_一个excel多个sheet,需要合并为一个sheet