代码:

  • 导入包
import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential按顺序构成的模型
from keras.models import Sequential
# Dense全连接层
from keras.layers import Dense,Activation
from keras.optimizers import SGD
  • 生成随机数据
# 使用Numpy生成200个-0.5~0.5之间的值
x_data = np.linspace(-0.5, 0.5, 200)
noise = np.random.normal(0, 0.02, x_data.shape)# y_data= x_data**2 + noise
y_data = np.square(x_data) + noise # 效果与上面一致# 显示随机点
plt.scatter(x_data, y_data)
plt.show()

  • 创建模型+训练
    加入隐藏层拟合更加复杂模型
    加入激活函数来拟合非线性模型
# 建立一个顺序模型
model = Sequential()
# 1-10-1: 加入一个隐藏层(10个神经元):来拟合更加复杂的线性模型。添加激活函数,来计算函数的非线性model.add(Dense(units=10, input_dim=1, activation='relu'))# 全连接层:输入一维数据,输出10个神经元
# model.add(Activation('tanh')) # 也可以直接在Dense里面加激活函数
model.add(Dense(units=1, activation='tanh')) # 全连接层:由于有上一层的添加,所以输入维度默认是10(可以不用写),输出1个值(要写)
# model.add(Activation('tanh'))# 自定义优化器SDG , 学习率默认是0.01(太小,导致要迭代好多次才能较好的拟合数据)
sgd = SGD(lr=0.3)
model.compile(optimizer=sgd, loss='mse')# 训练3000次数据
for step in range(3001):cost = model.train_on_batch(x_data, y_data)if step%500 == 0:print('cost: ',cost)# x_data输入神经网络中,得到预测值y_pred
y_pred = model.predict(x_data)# 显示随机点
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

总代码:

import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential按顺序构成的模型
from keras.models import Sequential
# Dense全连接层
from keras.layers import Dense,Activation
from keras.optimizers import SGD# 使用Numpy生成200个-0.5~0.5之间的值
x_data = np.linspace(-0.5, 0.5, 200)
noise = np.random.normal(0, 0.02, x_data.shape)# y_data= x_data**2 + noise
y_data = np.square(x_data) + noise # 效果与上面一致# 显示随机点
plt.scatter(x_data, y_data)
plt.show()# 建立一个顺序模型
model = Sequential()
# 1-10-1: 加入一个隐藏层(10个神经元):来拟合更加复杂的线性模型。添加激活函数,来计算函数的非线性model.add(Dense(units=10, input_dim=1, activation='relu'))# 全连接层:输入一维数据,输出10个神经元
# model.add(Activation('tanh')) # 也可以直接在Dense里面加激活函数
model.add(Dense(units=1, activation='tanh')) # 全连接层:由于有上一层的添加,所以输入维度默认是10(可以不用写),输出1个值(要写)
# model.add(Activation('tanh'))# 自定义优化器SDG , 学习率默认是0.01(太小,导致要迭代好多次才能较好的拟合数据)
sgd = SGD(lr=0.3)
model.compile(optimizer=sgd, loss='mse')# 训练3000次数据
for step in range(3001):cost = model.train_on_batch(x_data, y_data)if step%500 == 0:print('cost: ',cost)# x_data输入神经网络中,得到预测值y_pred
y_pred = model.predict(x_data)# 显示随机点
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

参考:

视频: 覃秉丰老师的“Keras入门”:http://www.ai-xlab.com/course/32
博客参考:https://www.cnblogs.com/XUEYEYU/tag/keras%E5%AD%A6%E4%B9%A0/

3. 使用Keras-神经网络来拟合非线性模型相关推荐

  1. keras神经网络回归预测_如何使用Keras建立您的第一个神经网络来预测房价

    keras神经网络回归预测 by Joseph Lee Wei En 通过李维恩 一步一步的完整的初学者指南,可使用像Deep Learning专业版这样的几行代码来构建您的第一个神经网络! (A s ...

  2. 从零开始学keras之过拟合与欠拟合

    在预测电影评论.主题分类和房价回归中,模型在留出验证数据上的性能总是在几轮后达到最高点,然后开始下降.也就是说,模型很快就在训练数据上开始过拟合.过拟合存在于所有机器学习问题中.学会如何处理过拟合对掌 ...

  3. Keras神经网络实现泰坦尼克号旅客生存预测

    Keras神经网络实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 训练 可视化 评估,预测 结果 代码 介绍 参考资料: 网易云课堂的深度学 ...

  4. Keras神经网络的学习与使用(1)

    Keras神经网络层学习与使用 Keras的简单介绍 Keras框架中的方法介绍 Compile()方法 fit()方法 summary()方法 evaluate()方法 perdict()方法 Ke ...

  5. Keras神经网络集成技术

    Keras神经网络集成技术 create_keras_neuropod 将Keras模型打包为神经网络集成包.目前,上文已经支持TensorFlow后端. create_keras_neuropod( ...

  6. 避免神经网络过拟合的5种技术(附链接) | CSDN博文精选

    作者 | Abhinav Sagar 翻译 | 陈超 校对 | 王琦 来源 | 数据派THU(ID:DatapiTHU) (*点击阅读原文,查看作者更多精彩文章) 本文介绍了5种在训练神经网络中避免过 ...

  7. 神经网络+过拟合+避免

    神经网络+过拟合+避免 作者:Abhinav Sagar & THU 最近一年我一直致力于深度学习领域.这段时间里,我使用过很多神经网络,比如卷积神经网络.循环神经网络.自编码器等等.我遇到的 ...

  8. 独家 | 避免神经网络过拟合的5种技术(附链接)

    作者:Abhinav Sagar 翻译:陈超 校对:王琦 本文约1700字,建议阅读8分钟. 本文介绍了5种在训练神经网络中避免过拟合的技术. 最近一年我一直致力于深度学习领域.这段时间里,我使用过很 ...

  9. 避免神经网络过拟合的5种技术

    作者:Abhinav Sagar 翻译:陈超 校对:王琦 本文约1700字,建议阅读8分钟. 本文介绍了5种在训练神经网络中避免过拟合的技术. 最近一年我一直致力于深度学习领域.这段时间里,我使用过很 ...

  10. 解决神经网络过拟合问题—Dropout方法、python实现

    解决神经网络过拟合问题-Dropout方法 一.what is Dropout?如何实现? 二.使用和不使用Dropout的训练结果对比 一.what is Dropout?如何实现? 如果网络模型复 ...

最新文章

  1. Python中使用Flask、MongoDB搭建简易图片服务器
  2. Apriori算法简介及实现(python)
  3. django 1.8 官方文档翻译:5-1-2 表单API
  4. spark学习-61-源代码:ShutdownHookManager虚拟机关闭钩子管理器
  5. centos java进程号_centos中分析java占用大量CPU资源的原因
  6. 【图像处理】基于matlab GUI图像形态学处理【含Matlab源码 1287期】
  7. 通过IP地址获取对方MAC地址的命令
  8. 【前端安全】web缓存投毒
  9. 波音承认 737MAX 飞行模拟器存在缺陷;韩国政府计划从 Win7 迁移到 Linux
  10. dell电脑更新win11后黑屏但有鼠标(已解决)
  11. ffmpeg学习十二:滤镜(实现视频缩放,裁剪,水印等)
  12. 利用注册表修改文件关联
  13. Microsoft Edge
  14. java months between,ORACLE函数MONTHS_BETWEEN
  15. Python之心算练习程序
  16. C++笔记:输入输出、变量、变量加减乘除
  17. Android Studio连接安卓手机驱动
  18. Java实验报告手写_java实验1实验报告(20135232王玥)
  19. 20155314 2016-2017-2 《Java程序设计》实验三 敏捷开发与XP实践
  20. 机器中的上帝-人工智能,冠状病毒,种族主义和宗教

热门文章

  1. BZOJ 4300: 绝世好题 动态规划
  2. 基于IDEA 最新Spirng3.2+hibernate4+struts2.3 全注解配置 登录
  3. 《机电传动控制》第三次作业
  4. apache 配置用户级目录
  5. Android获得全局进程信息以及进程使用的内存情况
  6. 按比例缩小图片的CSS代码
  7. oracle10g   RMAN增量备份策略
  8. 概率论与数理统计 习题篇
  9. response对象设置返回状态_爬虫代理之设置
  10. php pdo exec,PDO::exec