every blog every motto: Never say die

0. 前言

本节实战tf.GradientTape与tf.keras结合使用

1. 代码部分

1. 导入模块

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import kerasprint(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)

2. 读取数据

from sklearn.datasets import fetch_california_housing# 房价预测
housing = fetch_california_housing()
print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

3. 划分样本

# 划分样本
from sklearn.model_selection import train_test_splitx_train_all,x_test,y_train_all,y_test = train_test_split(housing.data,housing.target,random_state=7)
x_train,x_valid,y_train,y_valid = train_test_split(x_train_all,y_train_all,random_state=11)print(x_train.shape,y_train.shape)
print(x_valid.shape,y_valid.shape)
print(x_test.shape,y_test.shape)

4. 数据归一化

# 归一化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

5. metric的使用

# metric 的使用metric = keras.metrics.MeanSquaredError()
print(metric([5.],[2.]))
print(metric([0.],[1.]))
print(metric.result())metric.reset_states() # 不累加数据
metric([1.],[3.])
print(metric.result())

6. 搭建模块与训练

# fit 中做的事
# 1. batch 遍历数据集 metric
#   1.1 自动求导
# 2. epoch 结束 验证集 metric# 准备工作
epochs = 100
batch_size = 32
steps_per_epoch = len(x_train_scaled) // batch_size
optimizer = keras.optimizers.SGD()
metric = keras.metrics.MeanSquaredError()def random_batch(x,y,batch_size=32):idx = np.random.randint(0,len(x),size=batch_size)return x[idx],y[idx]# 搭建模型
model = keras.models.Sequential([keras.layers.Dense(30,activation='relu',input_shape=x_train.shape[1:]),keras.layers.Dense(1),])for epoch in range(epochs):metric.reset_states()for step in range(steps_per_epoch):x_batch,y_batch = random_batch(x_train_scaled,y_train,batch_size)with tf.GradientTape() as tape:y_pred = model(x_batch)loss = tf.reduce_mean(keras.losses.mean_squared_error(y_batch,y_pred))metric(y_batch,y_pred)grads = tape.gradient(loss,model.variables)grads_and_vars = zip(grads,model.variables)optimizer.apply_gradients(grads_and_vars)print('\rEpoch',epoch,'train mse:',metric.result().numpy(),end='')y_valid_pred = model(x_valid_scaled)valid_loss = tf.reduce_mean(keras.losses.mean_squared_error(y_valid_pred,y_valid))print('\t','valid mse:',valid_loss.numpy())

从零基础入门Tensorflow2.0 ----三、11. tf.GradientTape与tf.keras结合使用相关推荐

  1. 九宫怎么排列和使用_剪映零基础入门教程第三十七篇:一学就会系列之九宫格小程序配音...

    很多玩儿抖音的朋友都看过九宫格视频,但是并不是每个玩抖音的人都会制作这个九宫格视频,实际这个需要借助小工具来帮忙,而常用抖音的朋友们会对剪映更加熟悉一些,且九宫格视频在剪映内的制作方式则比较简单.那么 ...

  2. 视频编码零基础入门(0):零基础,史上最通俗视频编码技术入门

    [来源申明]本文引用了微信公众号"鲜枣课堂"的<视频编码零基础入门>文章内容.为了更好的内容呈现,即时通讯网在引用和收录时内容有改动,转载时请注明原文来源信息,尊重原作 ...

  3. Java好学吗?零基础入门Java,三个就业方向实现月入过万!

    Java好学吗?零基础入门Java容易吗?据统计,这是很多人学习前最常问也是最关心的问题之一. 不可否认,大家在开始接受新事物的时候都会陷入困境,但学习是循序渐进的,零基础入门Java到底难不难,只有 ...

  4. Pr零基础入门指南笔记三-------------视频效果与转场

    目录 精剪第一步--视频效果与转场 1.三大面板 2.位置 3.效果库 4.常用视频效果 [干货]PR零基础入门指南第四集:PR常用的效果和转场,视频防抖.宽银幕效果.设置默认效果等_哔哩哔哩_bil ...

  5. 日语零基础入门至初级“三步走”

    被日本动漫.日剧深深迷住,但看见似熟非熟的汉字假名却无从下手?!没关系,沪江网校推出日语入门系列班级,要想日语零基础入门你只需要"三步走"!首先,我们来了解看看什么是"三 ...

  6. SQL零基础入门学习(三)

    SQL零基础入门学习(二) SQL WHERE 子句 WHERE 子句用于提取那些满足指定条件的记录. SQL WHERE 语法 SELECT column1, column2, ... FROM t ...

  7. 分支程序设计02 - 零基础入门学习C语言11

    第四章:分支程序设计02 让编程改变世界 Change the world by program if语句 用if语句可以构成分支结构.它根据给定的条件进行判断,以决定执行某个分支程序段.C语言的if ...

  8. 零基础入门学习Python(11)-列表(3)

    列表的一些常用操作符 比较操作符 逻辑操作符 连接操作符 重复操作符 当有多个元素时,默认是从第0个元素比较的 字符串比较的是每一个字符对应的ASCII码值的大小 什么是ASSII码? 是Americ ...

  9. 【转】Dynamics CRM 365零基础入门学习(三)Dynamics 通过Web API 来调用自定义的Action(使用插件)

    今天想实现一个Search Product的功能,首先要将数据展示在页面,然后前端根据查询需求进行处理.之前是在salesforce中实现的,可以定义一个Search Product的页面,然后在页面 ...

  10. Arduino Uno零基础入门学习笔记——三针脚声音传感器

    一.电路接线 声音传感器 声音传感器引脚 Arduino引脚 VCC 5V GND GND OUT 6 LED LED引脚 Arduino引脚 正极 8 GND GND 二.代码 int val; i ...

最新文章

  1. 利用源代码搭建lnmp环境
  2. WebRTC各种资料集合
  3. Swift 十进制二进制转换 (How to convert a decimal number to binary in Swift)
  4. 积性函数与Dirichlet卷积 学习小记
  5. 洛谷 [P1352] 没有上司的舞会
  6. 业余爱好者linux_如何从业余爱好者变成专业开发人员
  7. libpng warning: iCCP: known incorrect sRGB profile
  8. YII 规则rule 里面 min,max 提示错误信息
  9. 497.非重叠矩形中的随机点
  10. 重写和重载的区别和理解
  11. cosx的麦克劳林级数是多少_余弦函数的泰勒级数
  12. 【年终总结】——回忆过往,不畏将来
  13. instrument Time Profiler总结
  14. 我在2022北大夏令营被吊起来打
  15. 计算机考研没有获奖没有科研难吗,大学期间没有什么获奖经历和科研成果, 对考研的影响大吗?...
  16. 拥有一台云服务器可以干什么?
  17. U盘中毒之后打不开怎么办
  18. 如何更合规、更安全的使用人脸识别进行身份管理
  19. vue使用watch监听拿到props的传值
  20. Xcode8.1 真机测试 ,添加iOS10.3的idk到Xcode8.1中

热门文章

  1. 每日一道剑指offer-二叉树的镜像
  2. mysql c3p0 释放连接池_mysql – 如何阻止c3p0连接池隐藏连接异常的原因?
  3. cannot+connect+mysql_mysqlnd cannot connect to MySQL 4.1+ using the old insecure
  4. Linux环境下安装Hadoop(完全分布式)
  5. 华为智慧屏 鸿蒙如何获得,荣耀智慧屏得鸿蒙助力,玩法超多
  6. java实现用户分组,java实现分组算法,根据每组多少人来进行分组
  7. C#:访问web.config中的常量
  8. Javascript特效:进度条
  9. Linux 系统中随机数在 KVM 中的应用
  10. CVPR 2021奖项出炉:最佳论文花落马普所,何恺明获提名,首届黄煦涛纪念奖颁布