TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测
“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。
当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!
一、数据集
1.读取数据集
import pandas as pddf = pd.read_excel('titanic3.xls')
df.describe()
pclass | survived | age | sibsp | parch | fare | body | |
---|---|---|---|---|---|---|---|
count | 1309.000000 | 1309.000000 | 1046.000000 | 1309.000000 | 1309.000000 | 1308.000000 | 121.000000 |
mean | 2.294882 | 0.381971 | 29.881135 | 0.498854 | 0.385027 | 33.295479 | 160.809917 |
std | 0.837836 | 0.486055 | 14.413500 | 1.041658 | 0.865560 | 51.758668 | 97.696922 |
min | 1.000000 | 0.000000 | 0.166700 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 2.000000 | 0.000000 | 21.000000 | 0.000000 | 0.000000 | 7.895800 | 72.000000 |
50% | 3.000000 | 0.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 | 155.000000 |
75% | 3.000000 | 1.000000 | 39.000000 | 1.000000 | 0.000000 | 31.275000 | 256.000000 |
max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 | 328.000000 |
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 14 columns):# Column Non-Null Count Dtype
--- ------ -------------- ----- 0 pclass 1309 non-null int64 1 survived 1309 non-null int64 2 name 1309 non-null object 3 sex 1309 non-null object 4 age 1046 non-null float645 sibsp 1309 non-null int64 6 parch 1309 non-null int64 7 ticket 1309 non-null object 8 fare 1308 non-null float649 cabin 295 non-null object 10 embarked 1307 non-null object 11 boat 486 non-null object 12 body 121 non-null float6413 home.dest 745 non-null object
dtypes: float64(3), int64(4), object(7)
memory usage: 143.3+ KB
df.head()
pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | Allen, Miss. Elisabeth Walton | female | 29.0000 | 0 | 0 | 24160 | 211.3375 | B5 | S | 2 | NaN | St Louis, MO |
1 | 1 | 1 | Allison, Master. Hudson Trevor | male | 0.9167 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | 11 | NaN | Montreal, PQ / Chesterville, ON |
2 | 1 | 0 | Allison, Miss. Helen Loraine | female | 2.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
3 | 1 | 0 | Allison, Mr. Hudson Joshua Creighton | male | 30.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | 135.0 | Montreal, PQ / Chesterville, ON |
4 | 1 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | female | 25.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
2.处理数据集
- 提取字段
- 处理缺失值
- 转换编码
- 删除name列
# 筛选需要提取的字段
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols] # 默认按列取值
df.head()
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | Allen, Miss. Elisabeth Walton | 1 | female | 29.0000 | 0 | 0 | 211.3375 | S |
1 | 1 | Allison, Master. Hudson Trevor | 1 | male | 0.9167 | 1 | 2 | 151.5500 | S |
2 | 0 | Allison, Miss. Helen Loraine | 1 | female | 2.0000 | 1 | 2 | 151.5500 | S |
3 | 0 | Allison, Mr. Hudson Joshua Creighton | 1 | male | 30.0000 | 1 | 2 | 151.5500 | S |
4 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | 1 | female | 25.0000 | 1 | 2 | 151.5500 | S |
# 找出有null值的字段
df.isnull().any()
survived False
name False
pclass False
sex False
age True
sibsp False
parch False
fare True
embarked True
dtype: bool
# 统计各个列有多少个空值
df.isnull().sum()
survived 0
name 0
pclass 0
sex 0
age 263
sibsp 0
parch 0
fare 1
embarked 2
dtype: int64
# 确定缺失值的位置
df[df.isnull().values==True]
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
15 | 0 | Baumann, Mr. John D | 1 | male | NaN | 0 | 0 | 25.9250 | S |
37 | 1 | Bradley, Mr. George ("George Arthur Brayton") | 1 | male | NaN | 0 | 0 | 26.5500 | S |
40 | 0 | Brewe, Dr. Arthur Jackson | 1 | male | NaN | 0 | 0 | 39.6000 | C |
46 | 0 | Cairns, Mr. Alexander | 1 | male | NaN | 0 | 0 | 31.0000 | S |
59 | 1 | Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev... | 1 | female | NaN | 0 | 0 | 27.7208 | C |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1293 | 0 | Williams, Mr. Howard Hugh "Harry" | 3 | male | NaN | 0 | 0 | 8.0500 | S |
1297 | 0 | Wiseman, Mr. Phillippe | 3 | male | NaN | 0 | 0 | 7.2500 | S |
1302 | 0 | Yousif, Mr. Wazli | 3 | male | NaN | 0 | 0 | 7.2250 | C |
1303 | 0 | Yousseff, Mr. Gerious | 3 | male | NaN | 0 | 0 | 14.4583 | C |
1305 | 0 | Zabour, Miss. Thamine | 3 | female | NaN | 1 | 0 | 14.4542 | C |
266 rows × 9 columns
# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any() # 但凡有空值就返回True
False
# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')
df.isnull().any()
survived False
name False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
# 转换编码
# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
# 删除name字段
df = df.drop(['name'],axis=1) # 0行1列
df.head()
survived | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 29.0000 | 0 | 0 | 211.3375 | 2 |
1 | 1 | 1 | 1 | 0.9167 | 1 | 2 | 151.5500 | 2 |
2 | 0 | 1 | 0 | 2.0000 | 1 | 2 | 151.5500 | 2 |
3 | 0 | 1 | 1 | 30.0000 | 1 | 2 | 151.5500 | 2 |
4 | 0 | 1 | 0 | 25.0000 | 1 | 2 | 151.5500 | 2 |
3.划分特征值和标签值
# 分离特征值和标签值
data = df.values# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]
labels.shape
(1309,)
4.定义数据预处理函数
def prepare_data(df):# 删除name列df = df.drop(['name'],axis=1) # 将age空的字段改为平均值age_mean = df['age'].mean()df['age'] = df['age'].fillna(age_mean)# 将fare空的字段改为平均值fare_mean = df['fare'].mean()df['fare'] = df['fare'].fillna(age_mean)# 为确实embarked记录填充值df['embarked'] = df['embarked'].fillna('S')# 性别sex由字符串转换为数字编码df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)# 港口embarked由字母表示转换为数字编码df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)print(df.isnull().any())# 分离特征值和标签值data = df.values# 后七列是特征值features = data[:,1:] # ndarray默认取行,dataframe默认取列# 第零列是标签值labels = data[:,0]return features,labels
5.划分训练集和测试集
shuffle_df = df_selected.sample(frac=1) # 打乱数据顺序,为后面训练做准备,frac为百分比,df保持不变
x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool((1309, 7), (1309,))
shuffle_df.head()
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
58 | 0 | Case, Mr. Howard Brown | 1 | male | 49.0 | 0 | 0 | 26.0000 | S |
666 | 0 | Barbara, Mrs. (Catherine David) | 3 | female | 45.0 | 0 | 1 | 14.4542 | C |
781 | 0 | Drazenoic, Mr. Jozef | 3 | male | 33.0 | 0 | 0 | 7.8958 | C |
480 | 0 | Laroche, Mr. Joseph Philippe Lemercier | 2 | male | 25.0 | 1 | 2 | 41.5792 | C |
459 | 0 | Jacobsohn, Mr. Sidney Samuel | 2 | male | 42.0 | 1 | 0 | 27.0000 | S |
test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
# 训练集
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
# 测试集
x_test = x_data[train_num:]
y_test = y_data[train_num:]
6.归一化
from sklearn import preprocessingminmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train) # 特征值标准化
x_test = minmax_scale.fit_transform(x_test)
二、模型
import tensorflow as tf
tf.__version__
'2.6.0'
1.建立序列模型
model = tf.keras.models.Sequential()
2.添加隐藏层
model.add(tf.keras.layers.Dense(units=64,use_bias=True,activation='relu',input_dim=7, # 也可以用input_shape=(7,)bias_initializer='zeros',kernel_initializer='normal'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
model.add(tf.keras.layers.Dense(units=32,activation='sigmoid',input_shape=(64,), # 也可以用input_dim=64bias_initializer='zeros',kernel_initializer='uniform'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
3.添加输出层
model.add(tf.keras.layers.Dense(units=1,activation='sigmoid',input_dim=32, # 也可以用input_shape=(7,)bias_initializer='zeros',kernel_initializer='uniform'))
model.summary()
Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_68 (Dense) (None, 64) 512
_________________________________________________________________
dropout_6 (Dropout) (None, 64) 0
_________________________________________________________________
dense_69 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_7 (Dropout) (None, 32) 0
_________________________________________________________________
dense_70 (Dense) (None, 1) 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________
三、训练
1.训练
# 定义训练模式
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])
# 设置训练参数
train_epochs = 100
batch_size = 40
train_history = model.fit(x=x_train,#训练特征值y=y_trian,#训练集的标签validation_split=0.2,#验证集的比例epochs=train_epochs,#训练的次数batch_size=batch_size,#批量的大小verbose=2) #训练过程的日志信息显示,一个epoch输出一行记录
Epoch 1/100
21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429
Epoch 2/100
21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429
Epoch 3/100
21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667
Epoch 4/100
21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905
Epoch 5/100
21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952
Epoch 6/100
21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143
Epoch 7/100
21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143
Epoch 8/100
21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095
Epoch 9/100
21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952
Epoch 10/100
21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000
Epoch 11/100
21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905
Epoch 12/100
21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905
Epoch 13/100
21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095
Epoch 14/100
21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 15/100
21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095
Epoch 16/100
21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095
Epoch 17/100
21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095
Epoch 18/100
21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095
Epoch 19/100
21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 20/100
21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 21/100
21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000
Epoch 22/100
21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857
Epoch 23/100
21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 24/100
21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905
Epoch 25/100
21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905
Epoch 26/100
21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000
Epoch 27/100
21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048
Epoch 28/100
21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000
Epoch 29/100
21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000
Epoch 30/100
21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000
Epoch 31/100
21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952
Epoch 32/100
21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952
Epoch 33/100
21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952
Epoch 34/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000
Epoch 35/100
21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952
Epoch 36/100
21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 37/100
21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952
Epoch 38/100
21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 39/100
21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000
Epoch 40/100
21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048
Epoch 41/100
21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952
Epoch 42/100
21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000
Epoch 43/100
21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905
Epoch 44/100
21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048
Epoch 45/100
21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000
Epoch 46/100
21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048
Epoch 47/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000
Epoch 48/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000
Epoch 49/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048
Epoch 50/100
21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000
Epoch 51/100
21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095
Epoch 52/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048
Epoch 53/100
21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000
Epoch 54/100
21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000
Epoch 55/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048
Epoch 56/100
21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048
Epoch 57/100
21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095
Epoch 58/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095
Epoch 59/100
21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095
Epoch 60/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 61/100
21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000
Epoch 62/100
21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048
Epoch 63/100
21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095
Epoch 64/100
21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000
Epoch 65/100
21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095
Epoch 66/100
21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952
Epoch 67/100
21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000
Epoch 68/100
21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952
Epoch 69/100
21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000
Epoch 70/100
21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952
Epoch 71/100
21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952
Epoch 72/100
21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 73/100
21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952
Epoch 74/100
21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952
Epoch 75/100
21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095
Epoch 76/100
21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 77/100
21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000
Epoch 78/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 79/100
21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048
Epoch 80/100
21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952
Epoch 81/100
21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095
Epoch 82/100
21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095
Epoch 83/100
21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952
Epoch 84/100
21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048
Epoch 85/100
21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095
Epoch 86/100
21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 87/100
21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952
Epoch 88/100
21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 89/100
21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000
Epoch 90/100
21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 91/100
21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048
Epoch 92/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048
Epoch 93/100
21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048
Epoch 94/100
21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095
Epoch 95/100
21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 96/100
21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095
Epoch 97/100
21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048
Epoch 98/100
21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143
Epoch 99/100
21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048
Epoch 100/100
21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095
2.训练过程可视化
# 训练过程可视化
import matplotlib.pyplot as pltdef show_train_history(trian_history,train_metric,validation_metric):plt.plot(trian_history[train_metric])plt.plot(trian_history[validation_metric])plt.title('Train History')plt.ylabel(train_metric)plt.xlabel('epoch')plt.legend(['train','validation'],loc='upper left')plt.show()
show_train_history(train_history.history,'loss','val_loss')
show_train_history(train_history.history,'accuracy','val_accuracy')
3.评估模型
loss,acc = model.evaluate(x_test,y_test)
9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435
loss,acc
(0.3702643811702728, 0.8435114622116089)
四.预测
#@title
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']
x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S |
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S |
x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessingminmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features) # 特征值标准化
y_pre = model.predict(x_pre_features)
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)
x_pre
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | surv_probabilty | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S | 0.058498 |
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S | 0.975978 |
TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测相关推荐
- Keras神经网络实现泰坦尼克号旅客生存预测
Keras神经网络实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 训练 可视化 评估,预测 结果 代码 介绍 参考资料: 网易云课堂的深度学 ...
- sklearn的随机森林实现泰坦尼克号旅客生存预测
sklearn的随机森林实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 评估,预测 结果 代码 介绍 参考资料: https://wenku ...
- python数据分析/机器学习 笔记之决策树(泰坦尼克号旅客生存预测)
最近在学习用python数据分析,不可避免的接触到了机器学习的一些算法,所以在这里简单整理一些学习的笔记和心得与大家分享! 首先机器学习分为:监督学习和非监督学习,前者有参照物,后者为参照物:主要分为 ...
- tensorflow从入门到精通100讲(七)-TensorFlow房价预估使用Keras快速构建模型
前言 这篇文章承接上一篇tensorflow从入门到精通100讲(二)-IRIS数据集应用实战 https://wenyusuran.blog.csdn.net/article/details/107 ...
- Tensorflow系列 | Tensorflow从入门到精通(二):附代码实战
作者 | AI小昕 编辑 | 安可 [导读]:本文讲了Tensorflow从入门到精通.欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机视觉. Tensor介绍 Tensor(张量)是Tenso ...
- 基于深度学习的泰坦尼克旅客生存预测
基于深度学习的泰坦尼克旅客生存预测 摘要:近年来,随着深度学习的迅速发展和崛起,尤其在图像分类方向取得了巨大的成就.本文实验基于Windows10系统,仿真软件用的是Anaconda下基于python ...
- 【决策树算法】泰坦尼克号乘客生存预测
泰坦尼克号乘客生存预测 1. 案例背景 2. 步骤分析 3. 代码实现 4. 决策树可视化 4.1 保存树的结构到dot文件 4.2 网站显示结构 5. 决策树总结 6. 小结 1. 案例背景 泰坦尼 ...
- 泰坦尼克号乘客生存预测(XGBoost)
泰坦尼克号乘客生存预测(XGBoost) 1. 案例背景 2. 步骤分析 3. 代码实现 1. 案例背景 泰坦尼克号沉没是历史上最臭名昭着的沉船之一.1912年4月15日,在她的处女航中,泰坦尼克号在 ...
- Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02
Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02 文章目录 Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02 一.引言 二.问题 三.问题分析 四.具体操作 1 ...
- tensorflow2.0——预测泰坦尼克号旅客生存概率(Keras应用实践)
一.数据准备 1.导入相关的库 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import pa ...
最新文章
- java string 后几位_java中String占几个位元组
- git stash pop冲突_这有一份 git 日常使用清单,你需要吗?
- python学习详解_深入解析Python小白学习【操作列表】
- SpringBoot 配置绑定
- spring boot配置dubbo注意事项
- java中的几个集合类
- 百度地图添加自定义shp图层_GIS当中使用uDig打开shp图层,并查看数据结果
- 01.查找的基本概念
- IOS:APP三种状态下收到推送后的跳转操作
- Windows Server 2003 SP2(32位) 中文版 下载地址 光盘整合方法
- 国内完全免费的电子图书下载网址
- php 函数名,php里函数名或者方法名前加 符号表示的意思
- 苹果电脑系统更新中断怎么办_苹果发布健身公告中断按需锻炼空间
- 依图科技:多个人工智能应用领域达到全球领先水平 | 百万人学AI评选
- Gos —— 启动分页机制
- Dell清除BIOS密码及硬盘锁
- 使用PowerShell下载必应图片
- 在互联网上提问应该注意什么?
- “之“字形打印矩阵(Java)
- 股票量化交易Python——计算收益率
热门文章
- p17.matplotlib:图中图
- Linux命令:halt
- Html2Excel 更名为 MyExcel,2.1.0 版本发布!
- JavaScript Window窗口对象
- 【电脑操作】【鼠标】无线鼠标无反应怎么办?
- 【前端性能】浅谈域名发散与域名收敛
- vim时”E575: viminfo: Illegal starting char in line。。。。。。
- windows关闭休眠
- [读书笔记]高效15法则 谷歌、苹果都在用的深度工作法
- Git-Dumper工具:从站点中导出一个Git库