文章大纲

  • Using Variational Autoencoders to predict future diagnosis
      • VAE for Collaborative Filtering
      • Diagnosis Codes
      • Data
      • VAE for next diagnosis prediction
      • Applications
      • Some background
    • IMPLEMENTATION
      • Function to plot Losses
      • Load Data
      • Configure Network
      • Using a two output network
      • Network 1: Same objective functions for the two outputs
      • Network 2: Different objective functions for the two outputs
      • Calculating Recall
      • Impact of different ways to calculating the objective functions
  • 参考答案

Using Variational Autoencoders to predict future diagnosis

VAE for Collaborative Filtering

This work is an adaptation of the work by Dawen et.al who used VAEs for the purpose of Collaborative filtering. The work by Dawen et.al exploits the Generative nature of VAEs to arrive at a completed user-preference information given an input of partial user-preference information.

Work by Dawen et.al: https://arxiv.org/abs/1802.05814

Diagnosis Codes

In the healthcare industry, the diagnosis any patient encounters, has been standardized with diagnosis codes. Each disease or a medical condition is mapped to a diagnosis code.
ICD10 Diagnosis Codes: https://www.icd10data.com/

Data

In the data we are using, we have information on a set of patients and the diagnoses that they have undergone. Each of these diagnoses are mapped to diagnosis codes. The data contains a total of 1567 unique diagnosis codes.
So a given patient is represented by a binary vector of dimension 1567 where an element is 1 if that patient has undergone the particular diagnosis and 0 otherwise.

VAE for next diagnosis prediction

Now given the patient diagnosis information, the VAE encodes it into a latent space. It learns the information on distribution of patients and the clusters of diagnoses they undergo.

To provide a simple example if you consider diabetes in older adults, the group of diagnosis that would appear commonly among such adults would be something like, Diabetes, Cholesterol, Blood Pressure, Arthritics etc,.
Now given a patient with a diagnosis set which says something like, Diabetes and Cholesterol, this patient would get mapped to the same space in the latent dimension as the older adults with the diagnosis mentioned earlier.

This mapping of similar patients to similar latent space has a very favourable impact on decoding/reconstruction to original space. On decoding, what happens is that the, missing diagnosis with high probabilities of occurence, for a particular patient is also reconstructed.

This ability to fill in the missing diagnosis in the form of a Collaborative Filtering of sorts is why I apply this technique to predict the next diagnosis.

Applications

This work can be used for many applications ranging from insurance companies using it to better predict a patient’s needs to healthcare applications which encourage people to improve their life-style choices.

Some background

I came across the application of VAEs for Collaborative filtering, when I studied it for my previous work “Hybrid VAE for Collaborative Filtering”. This work processes the movie plot information from IMDb and uses it as an input to improve movie recommendation systems. This particular work was published in RecSys 2018 Knowledge Transfer Learning workshop: https://arxiv.org/abs/1808.01006

IMPLEMENTATION

 %matplotlib inline
import numpy as np
import pickle
import os
from matplotlib import pyplot as plt
from keras.layers import Input, Dense, Lambda, Multiply, Dropout,Embedding, Flatten, Activation, Reshape
from keras.models import Model
from keras import losses
from keras import backend as K
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback
from IPython.display import clear_output
from sklearn import preprocessing
from keras import regularizers
import keras
import pandas as pd
import numpy as np
Using TensorFlow backend.
import os
os.getcwd()
'C:\\Users\\iz\\Desktop'
df = pd.read_csv('test.csv')
df = df.drop(['KEY'],axis = 1)
df.head()
T40 A08 I69 Z48 R44 N92 R59 B97 M96 I35 ... H61 T84 M16 J38 Z90 D68 K83 Z87 Z75 Z43
0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 1 0 0 1 0 0
1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 1 0 0 1 0 0
2 0 1 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 1 0 1 0 0
3 0 0 0 0 0 0 1 0 0 0 ... 0 0 0 0 0 0 0 1 0 0
4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 500 columns

from sklearn.model_selection import train_test_splity = range(df.shape[0])
xtrain, xtest, ytrain, ytest = train_test_split(df, y, test_size = 0.1, random_state = 42)
xtrain, xval, ytrain, yval = train_test_split(xtrain, ytrain, test_size = 0.1, random_state = 42)

# import numpy as npwith open('./train.data', 'wb') as f:np.save(f, xtrain)
with open('./test.data', 'wb') as f:np.save(f, xtest)
with open('./val.data', 'wb') as f:np.save(f, xval)

Function to plot Losses

class PlotLosses(Callback):def on_train_begin(self, logs={}):self.i = 0self.x = []self.losses = []self.val_losses = []        self.fig = plt.figure()self.logs = []def on_epoch_end(self, epoch, logs={}):self.logs.append(logs)self.x.append(self.i)self.losses.append(logs.get('loss'))self.val_losses.append(logs.get('val_loss'))self.i += 1clear_output(wait=True)plt.plot(self.x, self.losses, label="loss")plt.plot(self.x, self.val_losses, label="val_loss")plt.legend()plt.show();plot_losses = PlotLosses()

Load Data

with open('train.data', 'rb') as f:x_train = np.load(f)
print("number of training users: ", x_train.shape[0])with open('val.data', 'rb') as f:x_val = np.load(f)
print("number of validation users: ", x_val.shape[0])
number of training users:  5234
number of validation users:  582
x_train.shape,x_val.shape
((5234, 500), (582, 500))
x_train[0].shape
(500,)
x_train = x_train[:5000]
x_val = x_val[:500]

Configure Network

# encoder/decoder network sizebatch_size=100
original_dim = x_train.shape[1]
intermediate_dim=200
latent_dim=100
nb_epochs=30
epsilon_std=1.0

Using a two output network

Here, we have two outputs from the network, which is much different compared to the original VAE network proposed by Dawen et.al.

The first output reconstructs the given input, while the second output gives out a probability distribution over the Diagnosis codes. Each of them have a specific loss function that maximizes the particular objective.

Network 1: Same objective functions for the two outputs

#Function to increase the relevance of the KL regularization as the training progresses
class increaseBeta(Callback):def __init__(self):self.global_beta = 0.0def on_train_begin(self, logs={}):self.global_beta = 0.0def on_epoch_end(self, epoch, logs={}):self.global_beta = self.global_beta + 0.01updateBeta = increaseBeta()#Function to l2 normalize the inputs
def l2normalize(args):_x=argsreturn K.l2_normalize(_x, axis = -1)#Function to do the sampling from Latent Space
def sampling(args):_mean,_log_var=argsepsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)return _mean+K.exp(_log_var/2)*epsilon# encoder networkx=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid')
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)# We have two outputs, one which reconstructs the given input, the other which reconstructs the probability
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)def vae_loss(x,x_bar):reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)return reconst_loss + (updateBeta.global_beta)*kl_loss# build and compile model
vae = Model(x, [x_decoded, x_probability])
vae.compile(optimizer='adam', loss=vae_loss, loss_weights=[1., 1.])weightsPath = "./weights/weights_vae1.hdf5"
x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\validation_data=(x_val, [x_val, x_val]), )# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)# vae.fit(x =
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 758us/step - loss: 78.1787 - dense_23_loss: 34.3244 - dense_24_loss: 43.8543 - val_loss: 64.6742 - val_dense_23_loss: 22.1748 - val_dense_24_loss: 42.4994
Epoch 2/30
5000/5000 [==============================] - 4s 743us/step - loss: 64.4046 - dense_23_loss: 21.9435 - dense_24_loss: 42.4611 - val_loss: 63.6366 - val_dense_23_loss: 21.5553 - val_dense_24_loss: 42.0814
Epoch 3/30
5000/5000 [==============================] - 4s 837us/step - loss: 63.6365 - dense_23_loss: 21.5137 - dense_24_loss: 42.1228 - val_loss: 62.9220 - val_dense_23_loss: 21.1661 - val_dense_24_loss: 41.7559
Epoch 4/30
5000/5000 [==============================] - 4s 727us/step - loss: 62.5700 - dense_23_loss: 20.9471 - dense_24_loss: 41.6229 - val_loss: 61.4582 - val_dense_23_loss: 20.4014 - val_dense_24_loss: 41.0568
Epoch 5/30
5000/5000 [==============================] - 4s 733us/step - loss: 61.5120 - dense_23_loss: 20.3677 - dense_24_loss: 41.1443 - val_loss: 60.5572 - val_dense_23_loss: 19.8999 - val_dense_24_loss: 40.6573 los
Epoch 6/30
5000/5000 [==============================] - 3s 669us/step - loss: 60.6162 - dense_23_loss: 19.8809 - dense_24_loss: 40.7353 - val_loss: 59.4887 - val_dense_23_loss: 19.3130 - val_dense_24_loss: 40.1757
Epoch 7/30
5000/5000 [==============================] - 4s 879us/step - loss: 59.7845 - dense_23_loss: 19.4161 - dense_24_loss: 40.3683 - val_loss: 58.6560 - val_dense_23_loss: 18.8436 - val_dense_24_loss: 39.8125
Epoch 8/30
5000/5000 [==============================] - 3s 581us/step - loss: 59.0007 - dense_23_loss: 18.9740 - dense_24_loss: 40.0267 - val_loss: 57.7723 - val_dense_23_loss: 18.3500 - val_dense_24_loss: 39.4223
Epoch 9/30
5000/5000 [==============================] - 3s 664us/step - loss: 58.1647 - dense_23_loss: 18.4887 - dense_24_loss: 39.6760 - val_loss: 56.8185 - val_dense_23_loss: 17.8036 - val_dense_24_loss: 39.0149
Epoch 10/30
5000/5000 [==============================] - 4s 705us/step - loss: 57.3740 - dense_23_loss: 18.0296 - dense_24_loss: 39.3444 - val_loss: 56.0715 - val_dense_23_loss: 17.3908 - val_dense_24_loss: 38.6807
Epoch 11/30
5000/5000 [==============================] - 4s 711us/step - loss: 56.7545 - dense_23_loss: 17.6735 - dense_24_loss: 39.0810 - val_loss: 55.1875 - val_dense_23_loss: 16.8682 - val_dense_24_loss: 38.3194
Epoch 12/30
5000/5000 [==============================] - 3s 550us/step - loss: 56.1791 - dense_23_loss: 17.3366 - dense_24_loss: 38.8425 - val_loss: 54.6427 - val_dense_23_loss: 16.5640 - val_dense_24_loss: 38.0787
Epoch 13/30
5000/5000 [==============================] - 3s 568us/step - loss: 55.6814 - dense_23_loss: 17.0418 - dense_24_loss: 38.6397 - val_loss: 54.1010 - val_dense_23_loss: 16.2432 - val_dense_24_loss: 37.8578
Epoch 14/30500/5000 [==>...........................] - ETA: 2s - loss: 54.4526 - dense_23_loss: 16.7270 - dense_24_loss: 37.7256C:\ProgramData\Anaconda3\envs\zhongdian\lib\site-packages\keras\callbacks\callbacks.py:95: RuntimeWarning: Method (on_train_batch_end) is slow compared to the batch update (0.111798). Check your callbacks.% (hook_name, delta_t_median), RuntimeWarning)5000/5000 [==============================] - 3s 624us/step - loss: 55.2011 - dense_23_loss: 16.7642 - dense_24_loss: 38.4370 - val_loss: 53.4904 - val_dense_23_loss: 15.8693 - val_dense_24_loss: 37.6210
Epoch 15/30
5000/5000 [==============================] - ETA: 0s - loss: 54.8870 - dense_23_loss: 16.5693 - dense_24_loss: 38.31 - 3s 699us/step - loss: 54.8549 - dense_23_loss: 16.5582 - dense_24_loss: 38.2967 - val_loss: 53.0989 - val_dense_23_loss: 15.6399 - val_dense_24_loss: 37.4590
Epoch 16/30
5000/5000 [==============================] - 3s 612us/step - loss: 54.5157 - dense_23_loss: 16.3574 - dense_24_loss: 38.1583 - val_loss: 52.5809 - val_dense_23_loss: 15.3208 - val_dense_24_loss: 37.2601
Epoch 17/30
5000/5000 [==============================] - 3s 628us/step - loss: 54.2176 - dense_23_loss: 16.1803 - dense_24_loss: 38.0373 - val_loss: 52.2525 - val_dense_23_loss: 15.1291 - val_dense_24_loss: 37.1233
Epoch 18/30
5000/5000 [==============================] - 4s 744us/step - loss: 53.8684 - dense_23_loss: 15.9648 - dense_24_loss: 37.9036 - val_loss: 51.7984 - val_dense_23_loss: 14.8564 - val_dense_24_loss: 36.9420 - dense_23_
Epoch 19/30
5000/5000 [==============================] - 3s 600us/step - loss: 53.5824 - dense_23_loss: 15.7955 - dense_24_loss: 37.7869 - val_loss: 51.4563 - val_dense_23_loss: 14.6374 - val_dense_24_loss: 36.8189
Epoch 20/30
5000/5000 [==============================] - 3s 534us/step - loss: 53.3213 - dense_23_loss: 15.6395 - dense_24_loss: 37.6818 - val_loss: 51.2304 - val_dense_23_loss: 14.5035 - val_dense_24_loss: 36.7269
Epoch 21/30
5000/5000 [==============================] - 4s 701us/step - loss: 53.0680 - dense_23_loss: 15.4855 - dense_24_loss: 37.5825 - val_loss: 50.9055 - val_dense_23_loss: 14.3052 - val_dense_24_loss: 36.6004
Epoch 22/30
5000/5000 [==============================] - 4s 735us/step - loss: 52.8112 - dense_23_loss: 15.3356 - dense_24_loss: 37.4755 - val_loss: 50.6851 - val_dense_23_loss: 14.1706 - val_dense_24_loss: 36.5145
Epoch 23/30
5000/5000 [==============================] - 4s 733us/step - loss: 52.6116 - dense_23_loss: 15.2092 - dense_24_loss: 37.4025 - val_loss: 50.3594 - val_dense_23_loss: 13.9758 - val_dense_24_loss: 36.3836
Epoch 24/30
5000/5000 [==============================] - 4s 849us/step - loss: 52.4208 - dense_23_loss: 15.0949 - dense_24_loss: 37.3259 - val_loss: 50.1555 - val_dense_23_loss: 13.8413 - val_dense_24_loss: 36.3142
Epoch 25/30
5000/5000 [==============================] - 4s 732us/step - loss: 52.2180 - dense_23_loss: 14.9752 - dense_24_loss: 37.2428 - val_loss: 49.9194 - val_dense_23_loss: 13.6918 - val_dense_24_loss: 36.2276
Epoch 26/30
5000/5000 [==============================] - 3s 604us/step - loss: 52.0623 - dense_23_loss: 14.8808 - dense_24_loss: 37.1815 - val_loss: 49.7312 - val_dense_23_loss: 13.5534 - val_dense_24_loss: 36.1778
Epoch 27/30
5000/5000 [==============================] - 3s 507us/step - loss: 51.8958 - dense_23_loss: 14.7736 - dense_24_loss: 37.1222 - val_loss: 49.5618 - val_dense_23_loss: 13.4641 - val_dense_24_loss: 36.0977
Epoch 28/30
5000/5000 [==============================] - 3s 566us/step - loss: 51.7471 - dense_23_loss: 14.6879 - dense_24_loss: 37.0592 - val_loss: 49.3073 - val_dense_23_loss: 13.3000 - val_dense_24_loss: 36.0073
Epoch 29/30
5000/5000 [==============================] - 3s 595us/step - loss: 51.5263 - dense_23_loss: 14.5482 - dense_24_loss: 36.9781 - val_loss: 49.1972 - val_dense_23_loss: 13.2198 - val_dense_24_loss: 35.9775
Epoch 30/30
5000/5000 [==============================] - 3s 551us/step - loss: 51.4380 - dense_23_loss: 14.4985 - dense_24_loss: 36.9395 - val_loss: 48.9715 - val_dense_23_loss: 13.0922 - val_dense_24_loss: 35.8793<keras.callbacks.callbacks.History at 0x226a912c6d8>

Network 2: Different objective functions for the two outputs

# Function to increase the relevance of the KL regularization as the training progressesclass increaseBeta(Callback):def __init__(self):self.global_beta = 0.0def on_train_begin(self, logs={}):self.global_beta = 0.0def on_epoch_end(self, epoch, logs={}):self.global_beta = self.global_beta + 0.01updateBeta = increaseBeta()#Function to l2 normalize the inputs
def l2normalize(args):_x=argsreturn K.l2_normalize(_x, axis = -1)#Function to do the sampling from Latent Space
def sampling(args):_mean,_log_var=argsepsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)return _mean+K.exp(_log_var/2)*epsilon# encoder network
x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid')
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)
#We have two outputs, one which reconstructs the given input, the other which reconstructs the probability
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)def vae_loss1(x,x_bar):reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)return reconst_loss + (updateBeta.global_beta)*kl_lossdef vae_loss2(x,x_bar):neg_ll = -K.sum(x_bar*x, axis = -1)kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)return neg_ll + (updateBeta.global_beta)*kl_loss# build and compile model
vae2 = Model(x, [x_decoded, x_probability])
vae2.compile(optimizer='adam', loss=[vae_loss1, vae_loss2], loss_weights=[0.5, 0.5])# weightsPath = "./weights/weights_vae2.hdf5"
# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)# vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\validation_data=(x_val, [x_val, x_val]), )
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 729us/step - loss: 16.9416 - dense_29_loss: 34.2665 - dense_30_loss: -0.3832 - val_loss: 10.8511 - val_dense_29_loss: 22.3960 - val_dense_30_loss: -0.6938
Epoch 2/30
5000/5000 [==============================] - 3s 596us/step - loss: 10.6567 - dense_29_loss: 22.0604 - dense_30_loss: -0.7470 - val_loss: 10.5063 - val_dense_29_loss: 21.7948 - val_dense_30_loss: -0.7821
Epoch 3/30
5000/5000 [==============================] - 3s 587us/step - loss: 10.4868 - dense_29_loss: 21.7343 - dense_30_loss: -0.7607 - val_loss: 10.3348 - val_dense_29_loss: 21.4568 - val_dense_30_loss: -0.7872
Epoch 4/30
5000/5000 [==============================] - 3s 588us/step - loss: 10.3471 - dense_29_loss: 21.4561 - dense_30_loss: -0.7619 - val_loss: 10.2142 - val_dense_29_loss: 21.2168 - val_dense_30_loss: -0.7883
Epoch 5/30
5000/5000 [==============================] - 3s 583us/step - loss: 10.1537 - dense_29_loss: 21.0695 - dense_30_loss: -0.7622 - val_loss: 9.8815 - val_dense_29_loss: 20.5522 - val_dense_30_loss: -0.7891
Epoch 6/30
5000/5000 [==============================] - 4s 700us/step - loss: 9.8399 - dense_29_loss: 20.4424 - dense_30_loss: -0.7627 - val_loss: 9.6098 - val_dense_29_loss: 20.0090 - val_dense_30_loss: -0.7894
Epoch 7/30
5000/5000 [==============================] - 3s 661us/step - loss: 9.6009 - dense_29_loss: 19.9645 - dense_30_loss: -0.7627 - val_loss: 9.3354 - val_dense_29_loss: 19.4604 - val_dense_30_loss: -0.7896
Epoch 8/30
5000/5000 [==============================] - 4s 732us/step - loss: 9.3988 - dense_29_loss: 19.5603 - dense_30_loss: -0.7628 - val_loss: 9.1374 - val_dense_29_loss: 19.0644 - val_dense_30_loss: -0.7897
Epoch 9/30
5000/5000 [==============================] - 3s 698us/step - loss: 9.1813 - dense_29_loss: 19.1254 - dense_30_loss: -0.7629 - val_loss: 8.9246 - val_dense_29_loss: 18.6389 - val_dense_30_loss: -0.7896
Epoch 10/30
5000/5000 [==============================] - 4s 761us/step - loss: 8.9580 - dense_29_loss: 18.6789 - dense_30_loss: -0.7629 - val_loss: 8.6732 - val_dense_29_loss: 18.1362 - val_dense_30_loss: -0.7898
Epoch 11/30
5000/5000 [==============================] - 3s 695us/step - loss: 8.7604 - dense_29_loss: 18.2836 - dense_30_loss: -0.7629 - val_loss: 8.3996 - val_dense_29_loss: 17.5891 - val_dense_30_loss: -0.7899
Epoch 12/30
5000/5000 [==============================] - 3s 628us/step - loss: 8.5640 - dense_29_loss: 17.8911 - dense_30_loss: -0.7630 - val_loss: 8.1917 - val_dense_29_loss: 17.1734 - val_dense_30_loss: -0.7899
Epoch 13/30
5000/5000 [==============================] - 3s 616us/step - loss: 8.3888 - dense_29_loss: 17.5405 - dense_30_loss: -0.7630 - val_loss: 7.9692 - val_dense_29_loss: 16.7284 - val_dense_30_loss: -0.7900
Epoch 14/30
5000/5000 [==============================] - 3s 638us/step - loss: 8.2512 - dense_29_loss: 17.2653 - dense_30_loss: -0.7630 - val_loss: 7.8081 - val_dense_29_loss: 16.4062 - val_dense_30_loss: -0.7900
Epoch 15/30
5000/5000 [==============================] - 4s 777us/step - loss: 8.1205 - dense_29_loss: 17.0040 - dense_30_loss: -0.7630 - val_loss: 7.6419 - val_dense_29_loss: 16.0737 - val_dense_30_loss: -0.7900
Epoch 16/30
5000/5000 [==============================] - 5s 972us/step - loss: 7.9886 - dense_29_loss: 16.7401 - dense_30_loss: -0.7630 - val_loss: 7.5225 - val_dense_29_loss: 15.8349 - val_dense_30_loss: -0.7900
Epoch 17/30
5000/5000 [==============================] - 4s 724us/step - loss: 7.8793 - dense_29_loss: 16.5216 - dense_30_loss: -0.7630 - val_loss: 7.4022 - val_dense_29_loss: 15.5944 - val_dense_30_loss: -0.7900
Epoch 18/30
5000/5000 [==============================] - 4s 753us/step - loss: 7.7752 - dense_29_loss: 16.3134 - dense_30_loss: -0.7630 - val_loss: 7.2394 - val_dense_29_loss: 15.2688 - val_dense_30_loss: -0.7900
Epoch 19/30
5000/5000 [==============================] - 4s 841us/step - loss: 7.6858 - dense_29_loss: 16.1345 - dense_30_loss: -0.7630 - val_loss: 7.1225 - val_dense_29_loss: 15.0350 - val_dense_30_loss: -0.7900
Epoch 20/30
5000/5000 [==============================] - 4s 801us/step - loss: 7.5966 - dense_29_loss: 15.9562 - dense_30_loss: -0.7630 - val_loss: 6.9988 - val_dense_29_loss: 14.7876 - val_dense_30_loss: -0.7900
Epoch 21/30
5000/5000 [==============================] - 4s 744us/step - loss: 7.5157 - dense_29_loss: 15.7943 - dense_30_loss: -0.7630 - val_loss: 6.9055 - val_dense_29_loss: 14.6010 - val_dense_30_loss: -0.7900
Epoch 22/30
5000/5000 [==============================] - 4s 847us/step - loss: 7.4233 - dense_29_loss: 15.6096 - dense_30_loss: -0.7630 - val_loss: 6.8266 - val_dense_29_loss: 14.4433 - val_dense_30_loss: -0.7900
Epoch 23/30
5000/5000 [==============================] - 3s 668us/step - loss: 7.3547 - dense_29_loss: 15.4724 - dense_30_loss: -0.7630 - val_loss: 6.7169 - val_dense_29_loss: 14.2239 - val_dense_30_loss: -0.7900
Epoch 24/30
5000/5000 [==============================] - 3s 653us/step - loss: 7.2897 - dense_29_loss: 15.3423 - dense_30_loss: -0.7630 - val_loss: 6.6318 - val_dense_29_loss: 14.0535 - val_dense_30_loss: -0.7900
Epoch 25/30
5000/5000 [==============================] - 3s 693us/step - loss: 7.2195 - dense_29_loss: 15.2020 - dense_30_loss: -0.7630 - val_loss: 6.5772 - val_dense_29_loss: 13.9444 - val_dense_30_loss: -0.7900
Epoch 26/30
5000/5000 [==============================] - 4s 846us/step - loss: 7.1611 - dense_29_loss: 15.0852 - dense_30_loss: -0.7630 - val_loss: 6.5330 - val_dense_29_loss: 13.8560 - val_dense_30_loss: -0.7900
Epoch 27/30
5000/5000 [==============================] - 4s 817us/step - loss: 7.1067 - dense_29_loss: 14.9764 - dense_30_loss: -0.7630 - val_loss: 6.4391 - val_dense_29_loss: 13.6681 - val_dense_30_loss: -0.7900
Epoch 28/30
5000/5000 [==============================] - 4s 779us/step - loss: 7.0528 - dense_29_loss: 14.8687 - dense_30_loss: -0.7630 - val_loss: 6.3591 - val_dense_29_loss: 13.5081 - val_dense_30_loss: -0.7900
Epoch 29/30
5000/5000 [==============================] - 4s 728us/step - loss: 6.9904 - dense_29_loss: 14.7439 - dense_30_loss: -0.7630 - val_loss: 6.2726 - val_dense_29_loss: 13.3353 - val_dense_30_loss: -0.7900
Epoch 30/30
5000/5000 [==============================] - 5s 901us/step - loss: 6.9431 - dense_29_loss: 14.6491 - dense_30_loss: -0.7630 - val_loss: 6.2600 - val_dense_29_loss: 13.3100 - val_dense_30_loss: -0.7900<keras.callbacks.callbacks.History at 0x226ad466748>
with open('test.data', 'rb') as f:x_test = np.load(f)
print("number of testing users: ", x_test.shape[0])
number of testing users:  647
x_test = x_test[:600]
x_test.shape
(600, 500)
# x_test[0]

Calculating Recall

The way we are testing the trained system here is something like this. For each patient,

  1. We choose a random diagnosis of the M diagnoses for which the patient has the value 1 (The patient has undergone that diagnosis)
  2. We set that diagnosis to 0.
  3. We pass it through the network to arrive at the probability distribution for the diagnosis codes.
  4. Sort the diagnosis codes by their probabilities.
  5. The network was given an input with M-1 diagnosis. We know calculate the recall@k as the percentage of times the missing diagnosis was seen in the (M-1)+k top spots with respect to its probability.



x_test_hold_new = np.copy(x_test)
hold_out_ind_new = [np.random.choice(np.nonzero(i)[0]) for i in x_test[:,:473]]
for i in range(x_test.shape[0]) :x_test_hold_new[i][hold_out_ind_new[i]] = 0
def calc_heldout_recall_new(x_test, x_rec, k):count = 1.0tot = 1.0x_rank = np.argsort(x_rec)for i in range(x_rank.shape[0]):sm = np.sum(x_test[i])-1if sm < 5:continueelse:tot +=1if hold_out_ind_new[i] in x_rank[i][-(k+sm):]:count+=1.0return count/tot
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
# x_rec, x_prob = vae2.predict(x_test_hold_new, batch_size=batch_size)
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)for k in [1, 2, 3, 4, 5, 10,15]:print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))


x_test_hold = np.copy(x_test)
hold_out_ind = [np.random.choice(np.nonzero(i)[0]) for i in x_test]
for i in range(x_test.shape[0]) :x_test_hold[i][hold_out_ind[i]] = 0
def calc_heldout_recall(x_test, x_rec, k):count = 1.0tot = 1.0x_rank = np.argsort(x_rec)for i in range(x_rank.shape[0]):sm = np.sum(x_test[i])-1if sm < 5:continueelse:tot +=1if hold_out_ind[i] in x_rank[i][-(k+sm):]:count+=1.0return count/tot
x_rec, x_prob = vae.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:print(calc_heldout_recall(x_test, x_prob, k))
x_rec, x_prob = vae2.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10]:print(calc_heldout_recall(x_test, x_prob, k))

Impact of different ways to calculating the objective functions

We can see that the recall@k where k = 1, 2, 3, 4, 5, 10, 15 is pretty significant, considering that the network had to choose among 1500 other diagnoses.

An interesting observation is that, the second approach of calculating the objective, captures the recalls for smaller 'k’s in a better way compared to the first approach. This is however the opposite when it comes to the larger 'k’s


参考答案

《大数据+AI在大健康领域中最佳实践前瞻》---- 基于变分自编码器(VAE) 进行疾病预测实现相关推荐

  1. 《大数据+AI在大健康领域中最佳实践前瞻》---- 以元数据管理角度出发看人工智能医疗器械标准数据集的构建

    文章大纲 元数据治理 构建思路 我国数据集构建的规则 外国数据集构建思路参考 构建过程中需要注意的问题 数据收集SOP(Standard Operating Procedure) 元数据收集 数据收集 ...

  2. 《大数据+AI在大健康领域中最佳实践前瞻》---- 连续血糖监测(CGM) 初探

    文章大纲 1. 全球糖尿病背景 2. 中国糖尿病问题现状 3. 持续血糖检测及其背景简介 3.1 CGM 原理 3.2 CGM技术分类与代表企业 3.3 CGM 未来趋势 4. 使用雅培瞬感传感器进行 ...

  3. 《大数据+AI在大健康领域中最佳实践前瞻》---- 使用python PyHCUP 处理 hcup 数据集的asc 格式数据

    文章大纲 简介 环境搭建 python 及jupyter 环境 About Example Usage Load a datafile/loadfile combination. 样例程序 Short ...

  4. 《大数据+AI在大健康领域中最佳实践前瞻》 ---- 健康体检类数据的元数据解读与探索

    文章大纲 1.数据需求 DATA DEMAND 2 .数据接入前瞻 DATA INGESTION PLAN 2.1整体方案 OVERALL PLAN 2.2接入账户 INGESTION ACCOUNT ...

  5. 《大数据+AI在大健康领域中最佳实践前瞻》---- 医疗知识库的未来形态:医疗知识图谱前瞻

    文章大纲 简介 知识库构建思路 医学知识图谱的构建 医学知识抽取 医学知识描述体系 自动构建与提取技术 医学知识图谱质量评估 CMeKG(Chinese MedicalKnowledge Graph) ...

  6. 人文大数据及其在数字人文领域中的应用

    人文大数据及其在数字人文领域中的应用 陈静 南京大学艺术学院,江苏 南京 210031 摘要:人文大数据是指基于数字化或者数字生成的,被认为是人文艺术范畴的大规模数据集.与科学.工程及社会科学数据相比 ...

  7. 大数据风控---Vintage在金融信贷领域的运用实践

    前言 vintage这个词源于葡萄酒业,意思是葡萄酒的酿造年份,主要用来分析和管理葡萄酒年份对于酒的品质影响问题研究.在现代金融领域,同理,在比较放贷质量的时候,按账龄(month of book,M ...

  8. 大数据医疗展开新排位赛,中康科技赴港能否“C”位出道?

    经新冠疫情后,在线医疗服务的需求猛增,"互联网医疗"在已持续一年有余的.火热的二级市场翻炒中,成为炙手可热的概念. 2020年8月19日,国家健康医疗大数据研究院正式成立.国家健康 ...

  9. 物联网和大数据可应用在哪些领域?

    物联网和大数据可应用在哪些领域?物联网和大数据是近年来最受媒体和企业关注的两大宏观技术趋势.两者也并驾齐驱,物联网旨在特定组织或环境中创建一个互联网络,使用该网络来收集数据并集中执行特定功能.物联网部 ...

最新文章

  1. C++之头文件与源文件
  2. PhantomJS 基础及示例
  3. LeetCode 819. Most Common Word
  4. 重温java中的String,StringBuffer,StringBuilder类
  5. Linux系统之高级用户组和权限管理
  6. 手写Python中列表和字符串的反转
  7. 隐私安全的必答题,网易云信如何解?
  8. 在Vue项目中添加vue router
  9. Android实现在线阅读PDF文件
  10. MNIST数据集下载与保存为图片格式
  11. 手机号码状态检测(空号检测)的原理
  12. C# 如何插入、删除Excel分页符
  13. java 第十一章 多线程技术
  14. android开发转盘按钮,Android中的转盘
  15. CV2 找不到指定模块
  16. 百度地图绘制大量标注点卡顿问题完美解决
  17. 工程测量计算机在线用,《用TI 图形计算器学编程》—应用篇—工程测量.pdf
  18. ELK生态系统——修改es中index的mapping平滑过渡数据
  19. 微信帐号检测的实用小方法
  20. Scrum立会报告+燃尽图(Beta阶段第二次)

热门文章

  1. Spring5学习笔记3
  2. 地壳中元素含量排名记忆口诀_地壳中含量最多的八种元素,重温高中化学知识...
  3. php生成随机域名,php生成短域名函数的用法
  4. 服务器内存插错通道影响,插入方式不对?内存插法与性能关系揭秘
  5. STM32JTAG调试接口PB3、PB4复用问题
  6. 电机一启动或负载电流瞬间增大,电源输出电压就下降,欠压,随后又恢复正常
  7. python turtle画龙卷风,打电话给xtail龙卷风.过程.子过程
  8. TriCore AURIX TC397一览
  9. 随机生成10个包含数字、字母的8位数密码
  10. 天天写“业务代码”,如何成为“技术大拿?