原文出处:
http://www.wildml.com/2015/09/implementing-a-neural-network-from-scratch/
Posted on September 3, 2015 by Denny Britz
这篇文章帮助我们用python实践一下从零开始训练一个神经网络。
以下是中文翻译:

获取代码
在这篇文章中,我们将从头开始实现一个简单的3层神经网络。 我们不会推导出所有需要的数学运算,但是我会尽量直观地解释我们正在做什么。 我也会指点资源给你阅读细节。
在这里,我假设你熟悉基本的微积分和机器学习的概念,例如 你知道什么是分类和正规化。 理想情况下,您也可以了解梯度下降等优化技术的工作原理。 但是,即使你不熟悉以上任何一点,这篇文章仍然会变得有趣;
但为什么从头开始实施一个神经网络呢? 即使您计划在将来使用像PyBrain这样的神经网络库,从头开始至少实施一次网络也是非常有价值的练习。 它可以帮助您了解神经网络是如何工作的,这对于设计有效的模型是至关重要的。
有一点要注意的是,这里的代码示例并不是非常有效。 他们的意思是很容易理解。 在即将发布的文章中,我将探讨如何使用Theano编写高效的神经网络实现。 (更新:现在可用)

生成数据集
让我们开始生成一个我们可以玩的数据集。 幸运的是,scikit-learn有一些有用的数据集生成器,所以我们不需要自己编写代码。 我们将使用make_moons函数。

# Generate a dataset and plot it
np.random.seed(0)
X, y = sklearn.datasets.make_moons(200, noise=0.20)
plt.scatter(X[:,0], X[:,1], s=40, c=y, cmap=plt.cm.Spectral)

我们生成的数据集有两个类,绘制成红色和蓝色的点。 你可以把蓝点看作是男性患者,将红点看作是女性患者,x轴和y轴是医学测量。

我们的目标是训练一个机器学习分类器,预测给定x和y坐标的正确类别(女性的男性)。 请注意,数据不是线性可分的,我们不能绘制一条直线来分隔两个类。 这意味着线性分类器(如Logistic回归)将无法适用数据,除非您手动设计对于给定数据集非常有效的非线性特征(例如多项式)。

事实上,这是神经网络的主要优势之一。 您不需要担心功能工程。 神经网络的隐藏层将为您学习功能。

Logistic回归

为了证明这一点,让我们训练一个Logistic回归分类器。 它的输入是x和y值,输出是预测的类(0或1)。 为了让我们的生活更轻松,我们使用scikit-learn的Logistic Regression类。

# Train the logistic rgeression classifier
clf = sklearn.linear_model.LogisticRegressionCV()
clf.fit(X, y)# Plot the decision boundary
plot_decision_boundary(lambda x: clf.predict(x))
plt.title("Logistic Regression")

该图显示了我们的Logistic回归分类器学到的决策边界。 它使用直线将数据尽可能分离,但无法捕捉数据的“月亮形状”。

训练一个神经网络

现在我们来构建一个具有一个输入层,一个隐藏层和一个输出层的三层神经网络。 输入层中节点的数量取决于我们数据的维数2.类似地,输出层中节点的数量是由我们所拥有的类的数量决定的,也是2.(因为我们只有2个类, 实际上只能有一个输出节点预测为0或1,但有2个可以使网络稍后扩展到更多类)。 网络的输入将是x和y坐标,其输出将是两个概率,一个是0级(“女性”),一个是1级(“男性”)。 它看起来像这样:

我们可以选择隐藏层的维数(节点数)。我们放入隐藏层的节点越多,我们就可以适应更复杂的功能。但更高的维度是有代价的。首先,需要更多的计算来进行预测并学习网络参数。更多的参数也意味着我们更容易过拟合我们的数据。

如何选择隐藏层的大小?虽然有一些一般的指导方针和建议,但它总是取决于你的具体问题,更多的是艺术而不是科学。稍后我们将使用隐藏的节点数来看看它是如何影响我们的输出的。

我们还需要为隐藏层选择一个激活函数。激活功能将图层的输入转换为其输出。非线性激活函数使我们能够拟合非线性假设。用于激活功能的常见选择是tanh,sigmoid函数或ReLU。我们将使用tanh,在许多场景中表现相当好。这些函数的一个很好的属性是可以使用原始函数值来计算它们的派生值。例如,tanh x的导数是1- (tanh x)^2。这很有用,因为它可以让我们计算一次tanh x并稍后重新使用它的值来得到导数。

因为我们希望我们的网络输出概率,输出层的激活函数将是softmax,这只是将原始分数转换为概率的一种方法。如果您熟悉逻辑功能,您可以将softmax视为对多个类的归纳

Python中从头开始实现神经网络 - 介绍相关推荐

  1. python流程控制语句-Python中流程控制语句的详细介绍

    除了刚才介绍的while语句之外,Python也从其他语言借鉴了其他流程控制语句,并做了相应改变.Python中流程控制语句的详细介绍 4.1 ifStatements 或许最广为人知的语句就是if语 ...

  2. php有lambda表达式吗,Python中lambda表达式的简单介绍(附示例)

    本篇文章给大家带来的内容是关于Python中lambda表达式的简单介绍(附示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 一:匿名函数的定义 lambda parameter ...

  3. chatgpt赋能python:Python中开区间和闭区间的介绍

    Python中开区间和闭区间的介绍 在Python编程中,经常需要使用区间(range)对象.区间对象是Python中自带的一种数据类型,它表示一系列连续的整数.Python中的区间对象支持开区间和闭 ...

  4. chatgpt赋能python:Python中的连接符:介绍与应用

    Python中的连接符:介绍与应用 在Python编程中,连接符起着关键性的作用,它是连接不同代码部分的纽带.本篇文章将重点介绍几种常用的Python连接符. 一.加号连接符(+) 加号连接符最常见, ...

  5. 【新手指南】Python中的listdir()函数的介绍

    [新手指南]Python中的listdir()函数的介绍 在用pytorch导入dataset的时候对listdir()函数产生的是文件还是文件夹一直都有疑问,所以自己先在网上找了一段小代码调试,先小 ...

  6. chatgpt赋能python:Python中的s.len()方法介绍

    Python中的s.len()方法介绍 Python中有各种字符串处理方法,其中s.len()方法是一个重要的方法之一.s.len()返回一个字符串s的长度.这是一个非常基本的方法,但是在很多情况下都 ...

  7. chatgpt赋能python:Python中的Tilde符号的介绍

    Python中的Tilde符号的介绍 在Python的编程环境中,有一个比较神秘的符号,就是波浪线符号,即 "~" 或称为 "Tilde" 符号.这个符号在Py ...

  8. Python中TensorFlow长短期记忆神经网络LSTM、指数移动平均法预测股票市场时间序列和可视化

    最近我们被客户要求撰写关于LSTM的研究报告,包括一些图形和统计输出. 本文探索Python中的长短期记忆(LSTM)网络,以及如何使用它们来进行股市预测. 相关视频:LSTM神经网络架构和工作原理及 ...

  9. python中knn_如何在python中从头开始构建knn

    python中knn k最近邻居 (k-Nearest Neighbors) k-Nearest Neighbors (KNN) is a supervised machine learning al ...

最新文章

  1. 【神经网络】(10) Resnet18、34 残差网络复现,附python完整代码
  2. 根据给定数据创建JSON并验证
  3. 加速!上海要做人工智能产业“领头雁”
  4. RocketMQ与kafka对比(18项差异)-转自阿里中间件
  5. oracle创建外键约束的两种方式
  6. OllyDBG 入门之四--破解常用断点设
  7. AI前沿线上大会,ALBERT一作、京东AI科学家等大咖亲临现场,限时免费,名额有限!...
  8. 小程序测试用例模板_微信小程序样式:高质量小程序样式模板大全
  9. js 1000+简写为K,10000+简写为W
  10. 我的世界java1如何安装mod_《我的世界》【教程】如何安装MOD【PC】
  11. 打开非遗文化新呈现方式 三七互娱“非遗广州红”游园会即将开幕
  12. [论文阅读] Scene Context-Aware Salient Object Detection
  13. 利用WebHook实现自动部署Git代码
  14. MySQL 如何利用一条语句实现类似于if-else条件语句的判断
  15. python怎么打断点_Pycharm如何打断点的方法步骤
  16. 第十一周项目6-回文素数(一)
  17. 音创服务器系统手动加歌,音创ktv点歌系统的教程
  18. Retinex网络模型学习笔记
  19. 量化学习:大数据时代的学习方式
  20. 刘彬20000词汇01

热门文章

  1. Redis学习(一)——
  2. JVM知识点总览:高级Java工程师面试必备
  3. SQL Server 默认跟踪报表
  4. RED5 安装及问题
  5. MSSQL 如何实现 MySQL 的 limit 查询方式【转存】
  6. php 当我添加数据成功后跳到首页 为什么刷新还会增加数据,使用post提交数据之后,有错误,页面刷新之后,想保持原有值...
  7. 相位噪声 matlab,相位噪声仿真方法.PDF
  8. Activiti之H2
  9. Tomcat启动设置环境变量
  10. 交叉报表问题 subDataset