参考链接中的文章,有错误,我给更正了。

并且原文中是需要数据集文件的,我直接给替换成了一个数组,采用直接赋值的方式。

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as pltclass SimpleDataReader(object):def __init__(self, data_file):self.train_file_name = data_fileself.num_train = 0self.XTrain = Noneself.YTrain = None# read data from filedef ReadData(self):# data = np.load(self.train_file_name)# self.XTrain = data["data"]# self.YTrain = data["label"]self.XTrain = np.array([0.95, 3, 4, 5.07, 6.03, 8.21, 8.85, 12.02, 15], dtype=float)self.YTrain = np.array([5.1, 8.7, 11.5, 13, 15.3, 18, 21, 26.87, 32.5], dtype=float)self.num_train = self.XTrain.shape[0]#end if# get batch training datadef GetSingleTrainSample(self, iteration):x = self.XTrain[iteration]y = self.YTrain[iteration]return x, ydef GetWholeTrainSamples(self):return self.XTrain, self.YTrainclass NeuralNet(object):def __init__(self, eta):self.eta = etaself.w = 0self.b = 0def __forward(self, x):z = x * self.w + self.breturn zdef __backward(self, x,y,z):dz = z - y                  # 原错误为:dz = x * (z - y)db = dzdw = dzreturn dw, dbdef __update(self, dw, db):self.w = self.w - self.eta * dwself.b = self.b - self.eta * dbdef train(self, dataReader):for i in range(dataReader.num_train):# get x and y value for one samplex,y = dataReader.GetSingleTrainSample(i)# get z from x,yz = self.__forward(x)# calculate gradient of w and bdw, db = self.__backward(x, y, z)# update w,bself.__update(dw, db)# end fordef inference(self, x):return self.__forward(x)if __name__ == '__main__':# read datasdr = SimpleDataReader('ch04.npz')sdr.ReadData()# create neteta = 0.1net = NeuralNet(eta)net.train(sdr)# resultprint("w=%f,b=%f" %(net.w, net.b))# 绘图部分trainX,trainY = sdr.GetWholeTrainSamples()fig = plt.figure()ax = fig.add_subplot(111)# 绘制散点图ax.scatter(trainX,trainY)# 绘制线性回归x = np.arange(0, 15, 0.01)f = np.vectorize(net.inference, excluded=['x'])plt.plot(x,f(x),color='red')# 显示图表plt.show()

Ref:

  1. 通过神经网络实现线性回归的拟合

【Python】纯代码通过神经网络实现线性回归的拟合相关推荐

  1. [Python人工智能] 七.加速神经网络、激励函数和过拟合

    从本系列文章开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前六篇文章讲解了神经网络基础概念.Theano库的安装过程及基础用法.theano实现回归神经网络.theano实现 ...

  2. python编程代码示例_python编程线性回归代码示例

    用python进行线性回归分析非常方便,有现成的库可以使用比如:numpy.linalog.lstsq例子.scipy.stats.linregress例子.pandas.ols例子等. 不过本文使用 ...

  3. Python实现多项式回归实战——以及与线性回归的拟合效果对比

    对于给出的 数据做出散点图,可以大致看出模型是否适合做线性回归,但是,线性回归一定是拟合最好的模型吗?答案是否定的.有时候,多项式回归会得出拟合效果更好的模型,但是也需要注意过拟合的线性. 下面,还是 ...

  4. 【机器学习入门】(8) 线性回归算法:正则化、岭回归、实例应用(房价预测)附python完整代码和数据集

    各位同学好,今天我和大家分享一下python机器学习中线性回归算法的实例应用,并介绍正则化.岭回归方法.在上一篇文章中我介绍了线性回归算法的原理及推导过程:[机器学习](7) 线性回归算法:原理.公式 ...

  5. 理解神经网络,从简单的例子开始(1)7行python代码构建神经网络

    理解神经网络,从简单的例子开始(1)7行python代码构建神经网络 前言 本文分为两个部分,第一个部分是一个简单的实例:9行Python代码搭建神经网络,这篇文章原文为:原文链接, 其中中文翻译版来 ...

  6. 独家 | 手把手教你用Python创建简单的神经网络(附代码)

    作者:Michael J.Garbade 翻译:陈之炎 校对:丁楠雅 本文共2000字,建议阅读9分钟. 本文将为你演示如何创建一个神经网络,带你深入了解神经网络的工作方式. 了解神经网络工作方式的最 ...

  7. 9行Python代码搭建神经网络

    9行Python代码搭建神经网络 Kaiser谈笑风生 4月前发表至趣味项目,5995次访问 原文:How to build a simple neural network in 9 lines of ...

  8. python编程例子 输入 输出-推荐 :手把手教你用Python创建简单的神经网络(附代码)...

    原标题:推荐 :手把手教你用Python创建简单的神经网络(附代码) 作者:Michael J.Garbade:翻译:陈之炎:校对:丁楠雅 本文共2000字,9分钟. 本文将为你演示如何创建一个神经网 ...

  9. python 自动化-Python API 自动化实战详解(纯代码)

    主要讲如何在公司利用Python 搞API自动化. 1.分层设计思路 dataPool :数据池层,里面有我们需要的各种数据,包括一些公共数据等 config :基础配置 tools : 工具层 co ...

最新文章

  1. 2021年大数据Flink(三十一):​​​​​​​Table与SQL案例准备 依赖和​​​​​​​程序结构
  2. SpringBoot-JPA入门
  3. JavaScript初学者编程题(25)
  4. 值域范围 tf.clip_by_value的用法
  5. 分布式事务中间件 Fescar - 全局写排它锁解读
  6. #论文 《Wide Deep Learning for Recommender System》翻译
  7. JAVA语言中的反射机制
  8. ALGO-1 区间k大数查询
  9. Java反序列化json内存溢出_反序列化JSON时出现线程错误
  10. oracle和sql server取第一条记录的区别以及rownum详解
  11. 2. linux的日志文件在哪个目录,位于/var/log目录下的20个Linux日志文件
  12. python 白盒测试_白盒测试教程 - 颜丽的个人空间 - OSCHINA - 中文开源技术交流社区...
  13. docker 安装mysql5.6
  14. 苹果发布无人车安全报告,内容竟只有7页?
  15. 另类的缓存技术(存储数据)
  16. java8之StringJoiner。终于有像guava类库里的功能了
  17. python异常类父类_python【第五篇】异常处理
  18. 决策支持系统是什么?
  19. 寻找百度图片搜索接口--two
  20. 液化石油气(LPG)的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告

热门文章

  1. 改造我们的学习:有钱不会花,抱着金库抓瞎
  2. mySql中使用命令行建表基本操作
  3. JavaScript的案例(数据校验,js轮播图,页面定时弹窗)
  4. Abiword 编辑事件设计
  5. soj1209- 最短的距离(精度问题)
  6. [导入]C#向Sql Server中插入记录时单引号的处理
  7. UA MATH567 高维统计专题1 稀疏信号及其恢复7 LASSO的预测误差与变量选择一致性
  8. Linux磁盘管理基础学习
  9. 3d数学基础学习总结
  10. NHibernate重要概念的解释和说明