实现的是预测 低 出生 体重 的 概率。
尼克·麦克卢尔(Nick McClure). TensorFlow机器学习实战指南 (智能系统与技术丛书) (Kindle 位置 1060-1061). Kindle 版本.

# Logistic Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
#  y = 0 or 1 = low birth weight
#  x = demographic and medical history dataimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csvops.reset_default_graph()# Create graph
sess = tf.Session()###
# Obtain and prepare data for modeling
#### Set name of data file
birth_weight_file = 'birth_weight.csv'# Download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'birth_file = requests.get(birthdata_url)birth_data = birth_file.text.split('\r\n')birth_header = birth_data[0].split('\t')birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]with open(birth_weight_file, 'w', newline='') as f:writer = csv.writer(f)writer.writerow(birth_header)writer.writerows(birth_data)f.close()# Read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:csv_reader = csv.reader(csvfile)birth_header = next(csv_reader)for row in csv_reader:birth_data.append(row)birth_data = [[float(x) for x in row] for row in birth_data]# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])# Set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)# Split data into train/test = 80%/20%
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]# Normalize by column (min-max norm)
def normalize_cols(m):col_max = m.max(axis=0)col_min = m.min(axis=0)return (m-col_min) / (col_max - col_min)x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))###
# Define Tensorflow computational graph¶
#### Declare batch size
batch_size = 25# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)###
# Train model
#### Initialize variables
init = tf.global_variables_initializer()
sess.run(init)# Actual Prediction
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)# Training loop
loss_vec = []
train_acc = []
test_acc = []
for i in range(15000):rand_index = np.random.choice(len(x_vals_train), size=batch_size)rand_x = x_vals_train[rand_index]rand_y = np.transpose([y_vals_train[rand_index]])sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})loss_vec.append(temp_loss)temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})train_acc.append(temp_acc_train)temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})test_acc.append(temp_acc_test)if (i+1)%300==0:print('Loss = ' + str(temp_loss))###
# Display model performance
#### Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

转载于:https://www.cnblogs.com/bonelee/p/8996496.html

tensorflow 实现逻辑回归——原以为TensorFlow不擅长做线性回归或者逻辑回归,原来是这么简单哇!...相关推荐

  1. 逻辑回归(Logistic Regression):线性回归与逻辑回归的来龙去脉

    文章目录 Intro Logistic Regression 1. 回归的预测形式 1.1 线性回归的单变量形式 1.2 线性回归的多变量形式 1.3 逻辑回归:将线性回归转化为概率模型 1.4 逻辑 ...

  2. lasso回归_一文读懂线性回归、岭回归和Lasso回归

    (图片由AI科技大本营付费下载自视觉中国) 作者 | 文杰 编辑 | yuquanle 本文介绍线性回归模型,从梯度下降和最小二乘的角度来求解线性回归问题,以概率的方式解释了线性回归为什么采用平方损失 ...

  3. 学习笔记1:线性回归和逻辑回归、AUC

    复习笔记1--线性回归和逻辑回归 文章目录 复习笔记1--线性回归和逻辑回归 一.机器学习基本概念 1.1 什么是模型 1.2 极大似然估计 1.3为啥使用梯度下降法求解 1.4 梯度下降法本质 1. ...

  4. [机器学习-实践篇]学习之线性回归、岭回归、Lasso回归,tensorflow实现的线性回归

    线性回归.岭回归.Lasso回归 前言 1.线性回归 2. 岭回归 3. Lasso回归 4. tensorflow利用梯度下降实现的线性回归 前言 本章主要介绍线性回归.岭回归.Lasso回归,te ...

  5. 个人总结:从 线性回归 到 逻辑回归 为什么逻辑回归又叫对数几率回归?

    逻辑回归不是回归算法,是分类算法,可以处理二元分类以及多元分类. 线性回归 线性回归的模型是求出特征向量Y和输入样本矩阵X之间的线性关系系数θ,满足Y = Xθ.此时Y是连续的,所以是回归模型. 对应 ...

  6. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  7. 线性回归与逻辑回归/朴素贝叶斯

    一.线性回归与逻辑回归 (一)线性回归 1. 算法概述 回归的目的是预测数值型的目标值. 线性回归的优点:结果易于理解,计算上不复杂.缺点:对非线性的数据拟合不好.适用数据类型:数值型和标称型数据. ...

  8. ESP32 Tensorflow Lite (二)TensorFlow Lite Hello World

    TensorFlow Lite Hello World TensorFlow Lite Hello World 1. 导入依赖 2. 生成数据 3. 添加噪声 4. 数据分割 5. 设计模型 6. 训 ...

  9. 线性回归、逻辑回归及SVM

    1,回归(Linear Regression) 回归其实就是对已知公式的未知参数进行估计.可以简单的理解为:在给定训练样本点和已知的公式后,对于一个或多个未知参数,机器会自动枚举参数的所有可能取值(对 ...

最新文章

  1. __slots__(面向对象进阶)
  2. nagios出现乱码
  3. python中的open函数
  4. 上届作品回顾丨如何在 Innovation 2021 开发者大赛中脱颖而出?
  5. 图像相似度算法的C#实现及测评
  6. [leetcode] 72.编辑距离
  7. fiddler之数据统计(statistics)
  8. 内核初始化流程start_kernel
  9. webpack配置信息说明
  10. git cherry pick用法
  11. 智能家居通信协议科普,什么户型选择什么产品一文看懂
  12. 【生活中的逻辑谬误】止于分析和简化主义
  13. 学计算机怎能不知道电脑配置
  14. SWIG和MapGuide Web API
  15. 【点云处理技术之PCL】随机采样一致算法(Random sample consensus,RANSAC)
  16. matlab的损失函数mse,MSELoss损失函数
  17. 物联网创业项目(物联网创业点子大全500个)
  18. OSChina 娱乐弹弹弹——凉风有信,秋月无边
  19. 数据挖掘BUC算法实现
  20. 织梦可以不用mysql吗_织梦dedecms不用功能精简及安全设置

热门文章

  1. 类的加载顺序和对象的实例化
  2. Lua和C语言的交互——C API
  3. Warning: post-commit hook failed (exit code 255) with no output.
  4. 服务器无限火力时间,LOL无限火力2018时间表6月具体开启时间 无限火力模式什么时候出...
  5. html制作虚拟人物,一种虚拟人物角色直播系统的制作方法
  6. mysql bin的过期时间_Mysql设置binlog过期时间并自动删除
  7. 爱丁堡大学计算机专业alevel,爱丁堡大学alevel要求?
  8. python subprocess_python subprocess - 刘江的python教程
  9. Android程序员如何有效提升学习效率?帮你突破瓶颈
  10. 【深度学习】人脸识别和口罩检测的应用