看别人的代码和自己写代码,两种的难度和境界真是不一样。昨天和今天尝试着写一个简单的全连接神经网络,用来学习一个基本的模型,在实现的过程中遇到了不少的坑,虽然我已经明白了其中的原理。

我想了一个教材上面没有的简单例子,尝试着自己构造训练数据集和测试集。

我希望训练一个能够区分红点和蓝点的模型。在我构造的数据集中,当x < 1的时候,为蓝点;当x >1的时候为红点。

对于这个全连接网络,输入节点只有一个,表示x轴的坐标。有一个隐藏层,隐藏层的节点数量为3.最后是输出层,有两个节点。对于输出层,如果为[1, 0]表示蓝点,[0, 1]为红点。也就是区分两种不同的结果。

下面是代码实现:(使用了tensorflow框架)


import tensorflow as tf;
import numpy as np;def train_data():x = [[0.2], [0.4], [0.7], [1.2], [1.4], [1.8], [1.9], [2], [0.11], [0.16], [0.5]];y = [[1, 0], [1, 0], [1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [1, 0], [1, 0], [1, 0], [1, 0]];return (x, y);def test_data():x = [[0.3], [0.6], [0.8], [1.3], [1.5]];y = [[1, 0], [1, 0], [1, 0], [0, 1], [0,1]];return (x, y);# 数据数据集
x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='x-input');
# 训练数据集中的label
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='y-input');
# 输入数据和隐藏层连接的权重
w1 = tf.get_variable('weight1', shape=[1, 3],initializer=tf.random_normal_initializer(stddev=1, dtype=tf.float32))
# 输入层和隐藏层之间的偏移量,个数等于隐藏层节点的个数。
b1 = tf.get_variable('biase1', shape=[3],initializer=tf.random_normal_initializer(stddev=1, dtype=tf.float32))
# 隐藏层和输出层链接的权重
w2 = tf.get_variable('weight2', shape=[3, 2],initializer=tf.random_normal_initializer(stddev=1, dtype=tf.float32))
# 隐藏层和输出层之间的偏移量,个数等于输出层节点的个数。
b2 = tf.get_variable('biase2', shape=[2],initializer=tf.random_normal_initializer(stddev=1, dtype=tf.float32))layer1 = tf.nn.sigmoid(tf.matmul(x, w1) + b1);
y = tf.matmul(layer1, w2) + b2; # 模型预测的y值loss = tf.nn.l2_loss(y-y_); # 使用预测的值和训练数据的label的方差作为损失函数,方差越小越好。
# 开始训练过程。
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss);with tf.Session() as sess:# 初始化所有的张量(变量)sess.run(tf.global_variables_initializer());x_train = train_data()[0];y_train = train_data()[1];for i in range(10000):# 迭代一万次sess.run(train_op, feed_dict={x:x_train, y_:y_train})# 代码执行到这里就已经训练完成了。下面是测试。# 测试的思路是:比较预测的值和真实值x_test = test_data()[0];y_test = test_data()[1];count = 0;y_max_value_index = np.argmax(y_test, axis=1); for i in range(5):y_value = sess.run(y[i], feed_dict={x:x_test});if np.equal(y_max_value_index[i], np.argmax(y_value)):count += 1;print("the right proportion: %f"%(count/len(y_test)));

下面写一下所遇到的坑
1. 每一层的数据、权重和偏移量的维数需要严格对应。从输入数据到输出数据,它们的维数为:[none, 1] ->[1, 3] -> [3, 2], 这里none表示为不确定。也就是输入数据集是一个只有一列行数不确定的数据。隐藏层的权重矩阵是一个一行3列的矩阵。输出层为三行2列的矩阵。

结束感谢.

使用tensorflow实现全连接神经网络的简单示例,含源码相关推荐

  1. python下的orm基本操作(1)--Mysql下的CRUD简单操作(含源码DEMO)

    最近逐渐打算将工作的环境转移到ubuntu下,突然发现对于我来说,这ubuntu对于我这种上上网,收收邮件,写写博客,写写程序的时实在是太合适了,除了刚接触的时候会不怎么完全适应命令行及各种权限管理, ...

  2. 身份证问题讲解全连接神经网络

    标题 身份证问题讲解全连接神经网络(简单demo) #导入tf模块与随机函数模块 import tensorflow as tf import random#当seed()没有参数时,每次生成的随机数 ...

  3. 简单的全连接神经网络(tensorflow实现)

    简单的全连接神经网络,网络结构为2-2-1 代码如下: #encoding='utf-8' """ created on 2018-08-10 @author wt &q ...

  4. TF之DNN:TF利用简单7个神经元的三层全连接神经网络【2-3-2】实现降低损失到0.000以下

    TF之DNN:TF利用简单7个神经元的三层全连接神经网络实现降低损失到0.000以下(输入.隐藏.输出层分别为 2.3 . 2 个神经元) 目录 输出结果 实现代码 输出结果 实现代码 # -*- c ...

  5. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  6. Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

    文章目录 1 可视化 神经网络的二元分类效果 2 全连接神经网络 3 TensorFlow搭建一个全连接神经网络 3.1 读取MNIST数据 3.2 建立占位符 3.3 建立模型 3.4 正确率 3. ...

  7. 【TensorFlow】TensorFlow从浅入深系列之十 -- 教你认识卷积神经网络的基本网路结构及其与全连接神经网络的差异

    本文是<TensorFlow从浅入深>系列之第10篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...

  8. python——tensorflow使用和两层全连接神经网络搭建

    一.Tensorflow使用 1.Tensorflow简介 TensorFlow是一个软件库,封装了建构神经网络的API,类似于MatLab的神经网络工具箱,是Google于2015年推出的深度学习框 ...

  9. 【神经网络与深度学习】 Numpy 实现全连接神经网络

    1.实验名称 Numpy 实现全连接神经网络实验指南 2.实验要求 用 python 的 numpy 模块实现全连接神经网络. 网络结构为一个输入层.一个隐藏层.一个输出层. 隐藏层的激活函数为 Re ...

最新文章

  1. 用神经网络分类无理数2**0.5和3**0.5
  2. 2020 华工 数据结构-平时作业_【激光】从上海工博会看华工激光的差异化路线...
  3. 前端学习(1841):前端面试题之redux管理状态机制
  4. asp.net datatable 导出为 txt
  5. hibernate不能保存时分秒处理
  6. VB 详细枚举指定目录、文件夹文件列表
  7. 预处理,编译,汇编,链接程序的区别
  8. Web开发(初级)- 常用css总结,方便查询
  9. 服务器怎么关闭终端依然运行node,关闭控制台后如何永久运行node.js应用程序?...
  10. 应用ruby打造个性化的有道单词本 (二)
  11. GD32F103串口DMA收发
  12. linux 编译chromium,chromium(linux环境)指定版本下载和编译教程
  13. 遇见未来 | 对话叶毓睿:人类文明运行在软件之上(上篇)
  14. 人件札记:团队的化学反应
  15. ICPC-无限路之城(数学+思维)
  16. 二分法查找--Dichotomy search
  17. Win 10出现bitlocke恢复,蓝屏错误代码0x1600007e
  18. 计算机基础知识在教学的应用,计算机基础知识中项目教学法的应用
  19. 关于CMNET和CMWAP联网实践
  20. Unity Manual阅读记录——Animation(version 2019.4)

热门文章

  1. android sdcard 不存在,在android中显示sdcard上不存在的文件的提醒
  2. access两字段同时升序排序_7 天时间,我整理并实现了这 9 种常见的排序算法
  3. 第二周Access总结
  4. Java-NIO(九):管道 (Pipe)
  5. haproxy参数优化
  6. mysql主从切换(正常切换)
  7. KindEditor得不到textarea值的解决方法----摘至天涯
  8. they're hiring
  9. linux注销、关机、重启
  10. linux apache web服务器