吴恩达机器学习ex1

  • 完整代码
  • 代码输出

完整代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt#获取一个文本的数据(即本例中的数据集)
path = r'D:\Project\Pycharm Project\py2022\ExData\ex1data1.txt'#数据集的存储地址
data = pd.read_csv(path, header = None, names = ['Population', 'Profit'])#定义获取数据的形式,一列为人口,一列为经济利润
data.head()#head方法用来显示表格前五行的数据
#显示并检查数据
data.plot(kind='scatter', x='Population', y='Profit', figsize=(12,8))
plt.show()# 接下来是定义计算代价函数
def cost_func(x, y, w, b):#获得每一点数据的误差(成本)cost_matrix = np.power(((x * w + b) - y), 2)  # 这个就是矩阵相乘得到一个n*1维的矩阵与n*1维的y矩阵进行相减的平方return np.sum(cost_matrix) / (2 * len(x))  # 利用sum函数将cost矩阵所有元素加起来,再除以2m(m为总共的数据数量)# 计算迭代之后的w和b参数,w为斜率,b为截距,alpha为学习率
def compute_new_wb(x, y, w, b, alpha):m = len(x)dj_dw = 0dj_db = 0 #用于计算梯度for i in range(m):#老参数w和b所计算的所有y值(y0,y1,y2 ……)f_wb_i = w * x[i] + b#偏导之后的第i个和项(具体为什么这样可以看吴恩达老师讲解线性回归的视频)dj_dw_i = (f_wb_i - y[i]) * x[i]dj_db_i = (f_wb_i - y[i])##进行累加dj_dw = dj_dw + dj_dw_idj_db = dj_db + dj_db_idj_dw = (1 / m) * dj_dwdj_db = (1 / m) * dj_db#此时的dj_dw,dj_db就为老参数情况下的偏导return w - alpha * dj_dw, b - alpha * dj_db   #返回新的w和b#获取X矩阵和Y矩阵,为了方便计算成本函数,进行转置变成了列向量
X = np.matrix(data['Population'].values).T
Y = np.matrix(data['Profit'].values).Tprint(compute_new_wb(X, Y, 0, 0, 0.001)) #测试一下w=0,b=0迭代之后的w和b
#初始将w和b设为1,学习率设为0.01
w = 1
b = 1
alpha = 0.01
#经过1500次迭代
for i in range(1500):w, b = compute_new_wb(X, Y, w, b, alpha)print('最后计算出的w和b:', w, ',', b)
print('最后的成本函数', cost_func(X, Y, w, b))# 接下来就是画图
# 在人口的最小值和最大值之间取100个点,返回一个矩阵
x = np.linspace(data.Population.min(), data.Population.max(), 100)
#w[0, 0]为数值,w[0][0]则为列表
y = w[0, 0] * x + b[0, 0] # 获得y矩阵fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x, y, 'r', label='Prediction')#使用plot方法绘制预测直线
ax.scatter(data.Population, data.Profit, label='Training Data')
ax.legend(loc = 2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()
#%%预测人口3500 和 7000时的经济收益
prediction1 = w[0, 0] * 3.5 + b[0, 0]
print("人口为3500时,小吃摊经济规模:", prediction1)
prediction2 = w[0, 0] * 7.0 + b[0, 0]
print("人口为7000时,小吃摊经济规模:", prediction2)

代码输出

打印出来的数据集

整个数据集分布图(注意:数据集必须下载到自己的本地,可以搜索其他博主https://blog.csdn.net/qq_39435411/article/details/109763239)


你也可以一开始改变w和b的值,例如w = 0,b = 0,不过最后预测出来的会有很小的误差

吴恩达机器学习ex1——通过人口预测小摊经济状况相关推荐

  1. 吴恩达机器学习ex1 Python实现

    ** ** 机器学习小白入门,在看吴恩达机器学习课程的同时找到了课后的练习.开贴用于记录学习(copy)过程. 学习参考:吴恩达机器学习ex1 单变量线性回归 题目描述:在本部分的练习中,您将使用一个 ...

  2. 吴恩达机器学习Ex1多元回归部分

    多元线性回归 提交作业情况: 背景:预测房价 数据集:房屋大小,卧室的数量,房价. Loading data ... First 10 examples from the dataset: x = [ ...

  3. 吴恩达机器学习Ex1

    本次是week2 Linear Regression 的作业情况. 作业得分情况 作业代码 通过执行ex1.m文件获得想要的结果,其他函数为该文件所调用. ex1.m文件 %% Machine Lea ...

  4. python分类预测降低准确率_python实现吴恩达机器学习练习3(多元分类器和神经网络)...

    Programming Exercise 3: Multi-class Classification and Neural Networks 吴恩达机器学习教程练习3,练习数据是5000个手写数字(0 ...

  5. 吴恩达机器学习中文版课后题(中文题目+数据集+python版答案)week1 线性回归

    一.单线性回归问题 参考:https://blog.csdn.net/qq_42333474/article/details/119100860 题目一: 您将使用一元线性回归来预测食品车的利润.假设 ...

  6. 1. 吴恩达机器学习课程-作业1-线性回归

    fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 1. ...

  7. 吴恩达机器学习作业ex2-python实现

    系列文章目录 吴恩达机器学习作业ex1-python实现 吴恩达机器学习作业ex2-python实现 吴恩达机器学习作业ex3-python实现 作业说明及数据集 链接:https://pan.bai ...

  8. 吴恩达机器学习视频作业(Matlab实现)

    吴恩达机器学习视频的课后作业,使用matlab实现 ex1  线性回归 1.热身 建立一个5*5矩阵 A=eye(5); 2.单变量的线性回归 需要根据城市人口数量,预测开小吃店的利润 数据在ex1d ...

  9. 吴恩达机器学习课后作业——线性回归(Python实现)

    1.写在前面 吴恩达机器学习的课后作业及数据可以在coursera平台上进行下载,只要注册一下就可以添加课程了.所以这里就不写题目和数据了,有需要的小伙伴自行去下载就可以了. 作业及数据下载网址:吴恩 ...

最新文章

  1. SpringBoot 快速开启事务(附常见坑点)
  2. 为什么LED灯会越用越暗?
  3. 衡量人体健康的“十大新标杆”
  4. 信息学奥赛 数论专题 2、带 余 除 法
  5. Android数据库升级、降级、创建(onCreate() onUpgrade() onDowngrade())的注意点
  6. js 点击闭包_【JS进阶】Javascript 闭包与Promise的碰撞
  7. android手机分享app,Android Pie如何快捷分享文件至特定App
  8. 一步一步学习Servlet之Session使用
  9. 使用Dockerfile构建Nginx,Tomcat,MySQL镜像
  10. 网络口碑Market,生来“苟且”?
  11. PR2018入门教程01-基础教程
  12. 大学生咖啡网页制作教程 表格布局网页模板 学生HTML静态美食网页设计作业成品 简单网页制作代码 学生美食网页作品免费设计
  13. html5 牧场游戏,手机QQ首批五款HTML5游戏名单 农场偷菜复活
  14. linux iozone测试工具,linux系列之常用工具:iozone测试磁盘性能
  15. 上传JSPX文件绕过网站后缀名检查
  16. 51Nod-1183-编辑距离
  17. 杭州地铁行业十四五发展可行性及投资机遇研究报告2022版
  18. Panel的基本用法
  19. Python 日期格式总结
  20. Python基础--集合创建、添加删除元素以及集合的交集、并集和差集运算

热门文章

  1. 小方块上升组成背景特效 html+css+js
  2. PHP再学习4—— slim框架学习和使用
  3. java正整数分解因数_java将一个正整数分解质因数
  4. 模式识别 | PRML概览
  5. 爬虫入门实战(如何分析页面和构建requests请求)
  6. css动画 翻开折叠生日贺卡
  7. win10/win1桌面图标锁定,防止桌面图标移动
  8. 力扣(104.101)补9.7
  9. C++ 加号运算符重载
  10. 【机器学习】07. 决策树模型DecisionTreeClassifier(代码注释,思路推导)