【人工智能项目】LSTM实现数据预测分类实验

本次主要对csv文件中采集到的数据来区分树的品种实验,通过不同列的数据,送入lstm模型中,得到预测结果。

导包

# 导包
import numpy as np
import pandas as pd
import glob
import os
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
train_path = "./"
print(os.listdir(train_path))
['.ipynb_checkpoints', 'code.ipynb', 'data.csv', 'lstm.h5', 'plant_totaldatat.xlsx']

读取数据

# 读取文件
data = pd.read_csv("data.csv")
data
Filename 172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label
0 芭蕉0001.ROH 414.421417 445.234558 482.571625 378.288757 483.976776 476.850617 423.253845 445.033813 477.653564 ... 487.088196 513.986938 532.956604 545.502625 504.853424 568.687744 547.811096 584.947449 564.773376 1
1 芭蕉0002.ROH 469.523712 450.353333 447.543030 457.880981 467.616699 456.375458 483.575287 447.543030 415.224365 ... 560.357178 511.477722 473.337708 613.151001 513.384766 495.418793 618.771606 618.570923 495.619507 1
2 芭蕉0003.ROH 508.265930 502.946411 522.317505 471.932556 512.682129 503.950104 498.429840 487.891144 465.910461 ... 597.694275 552.126953 540.885681 661.327881 553.030273 540.183106 650.889526 659.521240 550.219971 1
3 芭蕉0004.ROH 490.801819 514.789917 529.945557 463.501617 527.536682 525.027466 489.898499 514.288025 503.247528 ... 567.784424 576.416138 573.906921 625.596680 548.915161 621.280823 632.221008 652.595825 599.199768 1
4 芭蕉0005.ROH 431.383697 433.290680 436.703217 408.901154 459.386505 461.694977 453.264008 435.900269 438.810974 ... 535.867249 499.232788 503.849731 569.691467 518.704285 512.381043 577.921692 573.605835 520.812012 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2422 樟树10096.ROH 376.682861 396.656189 391.637787 371.363342 434.796234 390.634094 378.489502 410.406677 394.147003 ... 478.155395 441.520904 413.317383 508.265930 447.844116 424.558624 497.927978 522.618652 449.951874 7
2423 樟树10097.ROH 312.647797 359.419495 336.836578 315.056641 381.299835 351.891876 333.925903 377.586182 353.397400 ... 434.495117 420.242798 342.758331 476.047668 397.258423 369.054871 446.639709 460.390167 384.913086 7
2424 樟树10098.ROH 383.809052 372.166290 419.941681 371.363342 412.112946 411.912201 399.566894 382.905731 405.287903 ... 438.208740 460.089081 427.469330 478.556885 463.300873 468.620392 485.181183 525.328613 500.035736 7
2425 樟树10099.ROH 327.100861 333.725159 347.676392 332.621124 376.181030 364.538300 361.727966 377.786926 347.274902 ... 417.934326 411.410370 377.184723 433.190338 413.919586 395.752899 432.989593 445.636017 425.562317 7
2426 樟树10100.ROH 380.697601 424.859741 441.119446 388.526367 448.446320 433.089966 416.428803 431.383697 450.654449 ... 500.838684 497.526520 458.182068 537.673889 493.210663 465.709717 551.625122 561.059753 480.263153 7

2427 rows × 1757 columns

# 查看前5行数据
data.head()
Filename 172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label
0 芭蕉0001.ROH 414.421417 445.234558 482.571625 378.288757 483.976776 476.850617 423.253845 445.033813 477.653564 ... 487.088196 513.986938 532.956604 545.502625 504.853424 568.687744 547.811096 584.947449 564.773376 1
1 芭蕉0002.ROH 469.523712 450.353333 447.543030 457.880981 467.616699 456.375458 483.575287 447.543030 415.224365 ... 560.357178 511.477722 473.337708 613.151001 513.384766 495.418793 618.771606 618.570923 495.619507 1
2 芭蕉0003.ROH 508.265930 502.946411 522.317505 471.932556 512.682129 503.950104 498.429840 487.891144 465.910461 ... 597.694275 552.126953 540.885681 661.327881 553.030273 540.183106 650.889526 659.521240 550.219971 1
3 芭蕉0004.ROH 490.801819 514.789917 529.945557 463.501617 527.536682 525.027466 489.898499 514.288025 503.247528 ... 567.784424 576.416138 573.906921 625.596680 548.915161 621.280823 632.221008 652.595825 599.199768 1
4 芭蕉0005.ROH 431.383697 433.290680 436.703217 408.901154 459.386505 461.694977 453.264008 435.900269 438.810974 ... 535.867249 499.232788 503.849731 569.691467 518.704285 512.381043 577.921692 573.605835 520.812012 1

5 rows × 1757 columns

数据分析

data.index
RangeIndex(start=0, stop=2427, step=1)
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2427 entries, 0 to 2426
Columns: 1757 entries, Filename to Label
dtypes: float64(1755), int64(1), object(1)
memory usage: 32.5+ MB
None
# 去除缺失数据
data.dropna(axis=0, how='any', inplace=True)
data
Filename 172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label
0 芭蕉0001.ROH 414.421417 445.234558 482.571625 378.288757 483.976776 476.850617 423.253845 445.033813 477.653564 ... 487.088196 513.986938 532.956604 545.502625 504.853424 568.687744 547.811096 584.947449 564.773376 1
1 芭蕉0002.ROH 469.523712 450.353333 447.543030 457.880981 467.616699 456.375458 483.575287 447.543030 415.224365 ... 560.357178 511.477722 473.337708 613.151001 513.384766 495.418793 618.771606 618.570923 495.619507 1
2 芭蕉0003.ROH 508.265930 502.946411 522.317505 471.932556 512.682129 503.950104 498.429840 487.891144 465.910461 ... 597.694275 552.126953 540.885681 661.327881 553.030273 540.183106 650.889526 659.521240 550.219971 1
3 芭蕉0004.ROH 490.801819 514.789917 529.945557 463.501617 527.536682 525.027466 489.898499 514.288025 503.247528 ... 567.784424 576.416138 573.906921 625.596680 548.915161 621.280823 632.221008 652.595825 599.199768 1
4 芭蕉0005.ROH 431.383697 433.290680 436.703217 408.901154 459.386505 461.694977 453.264008 435.900269 438.810974 ... 535.867249 499.232788 503.849731 569.691467 518.704285 512.381043 577.921692 573.605835 520.812012 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2422 樟树10096.ROH 376.682861 396.656189 391.637787 371.363342 434.796234 390.634094 378.489502 410.406677 394.147003 ... 478.155395 441.520904 413.317383 508.265930 447.844116 424.558624 497.927978 522.618652 449.951874 7
2423 樟树10097.ROH 312.647797 359.419495 336.836578 315.056641 381.299835 351.891876 333.925903 377.586182 353.397400 ... 434.495117 420.242798 342.758331 476.047668 397.258423 369.054871 446.639709 460.390167 384.913086 7
2424 樟树10098.ROH 383.809052 372.166290 419.941681 371.363342 412.112946 411.912201 399.566894 382.905731 405.287903 ... 438.208740 460.089081 427.469330 478.556885 463.300873 468.620392 485.181183 525.328613 500.035736 7
2425 樟树10099.ROH 327.100861 333.725159 347.676392 332.621124 376.181030 364.538300 361.727966 377.786926 347.274902 ... 417.934326 411.410370 377.184723 433.190338 413.919586 395.752899 432.989593 445.636017 425.562317 7
2426 樟树10100.ROH 380.697601 424.859741 441.119446 388.526367 448.446320 433.089966 416.428803 431.383697 450.654449 ... 500.838684 497.526520 458.182068 537.673889 493.210663 465.709717 551.625122 561.059753 480.263153 7

2427 rows × 1757 columns

# 删除第一列数据
data = data.drop(['Filename'], axis=1)
data
172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 177.966 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label
0 414.421417 445.234558 482.571625 378.288757 483.976776 476.850617 423.253845 445.033813 477.653564 595.285400 ... 487.088196 513.986938 532.956604 545.502625 504.853424 568.687744 547.811096 584.947449 564.773376 1
1 469.523712 450.353333 447.543030 457.880981 467.616699 456.375458 483.575287 447.543030 415.224365 601.006409 ... 560.357178 511.477722 473.337708 613.151001 513.384766 495.418793 618.771606 618.570923 495.619507 1
2 508.265930 502.946411 522.317505 471.932556 512.682129 503.950104 498.429840 487.891144 465.910461 655.907959 ... 597.694275 552.126953 540.885681 661.327881 553.030273 540.183106 650.889526 659.521240 550.219971 1
3 490.801819 514.789917 529.945557 463.501617 527.536682 525.027466 489.898499 514.288025 503.247528 661.628967 ... 567.784424 576.416138 573.906921 625.596680 548.915161 621.280823 632.221008 652.595825 599.199768 1
4 431.383697 433.290680 436.703217 408.901154 459.386505 461.694977 453.264008 435.900269 438.810974 592.675842 ... 535.867249 499.232788 503.849731 569.691467 518.704285 512.381043 577.921692 573.605835 520.812012 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2422 376.682861 396.656189 391.637787 371.363342 434.796234 390.634094 378.489502 410.406677 394.147003 510.775147 ... 478.155395 441.520904 413.317383 508.265930 447.844116 424.558624 497.927978 522.618652 449.951874 7
2423 312.647797 359.419495 336.836578 315.056641 381.299835 351.891876 333.925903 377.586182 353.397400 424.960113 ... 434.495117 420.242798 342.758331 476.047668 397.258423 369.054871 446.639709 460.390167 384.913086 7
2424 383.809052 372.166290 419.941681 371.363342 412.112946 411.912201 399.566894 382.905731 405.287903 531.149963 ... 438.208740 460.089081 427.469330 478.556885 463.300873 468.620392 485.181183 525.328613 500.035736 7
2425 327.100861 333.725159 347.676392 332.621124 376.181030 364.538300 361.727966 377.786926 347.274902 443.227173 ... 417.934326 411.410370 377.184723 433.190338 413.919586 395.752899 432.989593 445.636017 425.562317 7
2426 380.697601 424.859741 441.119446 388.526367 448.446320 433.089966 416.428803 431.383697 450.654449 533.458435 ... 500.838684 497.526520 458.182068 537.673889 493.210663 465.709717 551.625122 561.059753 480.263153 7

2427 rows × 1756 columns

# 样本分布
# 以图方式表示
sns.countplot(data["Label"])
plt.xlabel("Label")
plt.title("Number of  messages")
Text(0.5, 1.0, 'Number of  messages')

# 重新排序
df = data.sample(frac=1).reset_index(drop=True)
df
172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 177.966 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label
0 429.978546 448.345978 447.342285 430.380005 473.839569 442.323853 457.178406 452.862549 429.175598 582.739380 ... 494.415100 501.942718 467.616699 526.934509 482.170136 509.470367 526.432617 595.385803 530.246643 6
1 281.834656 293.979248 335.431427 288.057526 310.238953 317.365112 305.822723 321.179108 327.502319 398.462830 ... 329.911163 375.578827 351.691132 373.270355 349.984863 390.132263 369.255615 398.061371 381.199463 7
2 440.316498 426.164520 453.665497 430.480377 450.052216 461.594605 456.174713 440.517212 444.130493 607.429993 ... 515.291748 490.400360 489.998871 569.992554 481.567932 517.198731 559.353516 549.617737 548.614075 3
3 285.247192 309.737091 289.362305 302.008728 340.750977 327.000488 323.688324 345.769379 316.963623 418.034698 ... 338.442474 373.169983 362.530914 384.511627 340.750977 376.983978 375.578827 407.997833 391.236298 1
4 458.081696 475.345093 447.743744 437.606537 476.950989 457.278748 469.523712 437.305420 436.803589 594.683228 ... 486.887451 487.991516 502.544952 553.532104 483.374573 506.158203 554.033997 552.327698 540.785339 4
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2422 441.822022 464.706024 461.996063 427.368958 481.266815 490.299988 465.207855 473.939911 446.037476 610.240295 ... 523.421570 524.927124 519.406860 577.720947 499.232788 537.473144 581.434570 584.646362 551.625122 2
2423 290.767456 264.270203 298.395477 291.670776 303.915741 331.416687 325.394592 316.762909 328.204895 404.384583 ... 352.795166 357.010651 372.969238 393.544769 368.553040 391.236298 399.767639 412.915894 403.180145 4
2424 303.915741 286.652344 305.722382 293.176300 329.810791 303.815369 317.967316 328.907471 294.179993 419.239105 ... 344.866058 354.802551 357.512512 400.972046 342.557587 383.006103 390.634094 388.024506 372.668152 1
2425 428.473022 434.394745 487.991516 408.098206 464.003449 489.396667 456.174713 431.684814 458.583557 590.066223 ... 528.038513 561.059753 473.237335 578.624268 537.673889 496.221741 553.732849 616.162048 509.169250 2
2426 454.468445 441.621277 450.754822 424.357910 459.286133 460.289825 469.222595 459.787964 443.729004 592.274353 ... 494.916931 484.779724 499.533875 555.840576 474.441772 526.231873 546.707092 568.988892 528.239258 4

2427 rows × 1756 columns

# 空值检查
df[df.isnull().values==True]
172.538 173.141 173.744 174.348 174.951 175.554 176.157 176.76 177.363 177.966 ... 1165.846 1166.373 1166.9 1167.427 1167.954 1168.481 1169.008 1169.535 1170.061 Label

0 rows × 1756 columns

# 得到x和y
x = df.iloc[:,:-1]
y = df.iloc[:,-1]

划分数据集

# 划分数据集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2)
# 数据归一化处理
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
scaler.fit(x_train)X_train = scaler.transform(x_train)
X_test = scaler.transform(x_test)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
(1941, 1755)
(1941,)
(486, 1755)
(486,)
# 对数据处理
from keras.utils import np_utilsX_train = X_train.reshape((-1,1,1755))
Y_train = np_utils.to_categorical(y_train)
X_test = X_test.reshape((-1,1,1755))
Y_test = np_utils.to_categorical(y_test)
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)
(1941, 1, 1755)
(1941, 8)
(486, 1, 1755)
(486, 8)

模型

from keras import Sequential
from keras.layers import LSTM,Activation,Dense,Dropout,Input,Embedding,BatchNormalization,Add,concatenate,Flattenmodel = Sequential()model.add(LSTM(units=50,return_sequences=True,input_shape=(1,1755)))
model.add(Dropout(0.2))model.add(LSTM(units=50,return_sequences=True))
model.add(Dropout(0.2))model.add(LSTM(units=50,return_sequences=True))
model.add(Dropout(0.2))# model.add(LSTM(units=50,return_sequences=True))
# model.add(Dropout(0.2))model.add(LSTM(units=50))
model.add(Dropout(0.2))# model.add(Dense(units=256))
# model.add(Dropout(0.2))
model.add(Dense(units=128))
model.add(Dropout(0.2))
model.add(Dense(units=64))
model.add(Dropout(0.2))
model.add(Dense(units=16))
model.add(Dropout(0.2))model.add(Dense(units=8,activation="softmax"))
# Implement Learning rate decay
from keras.callbacks import EarlyStopping,ReduceLROnPlateau,ModelCheckpoint,LearningRateSchedulercheckpoint = ModelCheckpoint("lstm.h5",monitor="val_loss",mode="min",save_best_only = True,verbose=1)earlystop = EarlyStopping(monitor = 'val_loss', min_delta = 0, patience = 5,verbose = 1,restore_best_weights = True)reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',factor = 0.2,patience = 3,verbose = 1)#min_delta = 0.00001)callbacks = [earlystop, checkpoint, reduce_lr]
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
history_fit = model.fit(x=X_train, y=Y_train, batch_size=8, epochs=30, verbose=1, validation_data=(X_test, Y_test),callbacks=callbacks)
Train on 1941 samples, validate on 486 samples
Epoch 1/30
1941/1941 [==============================] - 6s 3ms/step - loss: 1.0300 - accuracy: 0.6188 - val_loss: 0.5473 - val_accuracy: 0.8313Epoch 00001: val_loss improved from inf to 0.54729, saving model to lstm.h5
Epoch 2/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.6064 - accuracy: 0.7836 - val_loss: 0.3829 - val_accuracy: 0.8374Epoch 00002: val_loss improved from 0.54729 to 0.38287, saving model to lstm.h5
Epoch 3/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.4797 - accuracy: 0.8089 - val_loss: 0.3595 - val_accuracy: 0.8272Epoch 00003: val_loss improved from 0.38287 to 0.35947, saving model to lstm.h5
Epoch 4/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.4672 - accuracy: 0.8083 - val_loss: 0.2970 - val_accuracy: 0.8354Epoch 00004: val_loss improved from 0.35947 to 0.29702, saving model to lstm.h5
Epoch 5/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3946 - accuracy: 0.8557 - val_loss: 0.2658 - val_accuracy: 0.9033Epoch 00005: val_loss improved from 0.29702 to 0.26579, saving model to lstm.h5
Epoch 6/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3519 - accuracy: 0.8712 - val_loss: 0.2217 - val_accuracy: 0.8909Epoch 00006: val_loss improved from 0.26579 to 0.22171, saving model to lstm.h5
Epoch 7/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3287 - accuracy: 0.8743 - val_loss: 0.2439 - val_accuracy: 0.8683Epoch 00007: val_loss did not improve from 0.22171
Epoch 8/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3400 - accuracy: 0.8635 - val_loss: 0.2036 - val_accuracy: 0.9259Epoch 00008: val_loss improved from 0.22171 to 0.20360, saving model to lstm.h5
Epoch 9/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3541 - accuracy: 0.8666 - val_loss: 0.2087 - val_accuracy: 0.9321Epoch 00009: val_loss did not improve from 0.20360
Epoch 10/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3227 - accuracy: 0.8691 - val_loss: 0.2141 - val_accuracy: 0.9362Epoch 00010: val_loss did not improve from 0.20360
Epoch 11/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2842 - accuracy: 0.8851 - val_loss: 0.1821 - val_accuracy: 0.9506Epoch 00011: val_loss improved from 0.20360 to 0.18205, saving model to lstm.h5
Epoch 12/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3343 - accuracy: 0.8712 - val_loss: 0.2297 - val_accuracy: 0.8951Epoch 00012: val_loss did not improve from 0.18205
Epoch 13/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3082 - accuracy: 0.8800 - val_loss: 0.2213 - val_accuracy: 0.9321Epoch 00013: val_loss did not improve from 0.18205
Epoch 14/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2550 - accuracy: 0.9052 - val_loss: 0.1765 - val_accuracy: 0.9444Epoch 00014: val_loss improved from 0.18205 to 0.17651, saving model to lstm.h5
Epoch 15/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3290 - accuracy: 0.8856 - val_loss: 0.2044 - val_accuracy: 0.9383Epoch 00015: val_loss did not improve from 0.17651
Epoch 16/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2812 - accuracy: 0.9031 - val_loss: 0.1578 - val_accuracy: 0.9465Epoch 00016: val_loss improved from 0.17651 to 0.15778, saving model to lstm.h5
Epoch 17/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2332 - accuracy: 0.9145 - val_loss: 0.1287 - val_accuracy: 0.9547Epoch 00017: val_loss improved from 0.15778 to 0.12870, saving model to lstm.h5
Epoch 18/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2597 - accuracy: 0.9114 - val_loss: 0.1607 - val_accuracy: 0.9280Epoch 00018: val_loss did not improve from 0.12870
Epoch 19/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2570 - accuracy: 0.9052 - val_loss: 0.1230 - val_accuracy: 0.9671Epoch 00019: val_loss improved from 0.12870 to 0.12305, saving model to lstm.h5
Epoch 20/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2401 - accuracy: 0.9129 - val_loss: 0.1639 - val_accuracy: 0.9588Epoch 00020: val_loss did not improve from 0.12305
Epoch 21/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2233 - accuracy: 0.9155 - val_loss: 0.1172 - val_accuracy: 0.9671Epoch 00021: val_loss improved from 0.12305 to 0.11718, saving model to lstm.h5
Epoch 22/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2524 - accuracy: 0.9088 - val_loss: 0.1627 - val_accuracy: 0.9588Epoch 00022: val_loss did not improve from 0.11718
Epoch 23/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2185 - accuracy: 0.9176 - val_loss: 0.1313 - val_accuracy: 0.9342Epoch 00023: val_loss did not improve from 0.11718
Epoch 24/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2344 - accuracy: 0.9160 - val_loss: 0.1223 - val_accuracy: 0.9527Epoch 00024: val_loss did not improve from 0.11718Epoch 00024: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 25/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1890 - accuracy: 0.9274 - val_loss: 0.0862 - val_accuracy: 0.9691Epoch 00025: val_loss improved from 0.11718 to 0.08617, saving model to lstm.h5
Epoch 26/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1475 - accuracy: 0.9361 - val_loss: 0.0794 - val_accuracy: 0.9733Epoch 00026: val_loss improved from 0.08617 to 0.07940, saving model to lstm.h5
Epoch 27/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1507 - accuracy: 0.9392 - val_loss: 0.0673 - val_accuracy: 0.9774Epoch 00027: val_loss improved from 0.07940 to 0.06732, saving model to lstm.h5
Epoch 28/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1498 - accuracy: 0.9444 - val_loss: 0.0764 - val_accuracy: 0.9733Epoch 00028: val_loss did not improve from 0.06732
Epoch 29/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1513 - accuracy: 0.9423 - val_loss: 0.0733 - val_accuracy: 0.9774Epoch 00029: val_loss did not improve from 0.06732
Epoch 30/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1338 - accuracy: 0.9418 - val_loss: 0.0815 - val_accuracy: 0.9753Epoch 00030: val_loss did not improve from 0.06732Epoch 00030: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
# 画曲线
def plot_performance(history=None,figure_directory=None,ylim_pad=[0,0]):xlabel="Epoch"legends=["Training","Validation"]plt.figure(figsize=(20,5))y1=history.history["accuracy"]y2=history.history["val_accuracy"]min_y=min(min(y1),min(y2))-ylim_pad[0]max_y=max(max(y1),max(y2))+ylim_pad[0]plt.subplot(121)plt.plot(y1)plt.plot(y2)plt.title("Model Accuracy\n",fontsize=17)plt.xlabel(xlabel,fontsize=15)plt.ylabel("Accuracy",fontsize=15)plt.ylim(min_y,max_y)plt.legend(legends,loc="upper left")plt.grid()y1=history.history["loss"]y2=history.history["val_loss"]min_y=min(min(y1),min(y2))-ylim_pad[1]max_y=max(max(y1),max(y2))+ylim_pad[1]plt.subplot(122)plt.plot(y1)plt.plot(y2)plt.title("Model Loss:\n",fontsize=17)plt.xlabel(xlabel,fontsize=15)plt.ylabel("Loss",fontsize=15)plt.ylim(min_y,max_y)plt.legend(legends,loc="upper left")plt.grid()plt.show()
# 可视化
plot_performance(history=history_fit)

# 预测
predict_y = model.predict_classes(X_test)
predict_y
array([1, 6, 4, 3, 4, 1, 4, 6, 6, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4, 5, 4, 5,7, 1, 4, 5, 3, 4, 1, 6, 4, 4, 5, 4, 1, 1, 7, 4, 1, 4, 6, 4, 4, 5,4, 7, 7, 6, 1, 1, 5, 6, 2, 1, 4, 4, 1, 4, 4, 4, 6, 5, 2, 6, 3, 1,2, 4, 2, 4, 1, 1, 1, 1, 1, 1, 6, 4, 1, 3, 5, 2, 4, 6, 3, 4, 4, 3,6, 5, 7, 1, 1, 2, 7, 4, 1, 6, 6, 2, 6, 1, 3, 4, 1, 1, 1, 4, 2, 1,3, 6, 2, 4, 4, 4, 3, 4, 1, 1, 6, 7, 6, 7, 2, 5, 1, 3, 4, 1, 3, 3,5, 4, 4, 7, 6, 2, 6, 4, 6, 6, 3, 5, 3, 5, 6, 3, 4, 1, 3, 6, 1, 4,6, 4, 6, 2, 2, 1, 7, 4, 6, 3, 6, 6, 5, 4, 4, 4, 4, 2, 4, 6, 1, 3,1, 6, 6, 4, 1, 1, 4, 1, 4, 4, 2, 3, 1, 6, 4, 4, 3, 6, 5, 3, 4, 6,1, 1, 3, 5, 4, 1, 6, 3, 4, 3, 1, 2, 1, 4, 6, 5, 3, 5, 4, 4, 4, 4,7, 3, 1, 4, 2, 4, 6, 7, 4, 1, 4, 3, 1, 4, 1, 5, 2, 5, 3, 4, 1, 2,4, 5, 1, 4, 4, 6, 3, 1, 4, 4, 5, 5, 6, 4, 3, 3, 1, 4, 5, 1, 1, 2,3, 1, 1, 6, 7, 6, 4, 6, 1, 3, 4, 1, 4, 2, 7, 4, 5, 1, 4, 2, 1, 7,3, 6, 4, 4, 1, 7, 1, 5, 4, 4, 1, 4, 4, 1, 1, 4, 1, 1, 3, 6, 3, 3,6, 5, 4, 3, 1, 2, 6, 6, 6, 4, 2, 2, 3, 1, 5, 1, 4, 1, 7, 3, 1, 1,3, 5, 6, 2, 4, 1, 1, 6, 1, 6, 6, 6, 7, 1, 5, 4, 2, 7, 1, 6, 3, 1,4, 5, 2, 1, 4, 5, 6, 3, 1, 5, 1, 6, 3, 1, 3, 6, 6, 5, 1, 6, 4, 1,7, 3, 4, 3, 7, 3, 6, 1, 5, 3, 4, 2, 4, 5, 4, 1, 1, 4, 6, 3, 6, 5,4, 6, 1, 6, 3, 1, 4, 4, 3, 1, 5, 6, 6, 3, 5, 3, 5, 2, 1, 3, 2, 4,1, 4, 1, 3, 7, 6, 3, 4, 4, 1, 4, 2, 1, 4, 4, 2, 1, 3, 1, 3, 4, 7,4, 4, 1, 1, 1, 1, 4, 4, 1, 4, 5, 6, 5, 3, 3, 1, 4, 3, 2, 2, 6, 4,4, 3, 2, 2, 1, 6, 3, 1, 3, 1, 6, 7, 4, 4, 4, 1, 1, 4, 3, 1, 4, 5,4, 3], dtype=int64)
from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,classification_reportprint(classification_report(y_test,predict_y))
              precision    recall  f1-score   support1       1.00      1.00      1.00       1172       0.97      1.00      0.99        363       0.99      0.88      0.93        754       0.92      0.99      0.95       1175       1.00      1.00      1.00        436       1.00      0.99      0.99        737       1.00      0.96      0.98        25accuracy                           0.98       486macro avg       0.98      0.97      0.98       486
weighted avg       0.98      0.98      0.98       486

小结

瓷们,点赞评论收藏走起来呀!!!

【人工智能项目】LSTM实现数据预测分类实验相关推荐

  1. 【人工智能项目】- 机器学习实现收入分类预测报告

    [人工智能项目]- 机器学习实现收入分类预测报告 题目 利用age.workclass.-.native_country等13个特征预测收入是否超过50k,是一个二分类问题. 训练集 32561个样本 ...

  2. 【LSTM时间序列数据】基于matlab LSTM时间序列数据预测【含Matlab源码 1949期】

    ⛄一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[LSTM时间序列数据]基于matlab LSTM时间序列数据预测[含Matlab源码 1949期] 获取代码方式2: 付费专栏Matla ...

  3. 【人工智能项目】卷积神经网络图片分类框架

    [人工智能项目]卷积神经网络图片分类框架 本次硬核分享当时做图片分类的工作,主要是整理了一个图片分类的框架,如果想换模型,引入新模型,在config中修改即可.那么走起来瓷!!! 整体结构 confi ...

  4. python预测算整理集合 python根据历史数据,预测未来数据 神经网络时间序列预测python 销售收入分析与预测 神经网络预测控制 Python 源码 4个lstm做数据预测的案例源代码

    python 预测未来/神经网络/负荷/飞机零件故障/链路预测程序源码 1.python实现TensorFlow2股票股价预测(源码) 2.负荷预测(py thon例子,实时负荷预测,15分钟到4小时 ...

  5. 时间序列预测 | Python实现Prophet、ARIMA、LSTM时间序列数据预测

    时间序列预测 | Python实现Prophet.ARIMA.LSTM时间序列数据预测 目录 时间序列预测 | Python实现Prophet.ARIMA.LSTM时间序列数据预测 数据描述 特征工程 ...

  6. LSTM算法+数据预测

    传统的神经网络一般都是全连接结构,且非相邻两层之间是没有连接的.对输入为时序的样本无法解决,因此引入了RNN(可以查看具体的RNN含义和推导),但是会存在梯度消失(不同的隐层之间会存在过去时刻对当前时 ...

  7. 基于通信数据的分类实验

    学习目标: 1.理解并掌握逻辑回归分类方法: 2.掌握逻辑回归的模型效果评估: 3.掌握决策树分类应用场景. 学习内容: 1.本次实验是电信客户的流失率分析和预测. 2.通过分析用户的套餐.通话.流量 ...

  8. 威斯康星大学乳腺癌肿瘤数据预测分类代码讲解

  9. 基于Python的空气质量网络数据爬虫,构建面向深度学习数据预测的空气质量数据集

    目录 1.目标 2. 思路 3.算法 3.1 算法流程 3.2 开发环境 4 核心代码 4.1 Header伪装 4.2 get_html_soup函数 4.3 get_city_link_list函 ...

  10. 【人工智能项目】MNIST手写体识别实验及分析

    [人工智能项目]MNIST数据集实验报告 这是之前接的小作业,现在分享出来,给大家以学习!!! [人工智能项目]MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实 ...

最新文章

  1. 算法总结-1算法入门
  2. php 二维数组排序函数,php自定义函数实现二维数组排序功能
  3. 10060 mysql_navicat连接mysql服务端报10060错误解决过程如下
  4. [洛谷P1119]灾后重建
  5. argparse模块用法
  6. Hadoop启动jobhistoryserver
  7. TypeError: to_categorical() got an unexpected keyword argument 'nb_classes'
  8. linux中screen 命令简单使用
  9. Memset、Memcpy、Strcpy 的作用和区别(转)
  10. AD10操作技巧及参数
  11. Swift 包管理器教程
  12. b站python弹幕签到_一个python脚本就可以B站查找弹幕发送者!
  13. 微信小程序—给图片添加相框
  14. 本科计算机考研考英语,2016考研必需知道的10件事
  15. iOS最简单的方式实现在线播放音频。
  16. 微卡认证系统使用手册
  17. linuxprobe-脚本编写,循环语句,重定向
  18. Sailfish应用开发入门(二)Sailfish SDK 简介
  19. 辅助系统(Flume,azkaban,sqoop)
  20. 全景AR增强监视系统对接SkeyeIVMS视频云管控系统实现软硬件资源的健康状态管理(一)

热门文章

  1. android 夏令时,android 时间处理(夏令时)
  2. 个人项目-数独(Python实现)——从解数独到写游戏
  3. linpack测试软件,标准Linpack测试详细指南.pdf
  4. Enterprise Architect Professional Edition
  5. 提交到dockerHub
  6. Moodle导入CSV文件格式的试题
  7. Mac共享主机网络给虚拟机
  8. 解决win10系统飞秋不在线问题
  9. 五子棋的实现 Java课程设计
  10. 在没有SSRS的ASP.NET中运行RDL/RDLC(SQL报告)