前言

大家好,我是Kay,小白一个。以下是我完成斯坦福 cs231n-assignment1-two_layer_net 这份作业的做题过程、思路、踩到的哪些坑、还有一些得到的启发和心得。希望下面的文字能对所有像我这样的小白有所帮助。

两层的网络处理方法与之前的SVM/Softmax 的处理方法类似,关键在于函数和梯度的计算。

TODO1:计算 scores

【思路】公式是:W2 * max(0, W1*x) 用代码实现之即可。

scores_hid = np.dot(X, W1) + b1

scores_hid = np.maximum(0, scores_hid)

scores =np.dot(scores_hid, W2) + b2

结果正确。

TODO2:计算 loss

【思路】用 softmax 的公式,不要忘记加上正则惩罚项。

scores = scores - np.max(scores, axis = 1,keepdims=True)

exp_sum = np.sum(np.exp(scores), axis = 1,keepdims=True)

loss = -np.sum(scores[range(N), y]) +np.sum(np.log(exp_sum))

loss = loss / N + 0.5 * reg * (np.sum(W1 *W1) + np.sum(W2 * W2))

【开始 Debug】

这个 bug 我是找得真是要气死自己啊!感觉公式都没记错啊!Run 了好几次 loss 还是不够小,郁闷之下跑去百度,贴了别人的代码结果 loss 也是这么多但是他们的结果很小,喵喵喵?我定眼一看,别人的 reg 都是 0.1!凭什么我给的是0.05。想哭、难受。

改成 0.1 后,果然我的代码也是正确的。

【思考提升】

我顺手看了下别人的代码,有的人没有对scores 做处理就开 e 的幂,这是不对的哦~小心数值被爆掉哦~

TODO3:利用反向传播计算梯度

【思路】画出“计算图”,一步步往回做,靠公式得到:

d_b2= 1 * d_scores

d_W2= h1 * d_scores (h1 是 max(0, f1) )

d_b1 = 1 * d_f1

d_W1= X * d_f1

prob = score / exp_sum

prob[range(N), y] -= 1

d_scores= np.dot(X.T, prob)

d_scores /= N

grads['b2'] = np.sum(d_scores, axis = 0)

grads['W2'] = np.dot(scores_hid.T, d_scores)

d_f1 = np.dot(d_scores, W2.T)

d_f1[scores_hid <= 0] = 0

grads['b1'] = np.sum(d_f1, axis = 0)

grads['W1'] =np.dot(X.T, d_f1)

【开始 Debug】这里我遇到了特别多错误,果然思考还是不够严谨。

1.    d_scores 求错了,不是 np.dot(X.T, prob),不能生搬以为是 softmax 里的 dW,d_scores 就是 prob!

2.    prob 也求错了,在 softmax 里,我分子上的 scores 是有做 e 幂的,但是这里还没处理就直接拿去用了,还是一处生搬旧思想的错误。

3.    两个 dW 没有加正则项。

【修改代码】

prob = np.exp(scores) / exp_sum

prob[range(N), y] -= 1

d_scores = prob / N

grads['b2'] = np.sum(d_scores, axis = 0)

grads['W2'] = np.dot(scores_hid.T, d_scores) + reg * W2

d_f1 = np.dot(d_scores, W2.T)

d_f1[scores_hid <= 0] = 0

grads['b1'] = np.sum(d_f1, axis = 0)

grads['W1'] =np.dot(X.T, d_f1) + reg * W1

结果正确。

【思考提升】其实像这些所谓的“小错”是很让人沮丧的,错的又不是大方向,找起bug 时又往往是从整体思路开始怀疑自己,因此找到这点小错是很耗费精力的,要怎么加快 debug 的效率呢?错的地方是思路、还是小瑕疵?这是应该训练的地方了。

TODO4:完成 train 函数和 predict 函数

【思路】

SGD:利用特定一张图像对我们的各个参数进行 update

y_pred就是谁的分数大就取那个标签作为 y_pred

三段代码都贴在这里了。

sample_indices = np.random.choice(range(num_train), batch_size)

X_batch = X[sample_indices]

y_batch =y[sample_indices]

self.params['W1'] -= learning_rate * grads['W1']

self.params['b1'] -= learning_rate * grads['b1']

self.params['W2'] -= learning_rate * grads['W2']

self.params['b2'] -= learning_rate * grads['b2']

y_pred = np.argmax(self.loss(X), axis=1)

结果Final training loss: 0.017143643532923747

TODO5:进行超参数的调参

【思路】这里的超参数有:hidden layer size, learning rate, numer of training epochs, andregularization strength

首先,我们还是先调最重要的lr 和 reg,接着再考虑其他超参数。

best_val = -1

input_size = 32 * 32 * 3

hidden_size = 100

num_classes = 10

net = TwoLayerNet(input_size, hidden_size,num_classes)

learing_rates = [1e-3, 1.5e-3, 2e-3]

regularizations = [0.2, 0.35, 0.5]

for lr in learing_rates:

for reg in regularizations:

stats = net.train(X_train, y_train, X_val, y_val,

num_iters=1500,batch_size=200,

learning_rate=lr,learning_rate_decay=0.95,

reg=reg, verbose=False)

val_acc = (net.predict(X_val) == y_val).mean()

if val_acc > best_val:

best_val = val_acc

best_net = net

print ("lr ",lr, "reg ", reg, "val accuracy:", val_acc)

print ("best validation accuracyachieved during cross-validation: ", best_val)

结果差强人意。

【思考提升】老师说:“Tuning the hyperparameters and developing intuition for how theyaffect the final performance is a large part of using Neural Networks.”可是就目前而言,我还是没有认识到超参数的“重要性”?所以,我需要训练一种超参如何影响神经网络的直觉。

总结

这份作业要求我们掌握正向传递分数函数和损失函数,同时反向传递梯度给每个变量。

Delta = “本地梯度”*“上沿梯度”

有趣的是,

变量间做“加法”,传回的梯度都是那份“上沿梯度”,相当于是一个广播器;

变量间做“max()”,传回的梯度是那份“上沿梯度”给最大的值,其他的梯度是0,相当于是一个路由器;

变量间做“乘法”,传回的梯度都是那份“上沿梯度”*对方本身的值,相当于是一个(带放大“上沿梯度”倍)交换器。

这三个典例,应该能帮助我们直观地理解 backpropagation。

最后,关于缩进,我已经放弃治疗了。

【学习笔记】cs231n-assignment1-two_layer_net相关推荐

  1. 转载CS231n课程学习笔记

    CS231n课程学习笔记 CS231n网易云课堂链接 CS231n官方笔记授权翻译总集篇发布 - 智能单元 - 知乎专栏 https://zhuanlan.zhihu.com/p/21930884 C ...

  2. CS231n 学习笔记(2)——神经网络 part2 :Softmax classifier

    *此系列为斯坦福李飞飞团队的系列公开课"cs231n convolutional neural network for visual recognition "的学习笔记.本文主要 ...

  3. CS231n 学习笔记(1)——神经网络 part1 :图像分类与数据驱动方法

    *此系列为斯坦福李飞飞团队的系列公开课"cs231n convolutional neural network for visual recognition "的学习笔记.本文主要 ...

  4. 深度学习总结——CS231n课程深度学习(机器视觉相关)笔记整理

    深度学习笔记整理 说明 基本知识点一:模型的设置(基本) 1. 激活函数的设置 2. 损失函数的设置 (1) 分类问题 (2) 属性问题 (3) 回归问题 3. 正则化方式的设置 (1) 损失函数添加 ...

  5. cs231n学习笔记——图像分类

    cs231n学习笔记--图像分类及代码实现 写在前面的废话 1.图像分类 2.数据驱动 3.图形分类流程 4.L1距离(曼哈顿距离) 5.L2距离(欧氏距离) 6. Nearest Neighbor分 ...

  6. CS231n 学习笔记(2)——神经网络 part2 :线性分类器,SVM

    *此系列为斯坦福李飞飞团队的系列公开课"cs231n convolutional neural network for visual recognition "的学习笔记.本文主要 ...

  7. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  8. 三元组法矩阵加法java_计算机视觉学习笔记(2.1)-KNN算法中距离矩阵的计算

    本笔记系列以斯坦福大学CS231N课程为大纲,海豚浏览器每周组织一次授课和习题答疑.具体时间地点请见微信公众号黑斑马团队(zero_zebra)和QQ群(142961883)发布.同时课程通过腾讯课堂 ...

  9. cs224n学习笔记

    *. Tips: loga(x)log_a(x)loga​(x)在机器学习中默认为ln(x)ln(x)ln(x) 0. 主要参考: CS224n专栏 AILearners/blog/nlp/cs224 ...

  10. 斯坦福cs224n教程--- 学习笔记1

    一.前言 自然语言是人类智慧的结晶,自然语言处理是人工智能中最为困难的问题之一,而对自然语言处理的研究也是充满魅力和挑战的. 通过经典的斯坦福cs224n教程,让我们一起和自然语言处理共舞!也希望大家 ...

最新文章

  1. axure动态登录和html5,Axure8原型设计实战案例:如何实现登录功能?
  2. LogMiner日志分析工具的使用
  3. 科大讯飞AIUI(1)
  4. 麦子的第一个注解+spring小案例 欢迎指点学习。
  5. 补丁更新选项的禁用与恢复
  6. 文末福利 | Python3 网络爬虫:老板,需要特殊服务吗?
  7. Thrall’s Dream HRBUST - 2048【BFS or 强连通分量】
  8. linux脚本登录启动失败,linux-从bash脚本启动进程失败
  9. php 后台运行函数,php守护进程函数 后台执行脚本的实例详解
  10. MyBatis Demo 编写(1)基础功能搭建
  11. python语言磁力搜索引擎源码公开,基于DHT协议,十二分有技术含量的技术博客...
  12. nbu备份软件异机恢复需要注意问题
  13. 今天解封了,该递交作业了,我做了个智能机器人
  14. 编写函数(fun),通过函数调用,输入存款金额和存款年限,计算到期总金额和利息。
  15. redirectType=Found和redirectType=Permanent哪个是301哪个是302?
  16. 快速学习html、css的经典笔记
  17. 关于汉字转化为简码的方法
  18. OsgEarth中设置模型运动路径,并绘制雷达扫描、动态实时绘制运动轨迹、跟随彩带
  19. 使用Glade3.0进行界面开发
  20. 数据大方送之全球10米土地利用数据

热门文章

  1. 计算机四级 信息安全工程师——计算机网络题库
  2. Adobe Flash Player30.0.0.113离线安装包
  3. SAP MM批次管理
  4. windows下的DataX的安装和使用教程
  5. 硬盘的接口,总线,协议知识点总结
  6. gensler逻辑学导论_学逻辑学,哪本书入门合适?
  7. vmware之VMware Remote Console (VMRC) SDK(二)
  8. 数据采集程序(网页小偷)点滴心得
  9. html如何嵌入手机,手机嵌入页面
  10. 零基础学习嵌入式:嵌入式linux视频教程免费分享