CNN 神经网络tricks 学习总结
TRICKS IN DEEP LEARNING
IN THIS DOC , ONLY WITH LITTLE BRIEF EXPLANATION, RECORED IN DAILY STUDY
Last update 2018.4.7
############################################################################
1、变量初始化
-----初始化变量
var = tf.Variable(tf.random_normal([2, 3], stddev=0.2, mean=0.0))
tf.random_normal()
tf.truncated_normal()
tf.random_uniform()
tf.random_gamma()
############################################################################
2、Loss Func
--A--交叉熵H(p,q)刻画的是两个概率分布之间的距离,常用于分类问题
-- y_表示真实值
cross_entropy = -tf.reduce_mean(
y_*tf.clip_by_value(y, 1e-10, 1.0)
)
cross_entropy = tf.reduce_mean( -tf.reduce_sum(y_ * tf.log( y), reduction_indices=[1]))
因为交叉熵一般会与softmax 回归一起使用,所以Tensorflow封装了函数
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y,y_)
得到softmax回归之后的交叉熵
--B---MSE 均方误差,常用于回归问题
-- y_表示真实值
mse = tf.reduce_mean(tf.square( y_ - y))
------自定义损失函数常用基本函数
tf.reduce_sum(); tf.select();tf.greater()
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
############################################################################
3、weights_with_L2_loss
def weights_with_loss(shape, wl=None):
"""
获取带有L2_Loss的权重, 并添加到collection loss 中
最后我们可以使用 loss = tf.add_n(tf.get_collection("loss"), name='total_loss')
计算出总体loss
weights_with_loss 一般不用于第一层和最后一层,多见于全连接层
:param shape: weights_shape
:param wl: weights_loss_ratio
:return: weights
"""
w = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.01, dtype=tf.float32))
if wl is not None:
weights_loss = tf.multiply(tf.nn.l2_loss(w), wl, name='weights_loss')
tf.add_to_collection("loss", weights_loss)
return w
############################################################################
4、batch_normallization
def batch_normalization(self, input, decay=0.9, eps=1e-5):
"""
Batch Normalization
Result in:
* Reduce DropOut
* Sparse Dependencies on Initial-value(e.g. weight, bias)
* Accelerate Convergence
* Enable to increase training rate
Usage: apply to (after)conv_layers
Args: output of convolution or fully-connection layer
Returns: Normalized batch
"""
shape = input.get_shape().as_list()
n_out = shape[-1]
beta = tf.Variable(tf.zeros([n_out]))
gamma = tf.Variable(tf.ones([n_out]))
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(input, [0])
else:
batch_mean, batch_var = tf.nn.moments(input, [0, 1, 2])
ema = tf.train.ExponentialMovingAverage(decay=decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(self.train_phase, mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
return tf.nn.batch_normalization(input, mean, var, beta, gamma, eps)
############################################################################
5、LRN
def LRN(x, R, alpha, beta, name=None, bias=1.0):
"""
LRN apply to (after)conv_layers
:param x: input_tensor
:param R: depth_radius
:param alpha: alpha in math formula
:param beta: beta in match formula
:param name:
:param bias:
:return:
"""
return tf.nn.local_response_normalization(x, depth_radius=R, alpha=alpha,
beta=beta, bias=bias, name=name)
############################################################################
5、gradient decent
----gradient decent & backpropagation
gradient decent :主要用于优化单个参数的取值 【所谓梯度就是一阶导数】
backpropagation: 给出了一个高效的方式在所有参数上使用梯度下降法
需要注意:
(1)gradient decent不能保证全局最优
(2)损失函数实在所有训练数据上的损失和,故gradient decent计算时间很长
|
^
gradient decent Adam(折中方式,每次计算一个batch的损失函数和) SGD
############################################################################
6、learning rate
-----learning_rate 决定了参数每次更新的幅度
-----decayed_learning_rate:
global_step = tf.Variable(0,tf.int32)
learning_rate = tf.train.exponential_decay(
0.1, global_step, 100, 0.96, staircase = True)
.....
learning_rate = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step )
每100轮过后 lr 乘以 0.96
############################################################################
7、full connection
----经典全连接层:
tf.nn.relu(tf,matmul(x,w)+biases)
----全连接层一般会和 dropout连用, 防止过拟合
############################################################################
8、PCA
def RGB_PCA(images):pixels = images.reshape(-1, images.shape[-1])idx = np.random.random_integers(0, pixels.shape[0], 1000000)pixels = [pixels[i] for i in idx]pixels = np.array(pixels, dtype=np.uint8).Tm = np.mean(pixels)/256.C = np.cov(pixels)/(256.*256.)l, v = np.linalg.eig(C)return l, v, mdef RGB_variations(image, eig_val, eig_vec):a = np.random.randn(3)v = np.array([a[0]*eig_val[0], a[1]*eig_val[1], a[2]*eig_val[2]])variation = np.dot(eig_vec, v)return image + variationl,v,m = RGB_PCA(img)
img = RGB_variations(img,l,v)
imshow(img)
CNN 神经网络tricks 学习总结相关推荐
- cnn是深度神经网络吗,cnn神经网络算法
1.深度学习和有效学习的区别 深度学习和有效学习的区别分别是: 1.深度学习是:Deep Learning,是一种机器学习的技术,由于深度学习在现代机器学习中的比重和价值非常巨大,因此常常将深度学习单 ...
- 马里奥AI实现方式探索 ——神经网络+增强学习
首先,对于实现马里奥AI当中涉及到的神经网络和增强学习的相关概念进行整理,之后对智能通关的两种方式进行阐述.(本人才疏学浅,在神经网络和增强学习方面基本门外汉,如有任何纰漏,还请大神指出,我会第一时间 ...
- pytorch卷积神经网络_资源|卷积神经网络迁移学习pytorch实战推荐
点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 一.资源简介 这次给大家推荐一篇关于卷积神经网络迁移学习的实战资料,卷积神经网络迁移学 ...
- 对比图像分类五大方法:KNN、SVM、BPNN、CNN和迁移学习
选自Medium 机器之心编译 参与:蒋思源.黄小天.吴攀 图像分类是人工智能领域的基本研究主题之一,研究者也已经开发了大量用于图像分类的算法.近日,Shiyu Mou 在 Medium 上发表 ...
- CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化
AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...
- 【图神经网络】图神经网络(GNN)学习笔记:图分类
图神经网络GNN学习笔记:图分类 1. 基于全局池化的图分类 2. 基于层次化池化的图分类 2.1 基于图坍缩的池化机制 1 图坍缩 2 DIFFPOOL 3. EigenPooling 2.2 基于 ...
- 手写汉字数字识别详细过程(构建数据集+CNN神经网络+Tensorflow)
手写汉字数字识别(构建数据集+CNN神经网络) 期末,P老师布置了一个大作业,自己构建数据集实现手写汉字数字的识别.太捞了,记录一下过程.大概花了一个下午加半个晚上,主要是做数据集花时间. 一.构建数 ...
- 《机器学习》第四章 人工神经网络 深度学习启蒙篇
神经网络是一门重要的机器学习技术.它是目前最为火热的研究方向--深度学习的基础.学习神经网络不仅可以让你掌握一门强大的机器学习方法,同时也可以更好地帮助你理解深度学习技术. 本文以一种简单的,循序的方 ...
- 基于TensorFlow实现的CNN神经网络 花卉识别系统Demo
基于TensorFlow实现的CNN神经网络 花卉识别系统Demo Demo展示 登录与注册 主页面 模型训练 识别 神经网络 训练 Demo下载 Demo展示 登录与注册 主页面 模型训练 识别 神 ...
最新文章
- php控制css,div控制css样式
- linux 自学系列:wc命令
- android垂直公告,【Android之垂直翻页公告】
- 分解和合并:Java 也擅长轻松的并行编程!
- MySQL 基本数据类型
- RDLC 示例 文章 1
- git21天打卡-day8 本地分支push到远程服务器
- 在苹果mac中使用excel时,如何快速求和多行数值?
- 小知识--局域网内的文件共享
- 全自动mysql数据监控平台_Prometheus+Grafana打造Mysql监控平台
- 数据抽取oracle_【跟我学】特征抽取算法与应用
- Linux - vim编辑器,tmux的简单使用
- css video 样式,css自定义video播放器样式的方法
- 老笔记本_Win7_U盘_ReadyBoost
- 华为手机怎样无线与电脑连接服务器,华为手机如何与电脑远程连接服务器
- bat脚本执行sql脚本
- 今天玩了一款游戏,真不错哦,英文的
- @media 的使用规范
- Blender Rigify版Walker绑定下载
- linux系统的格式化说明,格式化[说明]如何用LINUX命令格式化U盘
热门文章
- TED演讲——人生的12条法则
- Google怎么做(1.相关提示)
- hdu 1392 Surround the Trees
- (~解题报告~)L1-019 谁先倒 (15分) ——17行代码AC
- 算法竞赛入门经典(第二版) | 习题3-10 盒子 (pair结构体)(UVa1587,Box)
- ArrayList方法源码
- 编码utf-8的不可映射字符_建议永远不要在MySQL中使用UTF8
- 路由重分发中尽然忘记了这件事
- github 头像生成 java_Java 如何根据头像地址生成圆形的头像?
- mysql 列目录_Linux ls命令:查看目录下文件