神经网络解常微分方程(ODE)
原文
1 原理简介
微分方程可以写成2部分:
- 第一部分满足初始和边界条件并包含不可调节参数
- 第二部分不会影响第一部分,这部分涉及前馈神经网络,包含可调节参数(权重)。
因此在构建微分方程的函数时,要满足上述两个条件,今天就来简单看下。
假设存在以下微分方程:
上述微分方程f
对应着一个函数u(t)
,同时满足初始条件u(0)=u_0
,为此可以令:
则NN(t)
的导数为:
根据以上等式,NN(t)
的导数近似于:
可以把上式转换成损失函数:
简而言之,就是已知微分函数,然后用神经网络去拟合该微分函数的原函数,然后用微分公式作为损失函数去逼近原微分函数。
微分公式:
此外,还需要将初始条件考虑进去:
上述并不是一个好的方法,损失项越多会影响稳定性。为此会定义一个新函数,该函数要满足初始条件同时是t
的函数:
则损失函数为:
注意,神经微分网络目前主要是去近似一些简单的微分函数,复杂的比较消耗时间以及需要高算力。
2 实践
假设存在下述微分函数和网络:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npnp.random.seed(123)
tf.random.set_seed(123)"""微分初始条件以及相应参数定义"""
f0 = 1 # 初始条件 u(0)=1# 用于神经网络求导,无限小的小数
inf_s = np.sqrt(np.finfo(np.float32).eps) learning_rate = 0.01
training_steps = 500
batch_size = 100
display_step = training_steps/10"""神经网络参数定义"""
n_input = 1 # 输入维度
n_hidden_1 = 32 # 第一层输出维度
n_hidden_2 = 32 # 第二层输出维度
n_output = 1 # 最后一层输出维度
weights = {'h1': tf.Variable(tf.random.normal([n_input, n_hidden_1])),
'h2': tf.Variable(tf.random.normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random.normal([n_hidden_2, n_output]))
}
biases = {'b1': tf.Variable(tf.random.normal([n_hidden_1])),
'b2': tf.Variable(tf.random.normal([n_hidden_2])),
'out': tf.Variable(tf.random.normal([n_output]))
}
"""优化器"""
optimizer = tf.optimizers.SGD(learning_rate)"""定义模型和损失函数"""
"""多层感知机"""
def multilayer_perceptron(x):x = np.array([[[x]]], dtype='float32')layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])layer_1 = tf.nn.sigmoid(layer_1)layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])layer_2 = tf.nn.sigmoid(layer_2)output = tf.matmul(layer_2, weights['out']) + biases['out']return output"""近似原函数"""
def g(x):return x * multilayer_perceptron(x) + f0"""微分函数"""
def f(x):return 2*x"""定义损失函数逼近导数"""
def custom_loss():summation = []# 注意这里,没有定义数据,根据函数中t的范围选取了10个点进行计算for x in np.linspace(0,1,10):dNN = (g(x+inf_s)-g(x))/inf_ssummation.append((dNN - f(x))**2)return tf.reduce_mean(tf.abs(summation))"""训练函数"""
def train_step():with tf.GradientTape() as tape:loss = custom_loss()trainable_variables=list(weights.values())+list(biases.values())gradients = tape.gradient(loss, trainable_variables)optimizer.apply_gradients(zip(gradients, trainable_variables))"""训练模型"""
for i in range(training_steps):train_step()if i % display_step == 0:print("loss: %f " % (custom_loss()))"""绘图"""
from matplotlib.pyplot import figure
figure(figsize=(10,10))# True Solution (found analitically)
def true_solution(x):return x**2 + 1X = np.linspace(0, 1, 100)
result = []
for i in X:result.append(g(i).numpy()[0][0][0])S = true_solution(X)
plt.plot(X, S, label="Original Function")
plt.plot(X, result, label="Neural Net Approximation")
plt.legend(loc=2, prop={'size': 20})
plt.show()
参考:
https://towardsdatascience.com/using-neural-networks-to-solve-ordinary-differential-equations-a7806de99cdd
神经网络解常微分方程(ODE)相关推荐
- 【Matlab】一、解常微分方程ODE
文章目录 求解常微分方程 ODE (1)求解解析解 (2)求解数值解 求解常微分方程 ODE 在matlab中,我们可以求解常微分方程的解析解,和数值解,一般使用dsolve来求解常微分方程的解析 ...
- 求微分方程用c语言怎么表达,使用C语言解常微分方程 C ODE
. 解常微分方程 名字:文森 年级:2010,学号:1033 * * *组编号:5(组),4(大组) 1.数值方法: 我们的实验目标是求解常微分方程,包括几类问题.一阶常微分初值问题,高阶常微分初值问 ...
- 图神经网络解偏微分方程系列(一)
图神经网络解偏微分方程系列(一) 1. 标题和概述 Learning continuous-time PDEs from sparse(稀疏) data with graph neural netwo ...
- matlab解常微分方程
ODE 常微分方程ordinary differential equation的缩写,此种表述方式常见于编程,如MATLAB中Simulink求解器solver已能提供了7种微分方程求解方法:ode4 ...
- python解常微分方程
一.sympy.dsolve 首先,感觉最科学的是用sympy的dsolve解常微分方程,直接贴代码 import sympy as sydef differential_equation(x,f): ...
- Matlab 解常微分方程常用工具包
最近使用RNN网络解决优化问题时需要解常微分方程,最开始使用matlab包ode45解,发现在某个数据集中出现ode45跑不出结果的情况(不报错)经过搜索发现,ode45只能用来解决非刚性的常微分方程 ...
- C言语实现半隐式Euler解常微分方程(附完整源码)
实现半隐式Euler解常微分方程 实现以下几个相关接口 实现半隐式Euler解常微分方程的完整源码(定义,实现,main函数测试) 实现以下几个相关接口 void problem(const doub ...
- [Matlab科学计算] 四阶Runge-Kutta法解常微分方程
四阶Runge-Kutta法格式的详细推导请查找相关数值分析书籍,这里直接给出四阶Runge-Kutta法的经典格式和Matlab代码 Matlab代码如下:自行修改常微分方程即可 %% 四阶Rung ...
- Adams隐式4阶方法解常微分方程,python实现
Adams隐式4阶方法解常微分方程,由4阶Runge-Kutta方法提供初值,隐式方法比显式复杂一些,主要是因为需要解方程.这里使用弦截法解微分方程. import math import numpy ...
最新文章
- T1187 强制 NTLM 认证
- php中redis怎么使用,redis 怎么使用
- Android IOC模块,利用了Java反射和Java注解
- 小米air耳机重新配对_小米发布 399 元真无线蓝牙耳机,除了小爱同学还支持其他手机语音助手...
- awk 系列Part5:如何使用 awk 复合表达式
- Android实现3D旋转效果
- freeCodeCamp纳什维尔十月聚会回顾
- Makefile之嵌套执行(9)
- 计算机句法分析的研究现状,计算机理论论文融合语义和句型信息的中文句法分析方法研究与实现...
- 如何设置二进制某一位的值_mysql参数设置--max_allowed_packet 值如何调整?
- JAVA环境变量配置步骤及测试(JDK的下载、安装和环境配置教程)
- sketchup 计算机配置,草图大师2020对电脑配置要求
- 医院信息系统(HIS系统)如何接入短信/语音功能
- freeswitch简介
- maven profile <filtering>true</filtering>的作用
- Visio中旋转文本框与箭头平行
- 贝叶斯网络的联合概率到底有什么用:贝叶斯理论(4)
- 如何离线安装所有依赖包
- 简单模拟struts框架,了解strusts的框架实现机制
- MuleSoft知识总结-13.Mule组件(Set Variable,For Each,Choice)