本文还是以MNIST的CNN分析为例

loss函数一般有MSE均方差函数、交叉熵损失函数,说明见

https://blog.csdn.net/John_xyz/article/details/61211422

另外一部分为正则化部分,这里实际上了解图像的会理解较深,就是防止过拟合的一些方式,符合图像先验的正则化项会给图像恢复带来很大的效果,简单讲神经网络常见的正则化则是

1.对权重加入L2-norm或L1-norm

2.dropout

3.训练数据扩增

可以看

https://blog.csdn.net/u012162613/article/details/44220115

见修改的代码:

#tf可以认为是全局变量,从该变量为类,从中取input_data变量
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
#读取数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#这里用CNN方法进行训练
#函数定义部分
def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)#随机权重赋值,不过truncated_normal代表如果是2倍标准差之外的结果重新选取该值return tf.Variable(initial)def bias_variable(shape):initial=tf.constant(0.1,shape=shape)#偏置项return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#SAME表示输出补边,这里输出与输入尺寸一致def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')#ksize代表池化范围的大小,stride为扫描步长# 这里是变量的占位符,一般是输入输出使用该部分
x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder("float",[None,10])
x_image=tf.reshape(x,[-1,28,28,1])#-1表示自动计算该维度
#建立第一层
W_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
#第二层
W_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)#第三层,而且这里是全连接层
W_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
#dropout,注意这里也是有一个输入参数的,和x以及y一样
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)# 评价函数,该部分是今天想要着重说明的部分
# 常用的函数有方差代价函数,交叉熵代价函数
#cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_conv, y_)为交叉熵的用法,其中y_conv应该是没有经过softmax的,这里的y_conv=tf.matmul(h_fc1_drop,W_fc2)+b_fc2
#具体差异见https://blog.csdn.net/John_xyz/article/details/61211422
#cross_entropy =tf.nn.sparse_softmax_cross_entropy_with_logits(y_conv, y)也是交叉熵,差异在于这里的标签可以认为是排它的
#方差函数:mse = tf.reduce_mean(tf.square(y_- y_conv))
#分类型的优化函数loss = tf.reduce_sum(tf.select(tf.greater(v1,v2),loss1,loss2)),代表v1>=v2时,使用loss1函数,否则使用loss2函数
#应用场景:危险品的鉴别#第二种,抑制过拟合
#常用方法,加入权重L1正则项、L2正则项、dropout、训练数据扩展
#可以参考网址https://blog.csdn.net/u012162613/article/details/44220115
#我们这里使用L2正则化#L2的正则化项,一些具体的细节使用见https://www.jianshu.com/p/6ffd815e2d11
tf.add_to_collection(tf.GraphKeys.WEIGHTS,W_fc1)
tf.add_to_collection(tf.GraphKeys.WEIGHTS,W_fc2)
regularizer = tf.contrib.layers.l2_regularizer(scale=100/50000)#这里需要和你输入的样品数成正比
reg_term = tf.contrib.layers.apply_regularization(regularizer)cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))+reg_term
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# 启动模型,Session建立这样一个对象,然后指定某种操作,并实际进行该步
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)#数据读取部分
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(50)#这里貌似是代表读取50张图像数据#run第一个参数是fetch,可以是tensor也可以是Operation,第二个feed_dict是替换tensor的值'''if i % 10 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})print("step:%d,accuracy:%g" % (i, train_accuracy))'''sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5})#sess.run第一个参数是想要运行的位置,一般有train,accuracy,initdeng#第二个参数feed_dict,一般是输入参数,该代码里有x,y以及drop的参数if i%20==0 :print(i)print("train accuracy:%g"%sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5}))
print("test accuracy:%g"%sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1}))

运行结果:

和之前的结果对比一下:提高了0.3%,不知道这算不算提高。。。其实在正则项的超参数选择不好时,一般结果会比较差,一些超参数的选择可以见https://blog.csdn.net/u012162613/article/details/44265967

tensorflow学习(4.loss函数以及正则化的使用 )相关推荐

  1. 关于机器学习 Machine Learning中loss函数参数正则化的一点思考

    1 致谢 感谢 Andrew Ng教授的讲述! 2 前言 今天在学习机器学习中对loss函数中的参数进行正则化~ 3 关于机器学习中loss函数参数正则化 在机器学习中,有一项防止过拟合的技巧就是(参 ...

  2. tensorflow分类的loss函数_Tensorflow Keras的loss函数总结

    一.二分类与多分类交叉熵损失函数的理解 交叉熵是分类任务中的常用损失函数,在不同的分类任务情况下,交叉熵形式上有很大的差别, 二分类任务交叉熵损失函数: 多分类任务交叉熵损失函数: 这两个交叉熵损失函 ...

  3. tensorflow分类的loss函数_Tensorflow入门教程(三十三)——图像分割损失函数FocalLoss...

    常见的图像分割损失函数有交叉熵,dice系数,FocalLoss等.今天我将分享图像分割FocalLoss损失函数及Tensorflow版本的复现. 1.FocalLoss介绍 FocalLoss思想 ...

  4. tensorflow学习之常用函数总结:tensorflow官方例子中的诸如tf.reduce_mean()这类函数

    前言 tensorflow官网给的例子用到了很多函数,然后并没有具体说明,还要自己去翻文档,有些函数是很常用的,下面来一一总结. 正文 一,tensorflow中有一类在tensor的某一维度上求值的 ...

  5. tensorflow分类的loss函数_tensorflow 分类损失函数使用小记

    多分类损失函数 label.shape:[batch_size]; pred.shape: [batch_size, num_classes] 使用 tf.keras.losses.sparse_ca ...

  6. tensorflow分类的loss函数_tensorflow中loss函数

    交叉熵函数 1)sigmoid_cross_entropy_with_logits(二分类问题) 输入是logits和targets,logits就是神经网络模型中的 W * X矩阵,不需要经过sig ...

  7. tensorflow学习之常用函数总结:tensorflow.placeholder()函数

    tensorflow.placeholder()函数 tensorflow.placeholder(dtype, shape=None, name=None) 此函数可以理解为形参,用于定义过程,在执 ...

  8. tensorflow学习之常用函数总结:tensorflow.cast()函数

    tensorflow.cast()类型转换函数     tf.cast(x, dtype, name=None)     此函数是类型转换函数     参数 x:输入 dtype:转换目标类型 nam ...

  9. tensorflow学习之常用函数总结:tensorflow.argmax()函数

    tensorflow.argmax()函数 tf.argmax(input, axis=None, name=None, dimension=None) 此函数是对矩阵按行或列计算最大值 参数 inp ...

最新文章

  1. HA: SHERLOCK 靶机渗透取证
  2. do{ ...}while(0)应用技巧
  3. boost::contract模块实现private protected的测试程序
  4. Java微服务(二)【idea中文插件安装】(手把手编写,超级详细)
  5. linux内核模块间通信
  6. 【Flink】Flink 1.12.2 TaskSlot
  7. Output argument fuse (and maybe others) not assigned during call to
  8. 职场“35岁危机”:这是我看过的最棒建议
  9. 3. mysql的注解驱动的三种方式_注册 Jdbc 驱动程序的三种方式及Class.forName 的作用...
  10. Mac上有没有好用的WiFi无线网络管理工具?看这里
  11. X-Pacific / Elasticsearch-ESClientRHL
  12. python死循环_Python for死循环
  13. 【Windows 逆向】CE 地址遍历工具 ( CE 结构剖析工具 | 人物数据内存结构 | 人物三维坐标数据分析 )
  14. IDEA展示隐藏文件夹
  15. Counting Bloom Filter
  16. 为海思u-boot快速生成reg_info.bin文件
  17. 第一次搭建React项目以及错误:getaddrinfo ENOTFOUND registry.npmjs.org解决办法
  18. iOS7.0.4完美越狱后safari闪退解决方法
  19. BMS(Battery Management System)是什么?
  20. 强大的CSS:颜色、背景和剪切

热门文章

  1. 卧槽,B站联名键盘!
  2. 指数随机变量 泊松过程跳_《常见随机过程》(一)
  3. python开发闹钟_「玩转树莓派」为女朋友打造一款智能语音闹钟
  4. (JAVA学习笔记) 关于稀疏数组
  5. 聊聊买卖股票的最佳时机
  6. 毕业后两三月的本科毕业生,他们都怎么样了
  7. 爬取虎牙之一:(王者荣耀主播信息普通爬取)
  8. Android开启adb
  9. 计算机用户删除 cmd,用命令行删XP中的用户。除administrator外
  10. php判断表单修改内容,JavaScript判断用户是否对表单进行了修改的方法_javascript技巧...