“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。
当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!

一、数据集

1.读取数据集

import pandas as pddf = pd.read_excel('titanic3.xls')
df.describe()
pclass survived age sibsp parch fare body
count 1309.000000 1309.000000 1046.000000 1309.000000 1309.000000 1308.000000 121.000000
mean 2.294882 0.381971 29.881135 0.498854 0.385027 33.295479 160.809917
std 0.837836 0.486055 14.413500 1.041658 0.865560 51.758668 97.696922
min 1.000000 0.000000 0.166700 0.000000 0.000000 0.000000 1.000000
25% 2.000000 0.000000 21.000000 0.000000 0.000000 7.895800 72.000000
50% 3.000000 0.000000 28.000000 0.000000 0.000000 14.454200 155.000000
75% 3.000000 1.000000 39.000000 1.000000 0.000000 31.275000 256.000000
max 3.000000 1.000000 80.000000 8.000000 9.000000 512.329200 328.000000
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 14 columns):#   Column     Non-Null Count  Dtype
---  ------     --------------  -----  0   pclass     1309 non-null   int64  1   survived   1309 non-null   int64  2   name       1309 non-null   object 3   sex        1309 non-null   object 4   age        1046 non-null   float645   sibsp      1309 non-null   int64  6   parch      1309 non-null   int64  7   ticket     1309 non-null   object 8   fare       1308 non-null   float649   cabin      295 non-null    object 10  embarked   1307 non-null   object 11  boat       486 non-null    object 12  body       121 non-null    float6413  home.dest  745 non-null    object
dtypes: float64(3), int64(4), object(7)
memory usage: 143.3+ KB
df.head()
pclass survived name sex age sibsp parch ticket fare cabin embarked boat body home.dest
0 1 1 Allen, Miss. Elisabeth Walton female 29.0000 0 0 24160 211.3375 B5 S 2 NaN St Louis, MO
1 1 1 Allison, Master. Hudson Trevor male 0.9167 1 2 113781 151.5500 C22 C26 S 11 NaN Montreal, PQ / Chesterville, ON
2 1 0 Allison, Miss. Helen Loraine female 2.0000 1 2 113781 151.5500 C22 C26 S NaN NaN Montreal, PQ / Chesterville, ON
3 1 0 Allison, Mr. Hudson Joshua Creighton male 30.0000 1 2 113781 151.5500 C22 C26 S NaN 135.0 Montreal, PQ / Chesterville, ON
4 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female 25.0000 1 2 113781 151.5500 C22 C26 S NaN NaN Montreal, PQ / Chesterville, ON

2.处理数据集

  • 提取字段
  • 处理缺失值
  • 转换编码
  • 删除name列
# 筛选需要提取的字段
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols] # 默认按列取值
df.head()
survived name pclass sex age sibsp parch fare embarked
0 1 Allen, Miss. Elisabeth Walton 1 female 29.0000 0 0 211.3375 S
1 1 Allison, Master. Hudson Trevor 1 male 0.9167 1 2 151.5500 S
2 0 Allison, Miss. Helen Loraine 1 female 2.0000 1 2 151.5500 S
3 0 Allison, Mr. Hudson Joshua Creighton 1 male 30.0000 1 2 151.5500 S
4 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) 1 female 25.0000 1 2 151.5500 S
# 找出有null值的字段
df.isnull().any()
survived    False
name        False
pclass      False
sex         False
age          True
sibsp       False
parch       False
fare         True
embarked     True
dtype: bool
# 统计各个列有多少个空值
df.isnull().sum()
survived      0
name          0
pclass        0
sex           0
age         263
sibsp         0
parch         0
fare          1
embarked      2
dtype: int64
# 确定缺失值的位置
df[df.isnull().values==True]
survived name pclass sex age sibsp parch fare embarked
15 0 Baumann, Mr. John D 1 male NaN 0 0 25.9250 S
37 1 Bradley, Mr. George ("George Arthur Brayton") 1 male NaN 0 0 26.5500 S
40 0 Brewe, Dr. Arthur Jackson 1 male NaN 0 0 39.6000 C
46 0 Cairns, Mr. Alexander 1 male NaN 0 0 31.0000 S
59 1 Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev... 1 female NaN 0 0 27.7208 C
... ... ... ... ... ... ... ... ... ...
1293 0 Williams, Mr. Howard Hugh "Harry" 3 male NaN 0 0 8.0500 S
1297 0 Wiseman, Mr. Phillippe 3 male NaN 0 0 7.2500 S
1302 0 Yousif, Mr. Wazli 3 male NaN 0 0 7.2250 C
1303 0 Yousseff, Mr. Gerious 3 male NaN 0 0 14.4583 C
1305 0 Zabour, Miss. Thamine 3 female NaN 1 0 14.4542 C

266 rows × 9 columns

# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any() # 但凡有空值就返回True
False
# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')
df.isnull().any()
survived    False
name        False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool
# 转换编码
# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
# 删除name字段
df = df.drop(['name'],axis=1) # 0行1列
df.head()
survived pclass sex age sibsp parch fare embarked
0 1 1 0 29.0000 0 0 211.3375 2
1 1 1 1 0.9167 1 2 151.5500 2
2 0 1 0 2.0000 1 2 151.5500 2
3 0 1 1 30.0000 1 2 151.5500 2
4 0 1 0 25.0000 1 2 151.5500 2

3.划分特征值和标签值

# 分离特征值和标签值
data = df.values# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]
labels.shape
(1309,)

4.定义数据预处理函数

def prepare_data(df):# 删除name列df = df.drop(['name'],axis=1) # 将age空的字段改为平均值age_mean = df['age'].mean()df['age'] = df['age'].fillna(age_mean)# 将fare空的字段改为平均值fare_mean = df['fare'].mean()df['fare'] = df['fare'].fillna(age_mean)# 为确实embarked记录填充值df['embarked'] = df['embarked'].fillna('S')# 性别sex由字符串转换为数字编码df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)# 港口embarked由字母表示转换为数字编码df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)print(df.isnull().any())# 分离特征值和标签值data = df.values# 后七列是特征值features = data[:,1:] # ndarray默认取行,dataframe默认取列# 第零列是标签值labels = data[:,0]return features,labels

5.划分训练集和测试集

shuffle_df = df_selected.sample(frac=1) # 打乱数据顺序,为后面训练做准备,frac为百分比,df保持不变
x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape
survived    False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool((1309, 7), (1309,))
shuffle_df.head()
survived name pclass sex age sibsp parch fare embarked
58 0 Case, Mr. Howard Brown 1 male 49.0 0 0 26.0000 S
666 0 Barbara, Mrs. (Catherine David) 3 female 45.0 0 1 14.4542 C
781 0 Drazenoic, Mr. Jozef 3 male 33.0 0 0 7.8958 C
480 0 Laroche, Mr. Joseph Philippe Lemercier 2 male 25.0 1 2 41.5792 C
459 0 Jacobsohn, Mr. Sidney Samuel 2 male 42.0 1 0 27.0000 S
test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
# 训练集
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
# 测试集
x_test = x_data[train_num:]
y_test = y_data[train_num:]

6.归一化

from sklearn import preprocessingminmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train) # 特征值标准化
x_test = minmax_scale.fit_transform(x_test)

二、模型

import tensorflow as tf
tf.__version__
'2.6.0'

1.建立序列模型

model = tf.keras.models.Sequential()

2.添加隐藏层

model.add(tf.keras.layers.Dense(units=64,use_bias=True,activation='relu',input_dim=7, # 也可以用input_shape=(7,)bias_initializer='zeros',kernel_initializer='normal'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合
model.add(tf.keras.layers.Dense(units=32,activation='sigmoid',input_shape=(64,), # 也可以用input_dim=64bias_initializer='zeros',kernel_initializer='uniform'))
model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合

3.添加输出层

model.add(tf.keras.layers.Dense(units=1,activation='sigmoid',input_dim=32, # 也可以用input_shape=(7,)bias_initializer='zeros',kernel_initializer='uniform'))
model.summary()
Model: "sequential_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_68 (Dense)             (None, 64)                512
_________________________________________________________________
dropout_6 (Dropout)          (None, 64)                0
_________________________________________________________________
dense_69 (Dense)             (None, 32)                2080
_________________________________________________________________
dropout_7 (Dropout)          (None, 32)                0
_________________________________________________________________
dense_70 (Dense)             (None, 1)                 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________

三、训练

1.训练

# 定义训练模式
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])
# 设置训练参数
train_epochs = 100
batch_size = 40
train_history = model.fit(x=x_train,#训练特征值y=y_trian,#训练集的标签validation_split=0.2,#验证集的比例epochs=train_epochs,#训练的次数batch_size=batch_size,#批量的大小verbose=2) #训练过程的日志信息显示,一个epoch输出一行记录
Epoch 1/100
21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429
Epoch 2/100
21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429
Epoch 3/100
21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667
Epoch 4/100
21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905
Epoch 5/100
21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952
Epoch 6/100
21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143
Epoch 7/100
21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143
Epoch 8/100
21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095
Epoch 9/100
21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952
Epoch 10/100
21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000
Epoch 11/100
21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905
Epoch 12/100
21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905
Epoch 13/100
21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095
Epoch 14/100
21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 15/100
21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095
Epoch 16/100
21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095
Epoch 17/100
21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095
Epoch 18/100
21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095
Epoch 19/100
21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 20/100
21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 21/100
21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000
Epoch 22/100
21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857
Epoch 23/100
21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 24/100
21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905
Epoch 25/100
21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905
Epoch 26/100
21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000
Epoch 27/100
21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048
Epoch 28/100
21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000
Epoch 29/100
21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000
Epoch 30/100
21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000
Epoch 31/100
21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952
Epoch 32/100
21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952
Epoch 33/100
21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952
Epoch 34/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000
Epoch 35/100
21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952
Epoch 36/100
21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 37/100
21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952
Epoch 38/100
21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 39/100
21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000
Epoch 40/100
21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048
Epoch 41/100
21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952
Epoch 42/100
21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000
Epoch 43/100
21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905
Epoch 44/100
21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048
Epoch 45/100
21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000
Epoch 46/100
21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048
Epoch 47/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000
Epoch 48/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000
Epoch 49/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048
Epoch 50/100
21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000
Epoch 51/100
21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095
Epoch 52/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048
Epoch 53/100
21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000
Epoch 54/100
21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000
Epoch 55/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048
Epoch 56/100
21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048
Epoch 57/100
21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095
Epoch 58/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095
Epoch 59/100
21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095
Epoch 60/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 61/100
21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000
Epoch 62/100
21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048
Epoch 63/100
21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095
Epoch 64/100
21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000
Epoch 65/100
21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095
Epoch 66/100
21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952
Epoch 67/100
21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000
Epoch 68/100
21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952
Epoch 69/100
21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000
Epoch 70/100
21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952
Epoch 71/100
21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952
Epoch 72/100
21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 73/100
21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952
Epoch 74/100
21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952
Epoch 75/100
21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095
Epoch 76/100
21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 77/100
21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000
Epoch 78/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 79/100
21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048
Epoch 80/100
21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952
Epoch 81/100
21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095
Epoch 82/100
21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095
Epoch 83/100
21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952
Epoch 84/100
21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048
Epoch 85/100
21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095
Epoch 86/100
21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 87/100
21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952
Epoch 88/100
21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 89/100
21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000
Epoch 90/100
21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 91/100
21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048
Epoch 92/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048
Epoch 93/100
21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048
Epoch 94/100
21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095
Epoch 95/100
21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 96/100
21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095
Epoch 97/100
21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048
Epoch 98/100
21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143
Epoch 99/100
21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048
Epoch 100/100
21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095

2.训练过程可视化

# 训练过程可视化
import matplotlib.pyplot as pltdef show_train_history(trian_history,train_metric,validation_metric):plt.plot(trian_history[train_metric])plt.plot(trian_history[validation_metric])plt.title('Train History')plt.ylabel(train_metric)plt.xlabel('epoch')plt.legend(['train','validation'],loc='upper left')plt.show()
show_train_history(train_history.history,'loss','val_loss')

show_train_history(train_history.history,'accuracy','val_accuracy')

3.评估模型

loss,acc = model.evaluate(x_test,y_test)
9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435
loss,acc
(0.3702643811702728, 0.8435114622116089)

四.预测

#@title
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']
x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre
survived name pclass sex age sibsp parch fare embarked
0 0 Jack 3 male 23 1 0 5.0 S
1 1 Rose 1 female 20 1 0 100.0 S
x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessingminmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features) # 特征值标准化
y_pre = model.predict(x_pre_features)
survived    False
pclass      False
sex         False
age         False
sibsp       False
parch       False
fare        False
embarked    False
dtype: bool
x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)
x_pre
survived name pclass sex age sibsp parch fare embarked surv_probabilty
0 0 Jack 3 male 23 1 0 5.0 S 0.058498
1 1 Rose 1 female 20 1 0 100.0 S 0.975978

TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测相关推荐

  1. Keras神经网络实现泰坦尼克号旅客生存预测

    Keras神经网络实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 训练 可视化 评估,预测 结果 代码 介绍 参考资料: 网易云课堂的深度学 ...

  2. sklearn的随机森林实现泰坦尼克号旅客生存预测

    sklearn的随机森林实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 评估,预测 结果 代码 介绍 参考资料: https://wenku ...

  3. python数据分析/机器学习 笔记之决策树(泰坦尼克号旅客生存预测)

    最近在学习用python数据分析,不可避免的接触到了机器学习的一些算法,所以在这里简单整理一些学习的笔记和心得与大家分享! 首先机器学习分为:监督学习和非监督学习,前者有参照物,后者为参照物:主要分为 ...

  4. tensorflow从入门到精通100讲(七)-TensorFlow房价预估使用Keras快速构建模型

    前言 这篇文章承接上一篇tensorflow从入门到精通100讲(二)-IRIS数据集应用实战 https://wenyusuran.blog.csdn.net/article/details/107 ...

  5. Tensorflow系列 | Tensorflow从入门到精通(二):附代码实战

    作者 | AI小昕 编辑 | 安可 [导读]:本文讲了Tensorflow从入门到精通.欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机视觉. Tensor介绍 Tensor(张量)是Tenso ...

  6. 基于深度学习的泰坦尼克旅客生存预测

    基于深度学习的泰坦尼克旅客生存预测 摘要:近年来,随着深度学习的迅速发展和崛起,尤其在图像分类方向取得了巨大的成就.本文实验基于Windows10系统,仿真软件用的是Anaconda下基于python ...

  7. 【决策树算法】泰坦尼克号乘客生存预测

    泰坦尼克号乘客生存预测 1. 案例背景 2. 步骤分析 3. 代码实现 4. 决策树可视化 4.1 保存树的结构到dot文件 4.2 网站显示结构 5. 决策树总结 6. 小结 1. 案例背景 泰坦尼 ...

  8. 泰坦尼克号乘客生存预测(XGBoost)

    泰坦尼克号乘客生存预测(XGBoost) 1. 案例背景 2. 步骤分析 3. 代码实现 1. 案例背景 泰坦尼克号沉没是历史上最臭名昭着的沉船之一.1912年4月15日,在她的处女航中,泰坦尼克号在 ...

  9. Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02

    Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02 文章目录 Kaggle经典测试,泰坦尼克号的生存预测,机器学习实验----02 一.引言 二.问题 三.问题分析 四.具体操作 1 ...

  10. tensorflow2.0——预测泰坦尼克号旅客生存概率(Keras应用实践)

    一.数据准备 1.导入相关的库 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import pa ...

最新文章

  1. java string 后几位_java中String占几个位元组
  2. git stash pop冲突_这有一份 git 日常使用清单,你需要吗?
  3. python学习详解_深入解析Python小白学习【操作列表】
  4. SpringBoot 配置绑定
  5. spring boot配置dubbo注意事项
  6. java中的几个集合类
  7. 百度地图添加自定义shp图层_GIS当中使用uDig打开shp图层,并查看数据结果
  8. 01.查找的基本概念
  9. IOS:APP三种状态下收到推送后的跳转操作
  10. Windows Server 2003 SP2(32位) 中文版 下载地址 光盘整合方法
  11. 国内完全免费的电子图书下载网址
  12. php 函数名,php里函数名或者方法名前加 符号表示的意思
  13. 苹果电脑系统更新中断怎么办_苹果发布健身公告中断按需锻炼空间
  14. 依图科技:多个人工智能应用领域达到全球领先水平 | 百万人学AI评选
  15. Gos —— 启动分页机制
  16. Dell清除BIOS密码及硬盘锁
  17. 使用PowerShell下载必应图片
  18. 在互联网上提问应该注意什么?
  19. “之“字形打印矩阵(Java)
  20. 股票量化交易Python——计算收益率

热门文章

  1. p17.matplotlib:图中图
  2. Linux命令:halt
  3. Html2Excel 更名为 MyExcel,2.1.0 版本发布!
  4. JavaScript Window窗口对象
  5. 【电脑操作】【鼠标】无线鼠标无反应怎么办?
  6. 【前端性能】浅谈域名发散与域名收敛
  7. vim时”E575: viminfo: Illegal starting char in line。。。。。。
  8. windows关闭休眠
  9. [读书笔记]高效15法则 谷歌、苹果都在用的深度工作法
  10. Git-Dumper工具:从站点中导出一个Git库