文章目录

  • 数据集
  • 目标
  • 数据
  • 作业
  • KNN
  • PCA
  • Autoencoder
  • 评估

数据集

  • 数据集下载test
  • 数据集下载train

目标

  • Semi-supervised anomaly detection: 在只给定干净的(无anomaly)training data的情况下,分辨 testing data 中哪些 data 是来自 training 或是从未见过的类别

数据

  • Training: 某个 image dataset 的 training data (大小32323) 中的属于某些 label的data(40000 笔)
  • Testing: 此 dataset 的所有 testing data(10000 笔)
  • Notice: 请勿使用额外data进行training,亦不可使用pretrained model。可用额外data辅助validation。禁止搜寻或手标给定的data

作业

  • 这份作业要执行的 task 是 semi-supervised anomaly detection,也就是说 training set 是干净的,testing 的时候才会混进 outlier data (anomaly)。我们以某个简单的image dataset(image 加上他们的label(分类))作为示范,training data 为原先training set 中的某几类,而testing data 则是原先testing set 的所有data,要侦测的anomaly 为training data 中未出现的类别。 label 的部分,1 为 outlier data,而 0 为 inlier data(相对于 outlier)。正确率以 AUC 计算。

注:因为原始数据的label未公开,因此分数未显示出。

  • 方法以下列举 3 种: K-means, PCA, Autoencoder。
# 下载数据集
# !gdown --id '1_zT3JOpvXFGr7mkxs3XJDeGxTn_8pItq' --output train.npy
# !gdown --id '11Y_6JDjlhIY-M5-jW1rLRshDMqeKi9Kr' --output test.npy
import numpy as nptrain = np.load('train.npy', allow_pickle=True)
test = np.load('test.npy', allow_pickle=True)
task='ae'

KNN

  • K-means: 假设 training data 的 label 种类不多(e.g., < 20),然而因其为未知,可以猜测其为 n
  • 假设 training data 有 n 群。先用 K-means 计算 training data 中的 n 个 centroid,再用这 n 个 centroid 对 training data 分群
  • 应该可以观察到,inlier data 与所分到群的 centroid 的距离应较 outlier 的此距离来得小。
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import f1_score, pairwise_distances, roc_auc_score
from scipy.cluster.vq import vq, kmeansif task == 'knn':x = train.reshape(len(train), -1)y = test.reshape(len(test), -1)scores = list()for n in range(1, 10):kmeans_x = MiniBatchKMeans(n_clusters=n, batch_size=100).fit(x)y_cluster = kmeans_x.predict(y)y_dist = np.sum(np.square(kmeans_x.cluster_centers_[y_cluster] - y), axis=1)y_pred = y_dist# score = f1_score(y_label, y_pred, average='micro')# score = roc_auc_score(y_label, y_pred, average='micro')# scores.append(score)# print(np.max(scores), np.argmax(scores))# print(scores)# print('auc score: {}'.format(np.max(scores)))

PCA

  • 首先计算 training data 的 principle component
  • 将 testing data 投影在这些 component 上
  • 再将这些投影重建回原先 space 的向量
  • 对重建的图片和原图计算 MSE,inlier data 的数值应该较 outlier 的数值为小。
from sklearn.decomposition import PCAif task == 'pca':x = train.reshape(len(train), -1)y = test.reshape(len(test), -1)pca = PCA(n_components=2).fit(x)y_projected = pca.transform(y)y_reconstructed = pca.inverse_transform(y_projected)  dist = np.sqrt(np.sum(np.square(y_reconstructed - y).reshape(len(y), -1), axis=1))y_pred = dist
#     score = roc_auc_score(y_label, y_pred, average='micro')
#     score = f1_score(y_label, y_pred, average='micro')
#     print('auc score: {}'.format(score))

Autoencoder

import torch
from torch import nn
import torch.nn.functional as Fclass fcn_autoencoder(nn.Module):def __init__(self):super(fcn_autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Linear(32 * 32 * 3, 128),nn.ReLU(True),nn.Linear(128, 64),nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))self.decoder = nn.Sequential(nn.Linear(3, 12),nn.ReLU(True),nn.Linear(12, 64),nn.ReLU(True),nn.Linear(64, 128),nn.ReLU(True), nn.Linear(128, 32 * 32 * 3), nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xclass conv_autoencoder(nn.Module):def __init__(self):super(conv_autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]nn.ReLU(),nn.Conv2d(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]nn.ReLU(),nn.Conv2d(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]nn.ReLU(),
#             nn.Conv2d(48, 96, 4, stride=2, padding=1),           # [batch, 96, 2, 2]
#             nn.ReLU(),)self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1),  # [batch, 48, 4, 4]
#             nn.ReLU(),nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]nn.ReLU(),nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]nn.ReLU(),nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]nn.Tanh(),)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xclass VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(32*32*3, 400)self.fc21 = nn.Linear(400, 20)self.fc22 = nn.Linear(400, 20)self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 32*32*3)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()if torch.cuda.is_available():eps = torch.cuda.FloatTensor(std.size()).normal_()else:eps = torch.FloatTensor(std.size()).normal_()eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))return F.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x)z = self.reparametrize(mu, logvar)return self.decode(z), mu, logvardef loss_vae(recon_x, x, mu, logvar, criterion):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""mse = criterion(recon_x, x)  # mse loss# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn mse + KLD
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim import Adam, AdamW
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset)if task == 'ae':num_epochs = 1000batch_size = 128learning_rate = 1e-3#{'fcn', 'cnn', 'vae'} model_type = 'cnn' x = trainif model_type == 'fcn' or model_type == 'vae':x = x.reshape(len(x), -1)data = torch.tensor(x, dtype=torch.float)train_dataset = TensorDataset(data)train_sampler = RandomSampler(train_dataset)train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)model_classes = {'fcn':fcn_autoencoder(), 'cnn':conv_autoencoder(), 'vae':VAE()}model = model_classes[model_type].cuda()criterion = nn.MSELoss()optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)best_loss = np.infmodel.train()for epoch in range(num_epochs):for data in train_dataloader:if model_type == 'cnn':img = data[0].transpose(3, 1).cuda()else:img = data[0].cuda()# ===================forward=====================output = model(img)if model_type == 'vae':loss = loss_vae(output[0], img, output[1], output[2], criterion)else:loss = criterion(output, img)# ===================backward====================optimizer.zero_grad()loss.backward()optimizer.step()# ===================save====================if loss.item() < best_loss:best_loss = loss.item()torch.save(model, 'best_model_{}.pt'.format(model_type))# ===================log========================print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
epoch [1/1000], loss:0.0407
epoch [2/1000], loss:0.0271
epoch [3/1000], loss:0.0204
epoch [4/1000], loss:0.0167
epoch [5/1000], loss:0.0150
epoch [6/1000], loss:0.0152
epoch [7/1000], loss:0.0133
epoch [8/1000], loss:0.0135
epoch [9/1000], loss:0.0127
epoch [10/1000], loss:0.0124
epoch [11/1000], loss:0.0139
epoch [12/1000], loss:0.0172
epoch [13/1000], loss:0.0116
epoch [14/1000], loss:0.0109
epoch [15/1000], loss:0.0106
epoch [16/1000], loss:0.0107
epoch [17/1000], loss:0.0105
epoch [18/1000], loss:0.0102
epoch [19/1000], loss:0.0099
epoch [20/1000], loss:0.0108
epoch [21/1000], loss:0.0082
epoch [22/1000], loss:0.0093
epoch [23/1000], loss:0.0093
epoch [24/1000], loss:0.0080
epoch [25/1000], loss:0.0077
epoch [26/1000], loss:0.0079
epoch [27/1000], loss:0.0080
epoch [28/1000], loss:0.0080
epoch [29/1000], loss:0.0069
epoch [30/1000], loss:0.0072
epoch [31/1000], loss:0.0073
epoch [32/1000], loss:0.0069
epoch [33/1000], loss:0.0070
epoch [34/1000], loss:0.0064
epoch [35/1000], loss:0.0063
epoch [36/1000], loss:0.0070
epoch [37/1000], loss:0.0060
epoch [38/1000], loss:0.0064
epoch [39/1000], loss:0.0066
epoch [40/1000], loss:0.0064
epoch [41/1000], loss:0.0063
epoch [42/1000], loss:0.0066
epoch [43/1000], loss:0.0060
epoch [44/1000], loss:0.0072
epoch [45/1000], loss:0.0065
epoch [46/1000], loss:0.0068
epoch [47/1000], loss:0.0067
epoch [48/1000], loss:0.0060
epoch [49/1000], loss:0.0057
epoch [50/1000], loss:0.0058
epoch [51/1000], loss:0.0059
epoch [52/1000], loss:0.0061
epoch [53/1000], loss:0.0064
epoch [54/1000], loss:0.0060
epoch [55/1000], loss:0.0054
epoch [56/1000], loss:0.0052
epoch [57/1000], loss:0.0054
epoch [58/1000], loss:0.0060
epoch [59/1000], loss:0.0065
epoch [60/1000], loss:0.0054
epoch [61/1000], loss:0.0053
epoch [62/1000], loss:0.0062
epoch [63/1000], loss:0.0052
epoch [64/1000], loss:0.0053
epoch [65/1000], loss:0.0052
epoch [66/1000], loss:0.0049
epoch [67/1000], loss:0.0050
epoch [68/1000], loss:0.0044
epoch [69/1000], loss:0.0049
epoch [70/1000], loss:0.0048
epoch [71/1000], loss:0.0051
epoch [72/1000], loss:0.0054
epoch [73/1000], loss:0.0049
epoch [74/1000], loss:0.0050
epoch [75/1000], loss:0.0049
epoch [76/1000], loss:0.0048
epoch [77/1000], loss:0.0052
epoch [78/1000], loss:0.0045
epoch [79/1000], loss:0.0049
epoch [80/1000], loss:0.0053
epoch [81/1000], loss:0.0055
epoch [82/1000], loss:0.0044
epoch [83/1000], loss:0.0041
epoch [84/1000], loss:0.0052
epoch [85/1000], loss:0.0062
epoch [86/1000], loss:0.0039
epoch [87/1000], loss:0.0039
epoch [88/1000], loss:0.0044
epoch [89/1000], loss:0.0041
epoch [90/1000], loss:0.0043
epoch [91/1000], loss:0.0044
epoch [92/1000], loss:0.0038
epoch [93/1000], loss:0.0049
epoch [94/1000], loss:0.0046
epoch [95/1000], loss:0.0040
epoch [96/1000], loss:0.0040
epoch [97/1000], loss:0.0045
epoch [98/1000], loss:0.0045
epoch [99/1000], loss:0.0040
epoch [100/1000], loss:0.0038
epoch [101/1000], loss:0.0038
epoch [102/1000], loss:0.0044
epoch [103/1000], loss:0.0045
epoch [104/1000], loss:0.0040
epoch [105/1000], loss:0.0041
epoch [106/1000], loss:0.0038
epoch [107/1000], loss:0.0041
epoch [108/1000], loss:0.0040
epoch [109/1000], loss:0.0034
epoch [110/1000], loss:0.0036
epoch [111/1000], loss:0.0040
epoch [112/1000], loss:0.0046
epoch [113/1000], loss:0.0035
epoch [114/1000], loss:0.0042
epoch [115/1000], loss:0.0041
epoch [116/1000], loss:0.0034
epoch [117/1000], loss:0.0031
epoch [118/1000], loss:0.0046
epoch [119/1000], loss:0.0037
epoch [120/1000], loss:0.0038
epoch [121/1000], loss:0.0038
epoch [122/1000], loss:0.0037
epoch [123/1000], loss:0.0033
epoch [124/1000], loss:0.0042
epoch [125/1000], loss:0.0037
epoch [126/1000], loss:0.0032
epoch [127/1000], loss:0.0037
epoch [128/1000], loss:0.0037
epoch [129/1000], loss:0.0035
epoch [130/1000], loss:0.0033
epoch [131/1000], loss:0.0040
epoch [132/1000], loss:0.0050
epoch [133/1000], loss:0.0039
epoch [134/1000], loss:0.0037
epoch [135/1000], loss:0.0041
epoch [136/1000], loss:0.0037
epoch [137/1000], loss:0.0032
epoch [138/1000], loss:0.0037
epoch [139/1000], loss:0.0031
epoch [140/1000], loss:0.0036
epoch [141/1000], loss:0.0034
epoch [142/1000], loss:0.0038
epoch [143/1000], loss:0.0035
epoch [144/1000], loss:0.0039
epoch [145/1000], loss:0.0032
epoch [146/1000], loss:0.0035
epoch [147/1000], loss:0.0035
epoch [148/1000], loss:0.0037
epoch [149/1000], loss:0.0031
epoch [150/1000], loss:0.0031
epoch [151/1000], loss:0.0035
epoch [152/1000], loss:0.0038
epoch [153/1000], loss:0.0030
epoch [154/1000], loss:0.0032
epoch [155/1000], loss:0.0031
epoch [156/1000], loss:0.0033
epoch [157/1000], loss:0.0035
epoch [158/1000], loss:0.0034
epoch [159/1000], loss:0.0037
epoch [160/1000], loss:0.0028
epoch [161/1000], loss:0.0036
epoch [162/1000], loss:0.0032
epoch [163/1000], loss:0.0036
epoch [164/1000], loss:0.0035
epoch [165/1000], loss:0.0036
epoch [166/1000], loss:0.0030
epoch [167/1000], loss:0.0037
epoch [168/1000], loss:0.0034
epoch [169/1000], loss:0.0030
epoch [170/1000], loss:0.0033
epoch [171/1000], loss:0.0036
epoch [172/1000], loss:0.0031
epoch [173/1000], loss:0.0033
epoch [174/1000], loss:0.0031
epoch [175/1000], loss:0.0033
epoch [176/1000], loss:0.0032
epoch [177/1000], loss:0.0032
epoch [178/1000], loss:0.0033
epoch [179/1000], loss:0.0041
epoch [180/1000], loss:0.0035
epoch [181/1000], loss:0.0030
epoch [182/1000], loss:0.0032
epoch [183/1000], loss:0.0028
epoch [184/1000], loss:0.0032
epoch [185/1000], loss:0.0033
epoch [186/1000], loss:0.0034
epoch [187/1000], loss:0.0033
epoch [188/1000], loss:0.0034
epoch [189/1000], loss:0.0035
epoch [190/1000], loss:0.0034
epoch [191/1000], loss:0.0035
epoch [192/1000], loss:0.0037
epoch [193/1000], loss:0.0039
epoch [194/1000], loss:0.0031
epoch [195/1000], loss:0.0032
epoch [196/1000], loss:0.0028
epoch [197/1000], loss:0.0032
epoch [198/1000], loss:0.0034
epoch [199/1000], loss:0.0035
epoch [200/1000], loss:0.0032
epoch [201/1000], loss:0.0036
epoch [202/1000], loss:0.0033
epoch [203/1000], loss:0.0032
epoch [204/1000], loss:0.0034
epoch [205/1000], loss:0.0026
epoch [206/1000], loss:0.0030
epoch [207/1000], loss:0.0031
epoch [208/1000], loss:0.0029
epoch [209/1000], loss:0.0032
epoch [210/1000], loss:0.0024
epoch [211/1000], loss:0.0028
epoch [212/1000], loss:0.0036
epoch [213/1000], loss:0.0033
epoch [214/1000], loss:0.0029
epoch [215/1000], loss:0.0046
epoch [216/1000], loss:0.0034
epoch [217/1000], loss:0.0033
epoch [218/1000], loss:0.0029
epoch [219/1000], loss:0.0033
epoch [220/1000], loss:0.0040
epoch [221/1000], loss:0.0032
epoch [222/1000], loss:0.0036
epoch [223/1000], loss:0.0035
epoch [224/1000], loss:0.0029
epoch [225/1000], loss:0.0033
epoch [226/1000], loss:0.0035
epoch [227/1000], loss:0.0028
epoch [228/1000], loss:0.0030
epoch [229/1000], loss:0.0031
epoch [230/1000], loss:0.0030
epoch [231/1000], loss:0.0026
epoch [232/1000], loss:0.0027
epoch [233/1000], loss:0.0032
epoch [234/1000], loss:0.0034
epoch [235/1000], loss:0.0032
epoch [236/1000], loss:0.0027
epoch [237/1000], loss:0.0030
epoch [238/1000], loss:0.0026
epoch [239/1000], loss:0.0032
epoch [240/1000], loss:0.0029
epoch [241/1000], loss:0.0029
epoch [242/1000], loss:0.0033
epoch [243/1000], loss:0.0031
epoch [244/1000], loss:0.0028
epoch [245/1000], loss:0.0028
epoch [246/1000], loss:0.0039
epoch [247/1000], loss:0.0036
epoch [248/1000], loss:0.0032
epoch [249/1000], loss:0.0031
epoch [250/1000], loss:0.0031
epoch [251/1000], loss:0.0036
epoch [252/1000], loss:0.0030
epoch [253/1000], loss:0.0026
epoch [254/1000], loss:0.0028
epoch [255/1000], loss:0.0028
epoch [256/1000], loss:0.0029
epoch [257/1000], loss:0.0032
epoch [258/1000], loss:0.0030
epoch [259/1000], loss:0.0029
epoch [260/1000], loss:0.0032
epoch [261/1000], loss:0.0027
epoch [262/1000], loss:0.0027
epoch [263/1000], loss:0.0031
epoch [264/1000], loss:0.0028
epoch [265/1000], loss:0.0023
epoch [266/1000], loss:0.0031
epoch [267/1000], loss:0.0029
epoch [268/1000], loss:0.0029
epoch [269/1000], loss:0.0030
epoch [270/1000], loss:0.0032
epoch [271/1000], loss:0.0030
epoch [272/1000], loss:0.0026
epoch [273/1000], loss:0.0027
epoch [274/1000], loss:0.0028
epoch [275/1000], loss:0.0030
epoch [276/1000], loss:0.0033
epoch [277/1000], loss:0.0024
epoch [278/1000], loss:0.0031
epoch [279/1000], loss:0.0030
epoch [280/1000], loss:0.0035
epoch [281/1000], loss:0.0031
epoch [282/1000], loss:0.0026
epoch [283/1000], loss:0.0029
epoch [284/1000], loss:0.0031
epoch [285/1000], loss:0.0027
epoch [286/1000], loss:0.0030
epoch [287/1000], loss:0.0031
epoch [288/1000], loss:0.0033
epoch [289/1000], loss:0.0027
epoch [290/1000], loss:0.0038
epoch [291/1000], loss:0.0027
epoch [292/1000], loss:0.0027
epoch [293/1000], loss:0.0030
epoch [294/1000], loss:0.0029
epoch [295/1000], loss:0.0031
epoch [296/1000], loss:0.0026
epoch [297/1000], loss:0.0028
epoch [298/1000], loss:0.0024
epoch [299/1000], loss:0.0028
epoch [300/1000], loss:0.0024
epoch [301/1000], loss:0.0032
epoch [302/1000], loss:0.0029
epoch [303/1000], loss:0.0036
epoch [304/1000], loss:0.0029
epoch [305/1000], loss:0.0030
epoch [306/1000], loss:0.0029
epoch [307/1000], loss:0.0032
epoch [308/1000], loss:0.0027
epoch [309/1000], loss:0.0026
epoch [310/1000], loss:0.0030
epoch [311/1000], loss:0.0028
epoch [312/1000], loss:0.0032
epoch [313/1000], loss:0.0032
epoch [314/1000], loss:0.0029
epoch [315/1000], loss:0.0031
epoch [316/1000], loss:0.0026
epoch [317/1000], loss:0.0034
epoch [318/1000], loss:0.0043
epoch [319/1000], loss:0.0034
epoch [320/1000], loss:0.0031
epoch [321/1000], loss:0.0027
epoch [322/1000], loss:0.0026
epoch [323/1000], loss:0.0030
epoch [324/1000], loss:0.0031
epoch [325/1000], loss:0.0029
epoch [326/1000], loss:0.0033
epoch [327/1000], loss:0.0029
epoch [328/1000], loss:0.0026
epoch [329/1000], loss:0.0027
epoch [330/1000], loss:0.0030
epoch [331/1000], loss:0.0027
epoch [332/1000], loss:0.0028
epoch [333/1000], loss:0.0028
epoch [334/1000], loss:0.0026
epoch [335/1000], loss:0.0028
epoch [336/1000], loss:0.0026
epoch [337/1000], loss:0.0037
epoch [338/1000], loss:0.0030
epoch [339/1000], loss:0.0031
epoch [340/1000], loss:0.0028
epoch [341/1000], loss:0.0034
epoch [342/1000], loss:0.0027
epoch [343/1000], loss:0.0032
epoch [344/1000], loss:0.0028
epoch [345/1000], loss:0.0033
epoch [346/1000], loss:0.0030
epoch [347/1000], loss:0.0028
epoch [348/1000], loss:0.0027
epoch [349/1000], loss:0.0026
epoch [350/1000], loss:0.0028
epoch [351/1000], loss:0.0029
epoch [352/1000], loss:0.0029
epoch [353/1000], loss:0.0028
epoch [354/1000], loss:0.0031
epoch [355/1000], loss:0.0026
epoch [356/1000], loss:0.0028
epoch [357/1000], loss:0.0025
epoch [358/1000], loss:0.0030
epoch [359/1000], loss:0.0027
epoch [360/1000], loss:0.0025
epoch [361/1000], loss:0.0033
epoch [362/1000], loss:0.0030
epoch [363/1000], loss:0.0024
epoch [364/1000], loss:0.0028
epoch [365/1000], loss:0.0028
epoch [366/1000], loss:0.0024
epoch [367/1000], loss:0.0026
epoch [368/1000], loss:0.0026
epoch [369/1000], loss:0.0028
epoch [370/1000], loss:0.0026
epoch [371/1000], loss:0.0029
epoch [372/1000], loss:0.0027
epoch [373/1000], loss:0.0027
epoch [374/1000], loss:0.0029
epoch [375/1000], loss:0.0026
epoch [376/1000], loss:0.0028
epoch [377/1000], loss:0.0025
epoch [378/1000], loss:0.0027
epoch [379/1000], loss:0.0028
epoch [380/1000], loss:0.0027
epoch [381/1000], loss:0.0028
epoch [382/1000], loss:0.0032
epoch [383/1000], loss:0.0031
epoch [384/1000], loss:0.0025
epoch [385/1000], loss:0.0026
epoch [386/1000], loss:0.0029
epoch [387/1000], loss:0.0024
epoch [388/1000], loss:0.0030
epoch [389/1000], loss:0.0026
epoch [390/1000], loss:0.0028
epoch [391/1000], loss:0.0030
epoch [392/1000], loss:0.0025
epoch [393/1000], loss:0.0027
epoch [394/1000], loss:0.0043
epoch [395/1000], loss:0.0029
epoch [396/1000], loss:0.0036
epoch [397/1000], loss:0.0025
epoch [398/1000], loss:0.0029
epoch [399/1000], loss:0.0030
epoch [400/1000], loss:0.0027
epoch [401/1000], loss:0.0026
epoch [402/1000], loss:0.0026
epoch [403/1000], loss:0.0027
epoch [404/1000], loss:0.0025
epoch [405/1000], loss:0.0029
epoch [406/1000], loss:0.0024
epoch [407/1000], loss:0.0029
epoch [408/1000], loss:0.0030
epoch [409/1000], loss:0.0028
epoch [410/1000], loss:0.0028
epoch [411/1000], loss:0.0027
epoch [412/1000], loss:0.0028
epoch [413/1000], loss:0.0027
epoch [414/1000], loss:0.0034
epoch [415/1000], loss:0.0028
epoch [416/1000], loss:0.0028
epoch [417/1000], loss:0.0030
epoch [418/1000], loss:0.0028
epoch [419/1000], loss:0.0026
epoch [420/1000], loss:0.0029
epoch [421/1000], loss:0.0027
epoch [422/1000], loss:0.0027
epoch [423/1000], loss:0.0029
epoch [424/1000], loss:0.0027
epoch [425/1000], loss:0.0029
epoch [426/1000], loss:0.0024
epoch [427/1000], loss:0.0025
epoch [428/1000], loss:0.0027
epoch [429/1000], loss:0.0026
epoch [430/1000], loss:0.0029
epoch [431/1000], loss:0.0028
epoch [432/1000], loss:0.0030
epoch [433/1000], loss:0.0026
epoch [434/1000], loss:0.0029
epoch [435/1000], loss:0.0030
epoch [436/1000], loss:0.0026
epoch [437/1000], loss:0.0027
epoch [438/1000], loss:0.0023
epoch [439/1000], loss:0.0028
epoch [440/1000], loss:0.0030
epoch [441/1000], loss:0.0032
epoch [442/1000], loss:0.0032
epoch [443/1000], loss:0.0021
epoch [444/1000], loss:0.0029
epoch [445/1000], loss:0.0032
epoch [446/1000], loss:0.0028
epoch [447/1000], loss:0.0030
epoch [448/1000], loss:0.0026
epoch [449/1000], loss:0.0032
epoch [450/1000], loss:0.0024
epoch [451/1000], loss:0.0026
epoch [452/1000], loss:0.0031
epoch [453/1000], loss:0.0030
epoch [454/1000], loss:0.0026
epoch [455/1000], loss:0.0024
epoch [456/1000], loss:0.0031
epoch [457/1000], loss:0.0031
epoch [458/1000], loss:0.0028
epoch [459/1000], loss:0.0030
epoch [460/1000], loss:0.0025
epoch [461/1000], loss:0.0028
epoch [462/1000], loss:0.0024
epoch [463/1000], loss:0.0028
epoch [464/1000], loss:0.0026
epoch [465/1000], loss:0.0032
epoch [466/1000], loss:0.0030
epoch [467/1000], loss:0.0030
epoch [468/1000], loss:0.0033
epoch [469/1000], loss:0.0031
epoch [470/1000], loss:0.0028
epoch [471/1000], loss:0.0031
epoch [472/1000], loss:0.0025
epoch [473/1000], loss:0.0023
epoch [474/1000], loss:0.0030
epoch [475/1000], loss:0.0030
epoch [476/1000], loss:0.0025
epoch [477/1000], loss:0.0032
epoch [478/1000], loss:0.0026
epoch [479/1000], loss:0.0027
epoch [480/1000], loss:0.0025
epoch [481/1000], loss:0.0065
epoch [482/1000], loss:0.0028
epoch [483/1000], loss:0.0030
epoch [484/1000], loss:0.0024
epoch [485/1000], loss:0.0030
epoch [486/1000], loss:0.0027
epoch [487/1000], loss:0.0026
epoch [488/1000], loss:0.0033
epoch [489/1000], loss:0.0025
epoch [490/1000], loss:0.0033
epoch [491/1000], loss:0.0025
epoch [492/1000], loss:0.0027
epoch [493/1000], loss:0.0028
epoch [494/1000], loss:0.0028
epoch [495/1000], loss:0.0024
epoch [496/1000], loss:0.0029
epoch [497/1000], loss:0.0028
epoch [498/1000], loss:0.0027
epoch [499/1000], loss:0.0027
epoch [500/1000], loss:0.0031
epoch [501/1000], loss:0.0030
epoch [502/1000], loss:0.0030
epoch [503/1000], loss:0.0027
epoch [504/1000], loss:0.0026
epoch [505/1000], loss:0.0024
epoch [506/1000], loss:0.0030
epoch [507/1000], loss:0.0028
epoch [508/1000], loss:0.0025
epoch [509/1000], loss:0.0034
epoch [510/1000], loss:0.0030
epoch [511/1000], loss:0.0030
epoch [512/1000], loss:0.0025
epoch [513/1000], loss:0.0028
epoch [514/1000], loss:0.0028
epoch [515/1000], loss:0.0027
epoch [516/1000], loss:0.0028
epoch [517/1000], loss:0.0029
epoch [518/1000], loss:0.0034
epoch [519/1000], loss:0.0032
epoch [520/1000], loss:0.0026
epoch [521/1000], loss:0.0029
epoch [522/1000], loss:0.0026
epoch [523/1000], loss:0.0027
epoch [524/1000], loss:0.0029
epoch [525/1000], loss:0.0030
epoch [526/1000], loss:0.0028
epoch [527/1000], loss:0.0029
epoch [528/1000], loss:0.0027
epoch [529/1000], loss:0.0028
epoch [530/1000], loss:0.0023
epoch [531/1000], loss:0.0025
epoch [532/1000], loss:0.0029
epoch [533/1000], loss:0.0027
epoch [534/1000], loss:0.0030
epoch [535/1000], loss:0.0031
epoch [536/1000], loss:0.0025
epoch [537/1000], loss:0.0029
epoch [538/1000], loss:0.0028
epoch [539/1000], loss:0.0031
epoch [540/1000], loss:0.0028
epoch [541/1000], loss:0.0025
epoch [542/1000], loss:0.0030
epoch [543/1000], loss:0.0027
epoch [544/1000], loss:0.0026
epoch [545/1000], loss:0.0029
epoch [546/1000], loss:0.0025
epoch [547/1000], loss:0.0026
epoch [548/1000], loss:0.0028
epoch [549/1000], loss:0.0032
epoch [550/1000], loss:0.0026
epoch [551/1000], loss:0.0026
epoch [552/1000], loss:0.0027
epoch [553/1000], loss:0.0028
epoch [554/1000], loss:0.0025
epoch [555/1000], loss:0.0030
epoch [556/1000], loss:0.0028
epoch [557/1000], loss:0.0028
epoch [558/1000], loss:0.0035
epoch [559/1000], loss:0.0033
epoch [560/1000], loss:0.0029
epoch [561/1000], loss:0.0034
epoch [562/1000], loss:0.0029
epoch [563/1000], loss:0.0026
epoch [564/1000], loss:0.0029
epoch [565/1000], loss:0.0029
epoch [566/1000], loss:0.0027
epoch [567/1000], loss:0.0030
epoch [568/1000], loss:0.0029
epoch [569/1000], loss:0.0029
epoch [570/1000], loss:0.0026
epoch [571/1000], loss:0.0029
epoch [572/1000], loss:0.0034
epoch [573/1000], loss:0.0029
epoch [574/1000], loss:0.0026
epoch [575/1000], loss:0.0027
epoch [576/1000], loss:0.0028
epoch [577/1000], loss:0.0029
epoch [578/1000], loss:0.0031
epoch [579/1000], loss:0.0027
epoch [580/1000], loss:0.0030
epoch [581/1000], loss:0.0032
epoch [582/1000], loss:0.0034
epoch [583/1000], loss:0.0028
epoch [584/1000], loss:0.0024
epoch [585/1000], loss:0.0026
epoch [586/1000], loss:0.0027
epoch [587/1000], loss:0.0024
epoch [588/1000], loss:0.0026
epoch [589/1000], loss:0.0025
epoch [590/1000], loss:0.0023
epoch [591/1000], loss:0.0029
epoch [592/1000], loss:0.0025
epoch [593/1000], loss:0.0026
epoch [594/1000], loss:0.0026
epoch [595/1000], loss:0.0029
epoch [596/1000], loss:0.0027
epoch [597/1000], loss:0.0027
epoch [598/1000], loss:0.0025
epoch [599/1000], loss:0.0028
epoch [600/1000], loss:0.0025
epoch [601/1000], loss:0.0031
epoch [602/1000], loss:0.0027
epoch [603/1000], loss:0.0041
epoch [604/1000], loss:0.0028
epoch [605/1000], loss:0.0028
epoch [606/1000], loss:0.0033
epoch [607/1000], loss:0.0029
epoch [608/1000], loss:0.0031
epoch [609/1000], loss:0.0029
epoch [610/1000], loss:0.0029
epoch [611/1000], loss:0.0026
epoch [612/1000], loss:0.0032
epoch [613/1000], loss:0.0027
epoch [614/1000], loss:0.0029
epoch [615/1000], loss:0.0036
epoch [616/1000], loss:0.0028
epoch [617/1000], loss:0.0028
epoch [618/1000], loss:0.0029
epoch [619/1000], loss:0.0030
epoch [620/1000], loss:0.0025
epoch [621/1000], loss:0.0028
epoch [622/1000], loss:0.0026
epoch [623/1000], loss:0.0026
epoch [624/1000], loss:0.0025
epoch [625/1000], loss:0.0025
epoch [626/1000], loss:0.0030
epoch [627/1000], loss:0.0031
epoch [628/1000], loss:0.0030
epoch [629/1000], loss:0.0031
epoch [630/1000], loss:0.0032
epoch [631/1000], loss:0.0027
epoch [632/1000], loss:0.0024
epoch [633/1000], loss:0.0027
epoch [634/1000], loss:0.0031
epoch [635/1000], loss:0.0031
epoch [636/1000], loss:0.0032
epoch [637/1000], loss:0.0029
epoch [638/1000], loss:0.0030
epoch [639/1000], loss:0.0031
epoch [640/1000], loss:0.0026
epoch [641/1000], loss:0.0027
epoch [642/1000], loss:0.0029
epoch [643/1000], loss:0.0027
epoch [644/1000], loss:0.0031
epoch [645/1000], loss:0.0029
epoch [646/1000], loss:0.0023
epoch [647/1000], loss:0.0025
epoch [648/1000], loss:0.0023
epoch [649/1000], loss:0.0029
epoch [650/1000], loss:0.0028
epoch [651/1000], loss:0.0027
epoch [652/1000], loss:0.0027
epoch [653/1000], loss:0.0031
epoch [654/1000], loss:0.0026
epoch [655/1000], loss:0.0030
epoch [656/1000], loss:0.0032
epoch [657/1000], loss:0.0028
epoch [658/1000], loss:0.0028
epoch [659/1000], loss:0.0030
epoch [660/1000], loss:0.0029
epoch [661/1000], loss:0.0026
epoch [662/1000], loss:0.0028
epoch [663/1000], loss:0.0027
epoch [664/1000], loss:0.0027
epoch [665/1000], loss:0.0027
epoch [666/1000], loss:0.0025
epoch [667/1000], loss:0.0024
epoch [668/1000], loss:0.0025
epoch [669/1000], loss:0.0030
epoch [670/1000], loss:0.0029
epoch [671/1000], loss:0.0028
epoch [672/1000], loss:0.0023
epoch [673/1000], loss:0.0030
epoch [674/1000], loss:0.0028
epoch [675/1000], loss:0.0027
epoch [676/1000], loss:0.0028
epoch [677/1000], loss:0.0026
epoch [678/1000], loss:0.0028
epoch [679/1000], loss:0.0026
epoch [680/1000], loss:0.0025
epoch [681/1000], loss:0.0027
epoch [682/1000], loss:0.0034
epoch [683/1000], loss:0.0028
epoch [684/1000], loss:0.0029
epoch [685/1000], loss:0.0027
epoch [686/1000], loss:0.0026
epoch [687/1000], loss:0.0027
epoch [688/1000], loss:0.0025
epoch [689/1000], loss:0.0025
epoch [690/1000], loss:0.0037
epoch [691/1000], loss:0.0028
epoch [692/1000], loss:0.0025
epoch [693/1000], loss:0.0030
epoch [694/1000], loss:0.0026
epoch [695/1000], loss:0.0026
epoch [696/1000], loss:0.0028
epoch [697/1000], loss:0.0027
epoch [698/1000], loss:0.0025
epoch [699/1000], loss:0.0030
epoch [700/1000], loss:0.0026
epoch [701/1000], loss:0.0031
epoch [702/1000], loss:0.0029
epoch [703/1000], loss:0.0023
epoch [704/1000], loss:0.0027
epoch [705/1000], loss:0.0025
epoch [706/1000], loss:0.0028
epoch [707/1000], loss:0.0026
epoch [708/1000], loss:0.0027
epoch [709/1000], loss:0.0027
epoch [710/1000], loss:0.0028
epoch [711/1000], loss:0.0023
epoch [712/1000], loss:0.0026
epoch [713/1000], loss:0.0029
epoch [714/1000], loss:0.0030
epoch [715/1000], loss:0.0026
epoch [716/1000], loss:0.0030
epoch [717/1000], loss:0.0041
epoch [718/1000], loss:0.0026
epoch [719/1000], loss:0.0026
epoch [720/1000], loss:0.0028
epoch [721/1000], loss:0.0027
epoch [722/1000], loss:0.0027
epoch [723/1000], loss:0.0035
epoch [724/1000], loss:0.0028
epoch [725/1000], loss:0.0029
epoch [726/1000], loss:0.0027
epoch [727/1000], loss:0.0030
epoch [728/1000], loss:0.0034
epoch [729/1000], loss:0.0029
epoch [730/1000], loss:0.0029
epoch [731/1000], loss:0.0029
epoch [732/1000], loss:0.0028
epoch [733/1000], loss:0.0026
epoch [734/1000], loss:0.0030
epoch [735/1000], loss:0.0021
epoch [736/1000], loss:0.0030
epoch [737/1000], loss:0.0027
epoch [738/1000], loss:0.0025
epoch [739/1000], loss:0.0027
epoch [740/1000], loss:0.0026
epoch [741/1000], loss:0.0024
epoch [742/1000], loss:0.0028
epoch [743/1000], loss:0.0027
epoch [744/1000], loss:0.0028
epoch [745/1000], loss:0.0030
epoch [746/1000], loss:0.0026
epoch [747/1000], loss:0.0026
epoch [748/1000], loss:0.0030
epoch [749/1000], loss:0.0025
epoch [750/1000], loss:0.0030
epoch [751/1000], loss:0.0032
epoch [752/1000], loss:0.0027
epoch [753/1000], loss:0.0027
epoch [754/1000], loss:0.0027
epoch [755/1000], loss:0.0029
epoch [756/1000], loss:0.0028
epoch [757/1000], loss:0.0026
epoch [758/1000], loss:0.0023
epoch [759/1000], loss:0.0027
epoch [760/1000], loss:0.0026
epoch [761/1000], loss:0.0025
epoch [762/1000], loss:0.0024
epoch [763/1000], loss:0.0026
epoch [764/1000], loss:0.0025
epoch [765/1000], loss:0.0027
epoch [766/1000], loss:0.0026
epoch [767/1000], loss:0.0027
epoch [768/1000], loss:0.0025
epoch [769/1000], loss:0.0024
epoch [770/1000], loss:0.0026
epoch [771/1000], loss:0.0027
epoch [772/1000], loss:0.0024
epoch [773/1000], loss:0.0032
epoch [774/1000], loss:0.0025
epoch [775/1000], loss:0.0021
epoch [776/1000], loss:0.0026
epoch [777/1000], loss:0.0028
epoch [778/1000], loss:0.0030
epoch [779/1000], loss:0.0025
epoch [780/1000], loss:0.0030
epoch [781/1000], loss:0.0029
epoch [782/1000], loss:0.0029
epoch [783/1000], loss:0.0025
epoch [784/1000], loss:0.0026
epoch [785/1000], loss:0.0028
epoch [786/1000], loss:0.0030
epoch [787/1000], loss:0.0027
epoch [788/1000], loss:0.0024
epoch [789/1000], loss:0.0026
epoch [790/1000], loss:0.0026
epoch [791/1000], loss:0.0026
epoch [792/1000], loss:0.0026
epoch [793/1000], loss:0.0028
epoch [794/1000], loss:0.0025
epoch [795/1000], loss:0.0023
epoch [796/1000], loss:0.0031
epoch [797/1000], loss:0.0025
epoch [798/1000], loss:0.0024
epoch [799/1000], loss:0.0025
epoch [800/1000], loss:0.0028
epoch [801/1000], loss:0.0024
epoch [802/1000], loss:0.0026
epoch [803/1000], loss:0.0025
epoch [804/1000], loss:0.0030
epoch [805/1000], loss:0.0024
epoch [806/1000], loss:0.0025
epoch [807/1000], loss:0.0030
epoch [808/1000], loss:0.0025
epoch [809/1000], loss:0.0028
epoch [810/1000], loss:0.0030
epoch [811/1000], loss:0.0030
epoch [812/1000], loss:0.0025
epoch [813/1000], loss:0.0025
epoch [814/1000], loss:0.0025
epoch [815/1000], loss:0.0026
epoch [816/1000], loss:0.0031
epoch [817/1000], loss:0.0024
epoch [818/1000], loss:0.0028
epoch [819/1000], loss:0.0028
epoch [820/1000], loss:0.0030
epoch [821/1000], loss:0.0028
epoch [822/1000], loss:0.0029
epoch [823/1000], loss:0.0029
epoch [824/1000], loss:0.0030
epoch [825/1000], loss:0.0026
epoch [826/1000], loss:0.0028
epoch [827/1000], loss:0.0023
epoch [828/1000], loss:0.0024
epoch [829/1000], loss:0.0025
epoch [830/1000], loss:0.0028
epoch [831/1000], loss:0.0027
epoch [832/1000], loss:0.0025
epoch [833/1000], loss:0.0024
epoch [834/1000], loss:0.0023
epoch [835/1000], loss:0.0034
epoch [836/1000], loss:0.0027
epoch [837/1000], loss:0.0031
epoch [838/1000], loss:0.0028
epoch [839/1000], loss:0.0025
epoch [840/1000], loss:0.0026
epoch [841/1000], loss:0.0025
epoch [842/1000], loss:0.0025
epoch [843/1000], loss:0.0024
epoch [844/1000], loss:0.0029
epoch [845/1000], loss:0.0025
epoch [846/1000], loss:0.0028
epoch [847/1000], loss:0.0025
epoch [848/1000], loss:0.0039
epoch [849/1000], loss:0.0026
epoch [850/1000], loss:0.0031
epoch [851/1000], loss:0.0032
epoch [852/1000], loss:0.0025
epoch [853/1000], loss:0.0025
epoch [854/1000], loss:0.0028
epoch [855/1000], loss:0.0025
epoch [856/1000], loss:0.0032
epoch [857/1000], loss:0.0025
epoch [858/1000], loss:0.0026
epoch [859/1000], loss:0.0029
epoch [860/1000], loss:0.0026
epoch [861/1000], loss:0.0027
epoch [862/1000], loss:0.0023
epoch [863/1000], loss:0.0026
epoch [864/1000], loss:0.0025
epoch [865/1000], loss:0.0025
epoch [866/1000], loss:0.0025
epoch [867/1000], loss:0.0031
epoch [868/1000], loss:0.0029
epoch [869/1000], loss:0.0027
epoch [870/1000], loss:0.0031
epoch [871/1000], loss:0.0025
epoch [872/1000], loss:0.0027
epoch [873/1000], loss:0.0026
epoch [874/1000], loss:0.0027
epoch [875/1000], loss:0.0038
epoch [876/1000], loss:0.0025
epoch [877/1000], loss:0.0029
epoch [878/1000], loss:0.0026
epoch [879/1000], loss:0.0026
epoch [880/1000], loss:0.0026
epoch [881/1000], loss:0.0030
epoch [882/1000], loss:0.0030
epoch [883/1000], loss:0.0025
epoch [884/1000], loss:0.0029
epoch [885/1000], loss:0.0033
epoch [886/1000], loss:0.0025
epoch [887/1000], loss:0.0024
epoch [888/1000], loss:0.0021
epoch [889/1000], loss:0.0025
epoch [890/1000], loss:0.0027
epoch [891/1000], loss:0.0025
epoch [892/1000], loss:0.0031
epoch [893/1000], loss:0.0027
epoch [894/1000], loss:0.0030
epoch [895/1000], loss:0.0030
epoch [896/1000], loss:0.0029
epoch [897/1000], loss:0.0024
epoch [898/1000], loss:0.0027
epoch [899/1000], loss:0.0026
epoch [900/1000], loss:0.0022
epoch [901/1000], loss:0.0025
epoch [902/1000], loss:0.0028
epoch [903/1000], loss:0.0027
epoch [904/1000], loss:0.0024
epoch [905/1000], loss:0.0027
epoch [906/1000], loss:0.0027
epoch [907/1000], loss:0.0029
epoch [908/1000], loss:0.0035
epoch [909/1000], loss:0.0030
epoch [910/1000], loss:0.0030
epoch [911/1000], loss:0.0030
epoch [912/1000], loss:0.0030
epoch [913/1000], loss:0.0023
epoch [914/1000], loss:0.0026
epoch [915/1000], loss:0.0026
epoch [916/1000], loss:0.0024
epoch [917/1000], loss:0.0028
epoch [918/1000], loss:0.0029
epoch [919/1000], loss:0.0024
epoch [920/1000], loss:0.0023
epoch [921/1000], loss:0.0031
epoch [922/1000], loss:0.0033
epoch [923/1000], loss:0.0026
epoch [924/1000], loss:0.0030
epoch [925/1000], loss:0.0028
epoch [926/1000], loss:0.0030
epoch [927/1000], loss:0.0026
epoch [928/1000], loss:0.0032
epoch [929/1000], loss:0.0026
epoch [930/1000], loss:0.0028
epoch [931/1000], loss:0.0027
epoch [932/1000], loss:0.0025
epoch [933/1000], loss:0.0024
epoch [934/1000], loss:0.0026
epoch [935/1000], loss:0.0033
epoch [936/1000], loss:0.0026
epoch [937/1000], loss:0.0029
epoch [938/1000], loss:0.0028
epoch [939/1000], loss:0.0027
epoch [940/1000], loss:0.0026
epoch [941/1000], loss:0.0027
epoch [942/1000], loss:0.0026
epoch [943/1000], loss:0.0027
epoch [944/1000], loss:0.0024
epoch [945/1000], loss:0.0027
epoch [946/1000], loss:0.0025
epoch [947/1000], loss:0.0026
epoch [948/1000], loss:0.0028
epoch [949/1000], loss:0.0025
epoch [950/1000], loss:0.0031
epoch [951/1000], loss:0.0023
epoch [952/1000], loss:0.0027
epoch [953/1000], loss:0.0028
epoch [954/1000], loss:0.0027
epoch [955/1000], loss:0.0030
epoch [956/1000], loss:0.0024
epoch [957/1000], loss:0.0029
epoch [958/1000], loss:0.0026
epoch [959/1000], loss:0.0028
epoch [960/1000], loss:0.0025
epoch [961/1000], loss:0.0026
epoch [962/1000], loss:0.0026
epoch [963/1000], loss:0.0021
epoch [964/1000], loss:0.0033
epoch [965/1000], loss:0.0024
epoch [966/1000], loss:0.0027
epoch [967/1000], loss:0.0028
epoch [968/1000], loss:0.0026
epoch [969/1000], loss:0.0023
epoch [970/1000], loss:0.0029
epoch [971/1000], loss:0.0027
epoch [972/1000], loss:0.0029
epoch [973/1000], loss:0.0027
epoch [974/1000], loss:0.0027
epoch [975/1000], loss:0.0027
epoch [976/1000], loss:0.0026
epoch [977/1000], loss:0.0025
epoch [978/1000], loss:0.0027
epoch [979/1000], loss:0.0030
epoch [980/1000], loss:0.0035
epoch [981/1000], loss:0.0025
epoch [982/1000], loss:0.0029
epoch [983/1000], loss:0.0026
epoch [984/1000], loss:0.0023
epoch [985/1000], loss:0.0027
epoch [986/1000], loss:0.0025
epoch [987/1000], loss:0.0024
epoch [988/1000], loss:0.0026
epoch [989/1000], loss:0.0030
epoch [990/1000], loss:0.0025
epoch [991/1000], loss:0.0027
epoch [992/1000], loss:0.0029
epoch [993/1000], loss:0.0027
epoch [994/1000], loss:0.0030
epoch [995/1000], loss:0.0026
epoch [996/1000], loss:0.0027
epoch [997/1000], loss:0.0028
epoch [998/1000], loss:0.0024
epoch [999/1000], loss:0.0025
epoch [1000/1000], loss:0.0028
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim import Adam, AdamW
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset)if task == 'ae':num_epochs = 1000batch_size = 128learning_rate = 1e-3#{'fcn', 'cnn', 'vae'} model_type = 'vae' x = trainif model_type == 'fcn' or model_type == 'vae':x = x.reshape(len(x), -1)data = torch.tensor(x, dtype=torch.float)train_dataset = TensorDataset(data)train_sampler = RandomSampler(train_dataset)train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)model_classes = {'fcn':fcn_autoencoder(), 'cnn':conv_autoencoder(), 'vae':VAE()}model = model_classes[model_type].cuda()criterion = nn.MSELoss()optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)best_loss = np.infmodel.train()for epoch in range(num_epochs):for data in train_dataloader:if model_type == 'cnn':img = data[0].transpose(3, 1).cuda()else:img = data[0].cuda()# ===================forward=====================output = model(img)if model_type == 'vae':loss = loss_vae(output[0], img, output[1], output[2], criterion)else:loss = criterion(output, img)# ===================backward====================optimizer.zero_grad()loss.backward()optimizer.step()# ===================save====================if loss.item() < best_loss:best_loss = loss.item()torch.save(model, 'best_model_{}.pt'.format(model_type))# ===================log========================print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
E:\Anaconda\envs\pytorch\lib\site-packages\torch\nn\functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")epoch [1/1000], loss:0.7209
epoch [2/1000], loss:0.7183
epoch [3/1000], loss:0.3760
epoch [4/1000], loss:0.3718
epoch [5/1000], loss:0.2899
epoch [6/1000], loss:0.2597
epoch [7/1000], loss:0.3089
epoch [8/1000], loss:0.2674
epoch [9/1000], loss:0.2420
epoch [10/1000], loss:0.2356
epoch [11/1000], loss:0.2095
epoch [12/1000], loss:0.2775
epoch [13/1000], loss:0.2764
epoch [14/1000], loss:0.2764
epoch [15/1000], loss:0.2673
epoch [16/1000], loss:0.2465
epoch [17/1000], loss:0.2245
epoch [18/1000], loss:90.9289
epoch [19/1000], loss:0.6038
epoch [20/1000], loss:0.4177
epoch [21/1000], loss:0.3409
epoch [22/1000], loss:0.3918
epoch [23/1000], loss:0.3704
epoch [24/1000], loss:0.2902
epoch [25/1000], loss:0.8151
epoch [26/1000], loss:0.5705
epoch [27/1000], loss:0.3563
epoch [28/1000], loss:0.3049
epoch [29/1000], loss:0.3005
epoch [30/1000], loss:0.3972
epoch [31/1000], loss:0.3455
epoch [32/1000], loss:0.2981
epoch [33/1000], loss:0.2781
epoch [34/1000], loss:0.2959
epoch [35/1000], loss:0.2964
epoch [36/1000], loss:0.3022
epoch [37/1000], loss:0.2539
epoch [38/1000], loss:0.2813
epoch [39/1000], loss:0.2777
epoch [40/1000], loss:0.2561
epoch [41/1000], loss:0.2424
epoch [42/1000], loss:0.2589
epoch [43/1000], loss:0.2461
epoch [44/1000], loss:0.2503
epoch [45/1000], loss:0.2892
epoch [46/1000], loss:0.2715
epoch [47/1000], loss:0.2635
epoch [48/1000], loss:0.2492
epoch [49/1000], loss:0.2676
epoch [50/1000], loss:0.2636
epoch [51/1000], loss:0.2269
epoch [52/1000], loss:0.2280
epoch [53/1000], loss:0.2540
epoch [54/1000], loss:0.2465
epoch [55/1000], loss:0.2533
epoch [56/1000], loss:0.2762
epoch [57/1000], loss:0.2701
epoch [58/1000], loss:0.2389
epoch [59/1000], loss:0.2433
epoch [60/1000], loss:0.2521
epoch [61/1000], loss:0.2634
epoch [62/1000], loss:0.2244
epoch [63/1000], loss:0.2742
epoch [64/1000], loss:0.2327
epoch [65/1000], loss:0.2265
epoch [66/1000], loss:0.2631
epoch [67/1000], loss:0.2541
epoch [68/1000], loss:0.2563
epoch [69/1000], loss:0.2548
epoch [70/1000], loss:0.2906
epoch [71/1000], loss:0.2532
epoch [72/1000], loss:0.2912
epoch [73/1000], loss:0.2573
epoch [74/1000], loss:0.2257
epoch [75/1000], loss:0.2304
epoch [76/1000], loss:0.2328
epoch [77/1000], loss:0.2703
epoch [78/1000], loss:0.2532
epoch [79/1000], loss:0.2288
epoch [80/1000], loss:0.2499
epoch [81/1000], loss:0.2417
epoch [82/1000], loss:0.2583
epoch [83/1000], loss:0.2386
epoch [84/1000], loss:0.2603
epoch [85/1000], loss:0.2728
epoch [86/1000], loss:0.2634
epoch [87/1000], loss:0.2557
epoch [88/1000], loss:0.2523
epoch [89/1000], loss:0.2510
epoch [90/1000], loss:0.2788
epoch [91/1000], loss:0.2454
epoch [92/1000], loss:0.2635
epoch [93/1000], loss:0.2677
epoch [94/1000], loss:0.2130
epoch [95/1000], loss:0.2564
epoch [96/1000], loss:0.2474
epoch [97/1000], loss:0.2316
epoch [98/1000], loss:0.2303
epoch [99/1000], loss:0.2416
epoch [100/1000], loss:0.2476
epoch [101/1000], loss:0.2511
epoch [102/1000], loss:0.2583
epoch [103/1000], loss:0.2610
epoch [104/1000], loss:0.2770
epoch [105/1000], loss:0.2647
epoch [106/1000], loss:0.2736
epoch [107/1000], loss:0.2679
epoch [108/1000], loss:0.2406
epoch [109/1000], loss:0.2589
epoch [110/1000], loss:0.2255
epoch [111/1000], loss:0.2520
epoch [112/1000], loss:0.2610
epoch [113/1000], loss:0.2587
epoch [114/1000], loss:0.2329
epoch [115/1000], loss:0.2484
epoch [116/1000], loss:0.2382
epoch [117/1000], loss:0.2510
epoch [118/1000], loss:0.2565
epoch [119/1000], loss:0.2228
epoch [120/1000], loss:0.2412
epoch [121/1000], loss:0.2668
epoch [122/1000], loss:0.2475
epoch [123/1000], loss:0.2300
epoch [124/1000], loss:0.2489
epoch [125/1000], loss:0.2479
epoch [126/1000], loss:0.2435
epoch [127/1000], loss:0.2563
epoch [128/1000], loss:0.2633
epoch [129/1000], loss:0.2426
epoch [130/1000], loss:0.2260
epoch [131/1000], loss:0.2722
epoch [132/1000], loss:0.2800
epoch [133/1000], loss:0.2313
epoch [134/1000], loss:0.2458
epoch [135/1000], loss:0.2796
epoch [136/1000], loss:0.2428
epoch [137/1000], loss:0.2643
epoch [138/1000], loss:0.2503
epoch [139/1000], loss:0.2494
epoch [140/1000], loss:0.2237
epoch [141/1000], loss:0.2408
epoch [142/1000], loss:0.2520
epoch [143/1000], loss:0.2365
epoch [144/1000], loss:0.2293
epoch [145/1000], loss:0.2407
epoch [146/1000], loss:0.2616
epoch [147/1000], loss:0.2404
epoch [148/1000], loss:0.3104
epoch [149/1000], loss:0.2616
epoch [150/1000], loss:0.2410
epoch [151/1000], loss:0.2588
epoch [152/1000], loss:0.2291
epoch [153/1000], loss:0.2687
epoch [154/1000], loss:0.2820
epoch [155/1000], loss:0.2449
epoch [156/1000], loss:0.2641
epoch [157/1000], loss:0.2175
epoch [158/1000], loss:0.2721
epoch [159/1000], loss:0.2589
epoch [160/1000], loss:0.2604
epoch [161/1000], loss:0.2269
epoch [162/1000], loss:0.2372
epoch [163/1000], loss:0.2637
epoch [164/1000], loss:0.2272
epoch [165/1000], loss:0.2769
epoch [166/1000], loss:0.2735
epoch [167/1000], loss:0.2238
epoch [168/1000], loss:0.2486
epoch [169/1000], loss:0.2782
epoch [170/1000], loss:0.2684
epoch [171/1000], loss:0.2510
epoch [172/1000], loss:0.2624
epoch [173/1000], loss:0.2593
epoch [174/1000], loss:0.2757
epoch [175/1000], loss:0.2628
epoch [176/1000], loss:0.2583
epoch [177/1000], loss:0.2486
epoch [178/1000], loss:0.2739
epoch [179/1000], loss:0.2569
epoch [180/1000], loss:0.2540
epoch [181/1000], loss:0.2670
epoch [182/1000], loss:0.2339
epoch [183/1000], loss:0.2346
epoch [184/1000], loss:0.2925
epoch [185/1000], loss:0.2664
epoch [186/1000], loss:0.2730
epoch [187/1000], loss:0.2480
epoch [188/1000], loss:0.2414
epoch [189/1000], loss:0.2796
epoch [190/1000], loss:0.2642
epoch [191/1000], loss:0.2609
epoch [192/1000], loss:0.2401
epoch [193/1000], loss:0.2368
epoch [194/1000], loss:0.2525
epoch [195/1000], loss:0.2642
epoch [196/1000], loss:0.2574
epoch [197/1000], loss:0.2571
epoch [198/1000], loss:0.2540
epoch [199/1000], loss:0.2620
epoch [200/1000], loss:0.2627
epoch [201/1000], loss:0.2538
epoch [202/1000], loss:0.2611
epoch [203/1000], loss:0.2764
epoch [204/1000], loss:0.2915
epoch [205/1000], loss:0.2389
epoch [206/1000], loss:0.2206
epoch [207/1000], loss:0.2576
epoch [208/1000], loss:0.2513
epoch [209/1000], loss:0.2231
epoch [210/1000], loss:0.2506
epoch [211/1000], loss:0.2438
epoch [212/1000], loss:0.2531
epoch [213/1000], loss:0.2600
epoch [214/1000], loss:0.2364
epoch [215/1000], loss:0.2541
epoch [216/1000], loss:0.2763
epoch [217/1000], loss:0.2455
epoch [218/1000], loss:0.2452
epoch [219/1000], loss:0.2660
epoch [220/1000], loss:0.2480
epoch [221/1000], loss:0.2418
epoch [222/1000], loss:0.2753
epoch [223/1000], loss:0.2715
epoch [224/1000], loss:0.2284
epoch [225/1000], loss:0.2278
epoch [226/1000], loss:0.2512
epoch [227/1000], loss:0.2533
epoch [228/1000], loss:0.2436
epoch [229/1000], loss:0.2577
epoch [230/1000], loss:0.2588
epoch [231/1000], loss:0.2705
epoch [232/1000], loss:0.2273
epoch [233/1000], loss:0.2475
epoch [234/1000], loss:0.2480
epoch [235/1000], loss:0.2533
epoch [236/1000], loss:0.2874
epoch [237/1000], loss:0.2440
epoch [238/1000], loss:0.2415
epoch [239/1000], loss:0.2444
epoch [240/1000], loss:0.2716
epoch [241/1000], loss:0.2689
epoch [242/1000], loss:0.2521
epoch [243/1000], loss:0.2386
epoch [244/1000], loss:0.2526
epoch [245/1000], loss:0.2298
epoch [246/1000], loss:0.2550
epoch [247/1000], loss:0.2261
epoch [248/1000], loss:0.2364
epoch [249/1000], loss:0.2290
epoch [250/1000], loss:0.2769
epoch [251/1000], loss:0.2333
epoch [252/1000], loss:0.2530
epoch [253/1000], loss:0.2282
epoch [254/1000], loss:0.2411
epoch [255/1000], loss:0.2184
epoch [256/1000], loss:0.2505
epoch [257/1000], loss:0.2744
epoch [258/1000], loss:0.2539
epoch [259/1000], loss:0.2730
epoch [260/1000], loss:0.2452
epoch [261/1000], loss:0.2375
epoch [262/1000], loss:0.2487
epoch [263/1000], loss:0.2472
epoch [264/1000], loss:0.2470
epoch [265/1000], loss:0.2426
epoch [266/1000], loss:0.2652
epoch [267/1000], loss:0.2532
epoch [268/1000], loss:0.2840
epoch [269/1000], loss:0.2585
epoch [270/1000], loss:0.2498
epoch [271/1000], loss:0.2403
epoch [272/1000], loss:0.2600
epoch [273/1000], loss:0.2408
epoch [274/1000], loss:0.2417
epoch [275/1000], loss:0.2651
epoch [276/1000], loss:0.2295
epoch [277/1000], loss:0.2624
epoch [278/1000], loss:0.2820
epoch [279/1000], loss:0.2341
epoch [280/1000], loss:0.2728
epoch [281/1000], loss:0.2579
epoch [282/1000], loss:0.2703
epoch [283/1000], loss:0.2625
epoch [284/1000], loss:0.2600
epoch [285/1000], loss:0.2431
epoch [286/1000], loss:0.2478
epoch [287/1000], loss:0.2215
epoch [288/1000], loss:0.2387
epoch [289/1000], loss:0.2742
epoch [290/1000], loss:0.2767
epoch [291/1000], loss:0.2241
epoch [292/1000], loss:0.2456
epoch [293/1000], loss:0.2445
epoch [294/1000], loss:0.2558
epoch [295/1000], loss:0.2277
epoch [296/1000], loss:0.2750
epoch [297/1000], loss:0.2419
epoch [298/1000], loss:0.2788
epoch [299/1000], loss:0.2688
epoch [300/1000], loss:0.2632
epoch [301/1000], loss:0.2397
epoch [302/1000], loss:0.2310
epoch [303/1000], loss:0.2640
epoch [304/1000], loss:0.2463
epoch [305/1000], loss:0.2387
epoch [306/1000], loss:0.2441
epoch [307/1000], loss:0.2602
epoch [308/1000], loss:0.2371
epoch [309/1000], loss:0.2778
epoch [310/1000], loss:0.2506
epoch [311/1000], loss:0.2707
epoch [312/1000], loss:0.2390
epoch [313/1000], loss:0.2758
epoch [314/1000], loss:0.2503
epoch [315/1000], loss:0.2328
epoch [316/1000], loss:0.2483
epoch [317/1000], loss:0.2458
epoch [318/1000], loss:0.2744
epoch [319/1000], loss:0.2577
epoch [320/1000], loss:0.2832
epoch [321/1000], loss:0.2703
epoch [322/1000], loss:0.2514
epoch [323/1000], loss:0.2351
epoch [324/1000], loss:0.2388
epoch [325/1000], loss:0.2571
epoch [326/1000], loss:0.2666
epoch [327/1000], loss:0.2683
epoch [328/1000], loss:0.2426
epoch [329/1000], loss:0.2351
epoch [330/1000], loss:0.2441
epoch [331/1000], loss:0.2346
epoch [332/1000], loss:0.2828
epoch [333/1000], loss:0.2357
epoch [334/1000], loss:0.2686
epoch [335/1000], loss:0.2436
epoch [336/1000], loss:0.2393
epoch [337/1000], loss:0.2547
epoch [338/1000], loss:0.2833
epoch [339/1000], loss:0.2502
epoch [340/1000], loss:0.2572
epoch [341/1000], loss:0.2537
epoch [342/1000], loss:0.2686
epoch [343/1000], loss:0.2728
epoch [344/1000], loss:0.2438
epoch [345/1000], loss:0.2434
epoch [346/1000], loss:0.2324
epoch [347/1000], loss:0.2491
epoch [348/1000], loss:0.2639
epoch [349/1000], loss:0.2672
epoch [350/1000], loss:0.2490
epoch [351/1000], loss:0.2105
epoch [352/1000], loss:0.2471
epoch [353/1000], loss:0.2651
epoch [354/1000], loss:0.2695
epoch [355/1000], loss:0.2379
epoch [356/1000], loss:0.2697
epoch [357/1000], loss:0.2554
epoch [358/1000], loss:0.2453
epoch [359/1000], loss:0.2656
epoch [360/1000], loss:0.2211
epoch [361/1000], loss:0.2395
epoch [362/1000], loss:0.2595
epoch [363/1000], loss:0.2742
epoch [364/1000], loss:0.2707
epoch [365/1000], loss:0.2198
epoch [366/1000], loss:0.2333
epoch [367/1000], loss:0.2519
epoch [368/1000], loss:0.2398
epoch [369/1000], loss:0.2235
epoch [370/1000], loss:0.2535
epoch [371/1000], loss:0.2459
epoch [372/1000], loss:0.2537
epoch [373/1000], loss:0.2582
epoch [374/1000], loss:0.2646
epoch [375/1000], loss:0.2599
epoch [376/1000], loss:0.2349
epoch [377/1000], loss:0.2483
epoch [378/1000], loss:0.2527
epoch [379/1000], loss:0.2406
epoch [380/1000], loss:0.2793
epoch [381/1000], loss:0.2506
epoch [382/1000], loss:0.2622
epoch [383/1000], loss:0.2544
epoch [384/1000], loss:0.2371
epoch [385/1000], loss:0.2506
epoch [386/1000], loss:0.2592
epoch [387/1000], loss:0.2676
epoch [388/1000], loss:0.2377
epoch [389/1000], loss:0.2426
epoch [390/1000], loss:0.2433
epoch [391/1000], loss:0.2252
epoch [392/1000], loss:0.2578
epoch [393/1000], loss:0.2211
epoch [394/1000], loss:0.2334
epoch [395/1000], loss:0.2749
epoch [396/1000], loss:0.2447
epoch [397/1000], loss:0.2459
epoch [398/1000], loss:0.2486
epoch [399/1000], loss:0.2681
epoch [400/1000], loss:0.2425
epoch [401/1000], loss:0.2662
epoch [402/1000], loss:0.2399
epoch [403/1000], loss:0.2621
epoch [404/1000], loss:0.2690
epoch [405/1000], loss:0.2430
epoch [406/1000], loss:0.2536
epoch [407/1000], loss:0.2264
epoch [408/1000], loss:0.2609
epoch [409/1000], loss:0.2582
epoch [410/1000], loss:0.2617
epoch [411/1000], loss:0.2385
epoch [412/1000], loss:0.2603
epoch [413/1000], loss:0.2576
epoch [414/1000], loss:0.2614
epoch [415/1000], loss:0.2670
epoch [416/1000], loss:0.2438
epoch [417/1000], loss:0.2395
epoch [418/1000], loss:0.2719
epoch [419/1000], loss:0.2742
epoch [420/1000], loss:0.2527
epoch [421/1000], loss:0.2765
epoch [422/1000], loss:0.2508
epoch [423/1000], loss:0.2531
epoch [424/1000], loss:0.2317
epoch [425/1000], loss:0.2495
epoch [426/1000], loss:0.2665
epoch [427/1000], loss:0.2593
epoch [428/1000], loss:0.2715
epoch [429/1000], loss:0.2561
epoch [430/1000], loss:0.2501
epoch [431/1000], loss:0.2503
epoch [432/1000], loss:0.2603
epoch [433/1000], loss:0.2264
epoch [434/1000], loss:0.2879
epoch [435/1000], loss:0.2396
epoch [436/1000], loss:0.2585
epoch [437/1000], loss:0.2631
epoch [438/1000], loss:0.2547
epoch [439/1000], loss:0.2453
epoch [440/1000], loss:0.2732
epoch [441/1000], loss:0.2315
epoch [442/1000], loss:0.2773
epoch [443/1000], loss:0.2526
epoch [444/1000], loss:0.2929
epoch [445/1000], loss:0.2248
epoch [446/1000], loss:0.3030
epoch [447/1000], loss:0.2620
epoch [448/1000], loss:0.2514
epoch [449/1000], loss:0.2505
epoch [450/1000], loss:0.2674
epoch [451/1000], loss:0.2107
epoch [452/1000], loss:0.2331
epoch [453/1000], loss:0.2493
epoch [454/1000], loss:0.2508
epoch [455/1000], loss:0.2393
epoch [456/1000], loss:0.2226
epoch [457/1000], loss:0.2529
epoch [458/1000], loss:0.2189
epoch [459/1000], loss:0.2476
epoch [460/1000], loss:0.2641
epoch [461/1000], loss:0.2752
epoch [462/1000], loss:0.2503
epoch [463/1000], loss:0.2692
epoch [464/1000], loss:0.2637
epoch [465/1000], loss:0.2759
epoch [466/1000], loss:0.2631
epoch [467/1000], loss:0.2464
epoch [468/1000], loss:0.2792
epoch [469/1000], loss:0.2467
epoch [470/1000], loss:0.2545
epoch [471/1000], loss:0.2386
epoch [472/1000], loss:0.2733
epoch [473/1000], loss:0.2604
epoch [474/1000], loss:0.3001
epoch [475/1000], loss:0.2307
epoch [476/1000], loss:0.2630
epoch [477/1000], loss:0.2731
epoch [478/1000], loss:0.2190
epoch [479/1000], loss:0.2658
epoch [480/1000], loss:0.2590
epoch [481/1000], loss:0.2608
epoch [482/1000], loss:0.2714
epoch [483/1000], loss:0.2594
epoch [484/1000], loss:0.2154
epoch [485/1000], loss:0.2902
epoch [486/1000], loss:0.2503
epoch [487/1000], loss:0.2624
epoch [488/1000], loss:0.2402
epoch [489/1000], loss:0.2443
epoch [490/1000], loss:0.2518
epoch [491/1000], loss:0.2400
epoch [492/1000], loss:0.2465
epoch [493/1000], loss:0.2786
epoch [494/1000], loss:0.2538
epoch [495/1000], loss:0.2590
epoch [496/1000], loss:0.2549
epoch [497/1000], loss:0.2347
epoch [498/1000], loss:0.2587
epoch [499/1000], loss:0.2318
epoch [500/1000], loss:0.2598
epoch [501/1000], loss:0.2455
epoch [502/1000], loss:0.2573
epoch [503/1000], loss:0.2513
epoch [504/1000], loss:0.2558
epoch [505/1000], loss:0.2548
epoch [506/1000], loss:0.2515
epoch [507/1000], loss:0.2373
epoch [508/1000], loss:0.2621
epoch [509/1000], loss:0.2479
epoch [510/1000], loss:0.2453
epoch [511/1000], loss:0.2533
epoch [512/1000], loss:0.2606
epoch [513/1000], loss:0.2357
epoch [514/1000], loss:0.2432
epoch [515/1000], loss:0.2622
epoch [516/1000], loss:0.2365
epoch [517/1000], loss:0.2402
epoch [518/1000], loss:0.2498
epoch [519/1000], loss:0.2513
epoch [520/1000], loss:0.2477
epoch [521/1000], loss:0.2588
epoch [522/1000], loss:0.2627
epoch [523/1000], loss:0.2451
epoch [524/1000], loss:0.2521
epoch [525/1000], loss:0.2116
epoch [526/1000], loss:0.2411
epoch [527/1000], loss:0.2521
epoch [528/1000], loss:0.2715
epoch [529/1000], loss:0.2831
epoch [530/1000], loss:0.2694
epoch [531/1000], loss:0.2536
epoch [532/1000], loss:0.2276
epoch [533/1000], loss:0.2558
epoch [534/1000], loss:0.2846
epoch [535/1000], loss:0.2709
epoch [536/1000], loss:0.2378
epoch [537/1000], loss:0.2502
epoch [538/1000], loss:0.2792
epoch [539/1000], loss:0.2675
epoch [540/1000], loss:0.2825
epoch [541/1000], loss:0.2427
epoch [542/1000], loss:0.2675
epoch [543/1000], loss:0.2360
epoch [544/1000], loss:0.2554
epoch [545/1000], loss:0.2644
epoch [546/1000], loss:0.2774
epoch [547/1000], loss:0.2420
epoch [548/1000], loss:0.2383
epoch [549/1000], loss:0.2894
epoch [550/1000], loss:0.2760
epoch [551/1000], loss:0.2502
epoch [552/1000], loss:0.2798
epoch [553/1000], loss:0.2554
epoch [554/1000], loss:0.2379
epoch [555/1000], loss:0.2428
epoch [556/1000], loss:0.2574
epoch [557/1000], loss:0.2589
epoch [558/1000], loss:0.2619
epoch [559/1000], loss:0.2608
epoch [560/1000], loss:0.2400
epoch [561/1000], loss:0.2548
epoch [562/1000], loss:0.2401
epoch [563/1000], loss:0.2492
epoch [564/1000], loss:0.2592
epoch [565/1000], loss:0.2283
epoch [566/1000], loss:0.2574
epoch [567/1000], loss:0.2614
epoch [568/1000], loss:0.2549
epoch [569/1000], loss:0.2566
epoch [570/1000], loss:0.2516
epoch [571/1000], loss:0.2550
epoch [572/1000], loss:0.2269
epoch [573/1000], loss:0.2525
epoch [574/1000], loss:0.2209
epoch [575/1000], loss:0.2750
epoch [576/1000], loss:0.2471
epoch [577/1000], loss:0.2591
epoch [578/1000], loss:0.2342
epoch [579/1000], loss:0.2574
epoch [580/1000], loss:0.2910
epoch [581/1000], loss:0.2420
epoch [582/1000], loss:0.2457
epoch [583/1000], loss:0.2442
epoch [584/1000], loss:0.2584
epoch [585/1000], loss:0.2614
epoch [586/1000], loss:0.2477
epoch [587/1000], loss:0.2551
epoch [588/1000], loss:0.2168
epoch [589/1000], loss:0.2321
epoch [590/1000], loss:0.2499
epoch [591/1000], loss:0.2696
epoch [592/1000], loss:0.2405
epoch [593/1000], loss:0.2603
epoch [594/1000], loss:0.2668
epoch [595/1000], loss:0.2498
epoch [596/1000], loss:0.2780
epoch [597/1000], loss:0.2284
epoch [598/1000], loss:0.2189
epoch [599/1000], loss:0.2208
epoch [600/1000], loss:0.2420
epoch [601/1000], loss:0.2324
epoch [602/1000], loss:0.2607
epoch [603/1000], loss:0.2531
epoch [604/1000], loss:0.2494
epoch [605/1000], loss:0.2555
epoch [606/1000], loss:0.2423
epoch [607/1000], loss:0.2430
epoch [608/1000], loss:0.2336
epoch [609/1000], loss:0.2475
epoch [610/1000], loss:0.2550
epoch [611/1000], loss:0.2906
epoch [612/1000], loss:0.2548
epoch [613/1000], loss:0.2354
epoch [614/1000], loss:0.2642
epoch [615/1000], loss:0.2629
epoch [616/1000], loss:0.2414
epoch [617/1000], loss:0.2693
epoch [618/1000], loss:0.2530
epoch [619/1000], loss:0.2641
epoch [620/1000], loss:0.2726
epoch [621/1000], loss:0.2598
epoch [622/1000], loss:0.2452
epoch [623/1000], loss:0.2357
epoch [624/1000], loss:0.2550
epoch [625/1000], loss:0.2809
epoch [626/1000], loss:0.2389
epoch [627/1000], loss:0.2780
epoch [628/1000], loss:0.2559
epoch [629/1000], loss:0.2377
epoch [630/1000], loss:0.2639
epoch [631/1000], loss:0.2304
epoch [632/1000], loss:0.2656
epoch [633/1000], loss:0.2514
epoch [634/1000], loss:0.2302
epoch [635/1000], loss:0.2800
epoch [636/1000], loss:0.2569
epoch [637/1000], loss:0.2314
epoch [638/1000], loss:0.2479
epoch [639/1000], loss:0.2368
epoch [640/1000], loss:0.2750
epoch [641/1000], loss:0.2547
epoch [642/1000], loss:0.2523
epoch [643/1000], loss:0.2414
epoch [644/1000], loss:0.2425
epoch [645/1000], loss:0.2226
epoch [646/1000], loss:0.2483
epoch [647/1000], loss:0.2529
epoch [648/1000], loss:0.2497
epoch [649/1000], loss:0.2491
epoch [650/1000], loss:0.2613
epoch [651/1000], loss:0.2772
epoch [652/1000], loss:0.2142
epoch [653/1000], loss:0.2421
epoch [654/1000], loss:0.2371
epoch [655/1000], loss:0.2438
epoch [656/1000], loss:0.2628
epoch [657/1000], loss:0.2724
epoch [658/1000], loss:0.2529
epoch [659/1000], loss:0.2934
epoch [660/1000], loss:0.2407
epoch [661/1000], loss:0.2719
epoch [662/1000], loss:0.2738
epoch [663/1000], loss:0.2506
epoch [664/1000], loss:0.2482
epoch [665/1000], loss:0.2517
epoch [666/1000], loss:0.2181
epoch [667/1000], loss:0.2534
epoch [668/1000], loss:0.2152
epoch [669/1000], loss:0.2473
epoch [670/1000], loss:0.2644
epoch [671/1000], loss:0.2500
epoch [672/1000], loss:0.2684
epoch [673/1000], loss:0.2310
epoch [674/1000], loss:0.2393
epoch [675/1000], loss:0.2243
epoch [676/1000], loss:0.2435
epoch [677/1000], loss:0.2530
epoch [678/1000], loss:0.2568
epoch [679/1000], loss:0.2536
epoch [680/1000], loss:0.2451
epoch [681/1000], loss:0.2552
epoch [682/1000], loss:0.2288
epoch [683/1000], loss:0.2235
epoch [684/1000], loss:0.2490
epoch [685/1000], loss:0.2734
epoch [686/1000], loss:0.2513
epoch [687/1000], loss:0.2598
epoch [688/1000], loss:0.2545
epoch [689/1000], loss:0.2612
epoch [690/1000], loss:0.2221
epoch [691/1000], loss:0.2367
epoch [692/1000], loss:0.2340
epoch [693/1000], loss:0.2676
epoch [694/1000], loss:0.2649
epoch [695/1000], loss:0.2782
epoch [696/1000], loss:0.2506
epoch [697/1000], loss:0.2497
epoch [698/1000], loss:0.2528
epoch [699/1000], loss:0.2381
epoch [700/1000], loss:0.2581
epoch [701/1000], loss:0.2445
epoch [702/1000], loss:0.2283
epoch [703/1000], loss:0.2557
epoch [704/1000], loss:0.2587
epoch [705/1000], loss:0.2493
epoch [706/1000], loss:0.2704
epoch [707/1000], loss:0.2452
epoch [708/1000], loss:0.2673
epoch [709/1000], loss:0.2607
epoch [710/1000], loss:0.2370
epoch [711/1000], loss:0.2546
epoch [712/1000], loss:0.2446
epoch [713/1000], loss:0.2614
epoch [714/1000], loss:0.2662
epoch [715/1000], loss:0.2255
epoch [716/1000], loss:0.2529
epoch [717/1000], loss:0.2716
epoch [718/1000], loss:0.2506
epoch [719/1000], loss:0.2566
epoch [720/1000], loss:0.2198
epoch [721/1000], loss:0.2471
epoch [722/1000], loss:0.2802
epoch [723/1000], loss:0.2348
epoch [724/1000], loss:0.2691
epoch [725/1000], loss:0.2604
epoch [726/1000], loss:0.2882
epoch [727/1000], loss:0.2529
epoch [728/1000], loss:0.2441
epoch [729/1000], loss:0.2557
epoch [730/1000], loss:0.2723
epoch [731/1000], loss:0.2654
epoch [732/1000], loss:0.2362
epoch [733/1000], loss:0.2651
epoch [734/1000], loss:0.2527
epoch [735/1000], loss:0.2645
epoch [736/1000], loss:0.2323
epoch [737/1000], loss:0.2488
epoch [738/1000], loss:0.2596
epoch [739/1000], loss:0.2276
epoch [740/1000], loss:0.2470
epoch [741/1000], loss:0.2307
epoch [742/1000], loss:0.2275
epoch [743/1000], loss:0.2611
epoch [744/1000], loss:0.2633
epoch [745/1000], loss:0.2451
epoch [746/1000], loss:0.2733
epoch [747/1000], loss:0.2599
epoch [748/1000], loss:0.2224
epoch [749/1000], loss:0.2372
epoch [750/1000], loss:0.2370
epoch [751/1000], loss:0.2606
epoch [752/1000], loss:0.2518
epoch [753/1000], loss:0.2529
epoch [754/1000], loss:0.2612
epoch [755/1000], loss:0.2347
epoch [756/1000], loss:0.2359
epoch [757/1000], loss:0.2553
epoch [758/1000], loss:0.2661
epoch [759/1000], loss:0.2375
epoch [760/1000], loss:0.2632
epoch [761/1000], loss:0.2589
epoch [762/1000], loss:0.2271
epoch [763/1000], loss:0.2587
epoch [764/1000], loss:0.2441
epoch [765/1000], loss:0.2506
epoch [766/1000], loss:0.2437
epoch [767/1000], loss:0.2308
epoch [768/1000], loss:0.2172
epoch [769/1000], loss:0.2270
epoch [770/1000], loss:0.2563
epoch [771/1000], loss:0.2399
epoch [772/1000], loss:0.2744
epoch [773/1000], loss:0.2825
epoch [774/1000], loss:0.2577
epoch [775/1000], loss:0.2145
epoch [776/1000], loss:0.2752
epoch [777/1000], loss:0.2237
epoch [778/1000], loss:0.2347
epoch [779/1000], loss:0.2642
epoch [780/1000], loss:0.2543
epoch [781/1000], loss:0.2443
epoch [782/1000], loss:0.2623
epoch [783/1000], loss:0.2527
epoch [784/1000], loss:0.2183
epoch [785/1000], loss:0.2595
epoch [786/1000], loss:0.2620
epoch [787/1000], loss:0.2547
epoch [788/1000], loss:0.2423
epoch [789/1000], loss:0.2464
epoch [790/1000], loss:0.2570
epoch [791/1000], loss:0.2673
epoch [792/1000], loss:0.2651
epoch [793/1000], loss:0.2193
epoch [794/1000], loss:0.2277
epoch [795/1000], loss:0.2609
epoch [796/1000], loss:0.2343
epoch [797/1000], loss:0.2775
epoch [798/1000], loss:0.2480
epoch [799/1000], loss:0.2624
epoch [800/1000], loss:0.2325
epoch [801/1000], loss:0.2423
epoch [802/1000], loss:0.2682
epoch [803/1000], loss:0.2578
epoch [804/1000], loss:0.2710
epoch [805/1000], loss:0.2166
epoch [806/1000], loss:0.2562
epoch [807/1000], loss:0.2403
epoch [808/1000], loss:0.2233
epoch [809/1000], loss:0.2492
epoch [810/1000], loss:0.2165
epoch [811/1000], loss:0.2394
epoch [812/1000], loss:0.2494
epoch [813/1000], loss:0.2545
epoch [814/1000], loss:0.2457
epoch [815/1000], loss:0.2592
epoch [816/1000], loss:0.2466
epoch [817/1000], loss:0.2438
epoch [818/1000], loss:0.2413
epoch [819/1000], loss:0.2774
epoch [820/1000], loss:0.2784
epoch [821/1000], loss:0.2649
epoch [822/1000], loss:0.2614
epoch [823/1000], loss:0.2588
epoch [824/1000], loss:0.2528
epoch [825/1000], loss:0.2389
epoch [826/1000], loss:0.2368
epoch [827/1000], loss:0.2382
epoch [828/1000], loss:0.2555
epoch [829/1000], loss:0.2468
epoch [830/1000], loss:0.2574
epoch [831/1000], loss:0.2752
epoch [832/1000], loss:0.2375
epoch [833/1000], loss:0.2526
epoch [834/1000], loss:0.2439
epoch [835/1000], loss:0.2539
epoch [836/1000], loss:0.2434
epoch [837/1000], loss:0.2505
epoch [838/1000], loss:0.2777
epoch [839/1000], loss:0.2396
epoch [840/1000], loss:0.2512
epoch [841/1000], loss:0.2603
epoch [842/1000], loss:0.2415
epoch [843/1000], loss:0.2402
epoch [844/1000], loss:0.2529
epoch [845/1000], loss:0.2524
epoch [846/1000], loss:0.2702
epoch [847/1000], loss:0.2154
epoch [848/1000], loss:0.2688
epoch [849/1000], loss:0.2683
epoch [850/1000], loss:0.2583
epoch [851/1000], loss:0.2563
epoch [852/1000], loss:0.2472
epoch [853/1000], loss:0.2686
epoch [854/1000], loss:0.2628
epoch [855/1000], loss:0.2476
epoch [856/1000], loss:0.2473
epoch [857/1000], loss:0.2521
epoch [858/1000], loss:0.2828
epoch [859/1000], loss:0.2417
epoch [860/1000], loss:0.2772
epoch [861/1000], loss:0.2510
epoch [862/1000], loss:0.2389
epoch [863/1000], loss:0.2451
epoch [864/1000], loss:0.2360
epoch [865/1000], loss:0.2328
epoch [866/1000], loss:0.2388
epoch [867/1000], loss:0.2288
epoch [868/1000], loss:0.2471
epoch [869/1000], loss:0.2605
epoch [870/1000], loss:0.2558
epoch [871/1000], loss:0.2420
epoch [872/1000], loss:0.2251
epoch [873/1000], loss:0.2544
epoch [874/1000], loss:0.2438
epoch [875/1000], loss:0.2719
epoch [876/1000], loss:0.2348
epoch [877/1000], loss:0.2469
epoch [878/1000], loss:0.2869
epoch [879/1000], loss:0.2412
epoch [880/1000], loss:0.2572
epoch [881/1000], loss:0.2807
epoch [882/1000], loss:0.2565
epoch [883/1000], loss:0.2972
epoch [884/1000], loss:0.2521
epoch [885/1000], loss:0.2390
epoch [886/1000], loss:0.2553
epoch [887/1000], loss:0.2367
epoch [888/1000], loss:0.2447
epoch [889/1000], loss:0.2721
epoch [890/1000], loss:0.2803
epoch [891/1000], loss:0.2083
epoch [892/1000], loss:0.2662
epoch [893/1000], loss:0.2916
epoch [894/1000], loss:0.2884
epoch [895/1000], loss:0.2273
epoch [896/1000], loss:0.2602
epoch [897/1000], loss:0.2312
epoch [898/1000], loss:0.2515
epoch [899/1000], loss:0.3014
epoch [900/1000], loss:0.2425
epoch [901/1000], loss:0.2474
epoch [902/1000], loss:0.2730
epoch [903/1000], loss:0.2351
epoch [904/1000], loss:0.2698
epoch [905/1000], loss:0.2624
epoch [906/1000], loss:0.2527
epoch [907/1000], loss:0.2580
epoch [908/1000], loss:0.2642
epoch [909/1000], loss:0.2573
epoch [910/1000], loss:0.2451
epoch [911/1000], loss:0.2776
epoch [912/1000], loss:0.2364
epoch [913/1000], loss:0.2389
epoch [914/1000], loss:0.2474
epoch [915/1000], loss:0.2454
epoch [916/1000], loss:0.2277
epoch [917/1000], loss:0.2339
epoch [918/1000], loss:0.2343
epoch [919/1000], loss:0.2763
epoch [920/1000], loss:0.2513
epoch [921/1000], loss:0.2491
epoch [922/1000], loss:0.2379
epoch [923/1000], loss:0.2558
epoch [924/1000], loss:0.2418
epoch [925/1000], loss:0.2424
epoch [926/1000], loss:0.2616
epoch [927/1000], loss:0.2567
epoch [928/1000], loss:0.2267
epoch [929/1000], loss:0.2335
epoch [930/1000], loss:0.2399
epoch [931/1000], loss:0.2589
epoch [932/1000], loss:0.2747
epoch [933/1000], loss:0.2577
epoch [934/1000], loss:0.2512
epoch [935/1000], loss:0.2418
epoch [936/1000], loss:0.2359
epoch [937/1000], loss:0.2690
epoch [938/1000], loss:0.2234
epoch [939/1000], loss:0.2475
epoch [940/1000], loss:0.2500
epoch [941/1000], loss:0.2870
epoch [942/1000], loss:0.2366
epoch [943/1000], loss:0.2707
epoch [944/1000], loss:0.2843
epoch [945/1000], loss:0.2529
epoch [946/1000], loss:0.2566
epoch [947/1000], loss:0.2609
epoch [948/1000], loss:0.2522
epoch [949/1000], loss:0.2667
epoch [950/1000], loss:0.2201
epoch [951/1000], loss:0.2463
epoch [952/1000], loss:0.2447
epoch [953/1000], loss:0.2828
epoch [954/1000], loss:0.3006
epoch [955/1000], loss:0.2669
epoch [956/1000], loss:0.2244
epoch [957/1000], loss:0.2670
epoch [958/1000], loss:0.2275
epoch [959/1000], loss:0.2621
epoch [960/1000], loss:0.2298
epoch [961/1000], loss:0.2376
epoch [962/1000], loss:0.2533
epoch [963/1000], loss:0.2661
epoch [964/1000], loss:0.2361
epoch [965/1000], loss:0.2569
epoch [966/1000], loss:0.2497
epoch [967/1000], loss:0.2824
epoch [968/1000], loss:0.2899
epoch [969/1000], loss:0.2375
epoch [970/1000], loss:0.2424
epoch [971/1000], loss:0.2526
epoch [972/1000], loss:0.2592
epoch [973/1000], loss:0.2564
epoch [974/1000], loss:0.2611
epoch [975/1000], loss:0.2699
epoch [976/1000], loss:0.2429
epoch [977/1000], loss:0.2609
epoch [978/1000], loss:0.2539
epoch [979/1000], loss:0.2379
epoch [980/1000], loss:0.3064
epoch [981/1000], loss:0.2268
epoch [982/1000], loss:0.2291
epoch [983/1000], loss:0.2522
epoch [984/1000], loss:0.2367
epoch [985/1000], loss:0.2612
epoch [986/1000], loss:0.2687
epoch [987/1000], loss:0.2552
epoch [988/1000], loss:0.2559
epoch [989/1000], loss:0.2527
epoch [990/1000], loss:0.2519
epoch [991/1000], loss:0.2684
epoch [992/1000], loss:0.2680
epoch [993/1000], loss:0.2424
epoch [994/1000], loss:0.2650
epoch [995/1000], loss:0.2659
epoch [996/1000], loss:0.2775
epoch [997/1000], loss:0.2589
epoch [998/1000], loss:0.2558
epoch [999/1000], loss:0.2510
epoch [1000/1000], loss:0.2552
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim import Adam, AdamW
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset)if task == 'ae':num_epochs = 1000batch_size = 128learning_rate = 1e-3#{'fcn', 'cnn', 'vae'} model_type = 'fcn' x = trainif model_type == 'fcn' or model_type == 'vae':x = x.reshape(len(x), -1)data = torch.tensor(x, dtype=torch.float)train_dataset = TensorDataset(data)train_sampler = RandomSampler(train_dataset)train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)model_classes = {'fcn':fcn_autoencoder(), 'cnn':conv_autoencoder(), 'vae':VAE()}model = model_classes[model_type].cuda()criterion = nn.MSELoss()optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)best_loss = np.infmodel.train()for epoch in range(num_epochs):for data in train_dataloader:if model_type == 'cnn':img = data[0].transpose(3, 1).cuda()else:img = data[0].cuda()# ===================forward=====================output = model(img)if model_type == 'vae':loss = loss_vae(output[0], img, output[1], output[2], criterion)else:loss = criterion(output, img)# ===================backward====================optimizer.zero_grad()loss.backward()optimizer.step()# ===================save====================if loss.item() < best_loss:best_loss = loss.item()torch.save(model, 'best_model_{}.pt'.format(model_type))# ===================log========================print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
epoch [1/1000], loss:0.1453
epoch [2/1000], loss:0.1202
epoch [3/1000], loss:0.1129
epoch [4/1000], loss:0.1325
epoch [5/1000], loss:0.1200
epoch [6/1000], loss:0.1342
epoch [7/1000], loss:0.1035
epoch [8/1000], loss:0.1239
epoch [9/1000], loss:0.1267
epoch [10/1000], loss:0.1101
epoch [11/1000], loss:0.1245
epoch [12/1000], loss:0.1240
epoch [13/1000], loss:0.1189
epoch [14/1000], loss:0.1245
epoch [15/1000], loss:0.1138
epoch [16/1000], loss:0.1153
epoch [17/1000], loss:0.1249
epoch [18/1000], loss:0.1152
epoch [19/1000], loss:0.1220
epoch [20/1000], loss:0.1227
epoch [21/1000], loss:0.1276
epoch [22/1000], loss:0.1122
epoch [23/1000], loss:0.1127
epoch [24/1000], loss:0.1227
epoch [25/1000], loss:0.1130
epoch [26/1000], loss:0.1133
epoch [27/1000], loss:0.1245
epoch [28/1000], loss:0.1180
epoch [29/1000], loss:0.1276
epoch [30/1000], loss:0.1317
epoch [31/1000], loss:0.1141
epoch [32/1000], loss:0.1210
epoch [33/1000], loss:0.1151
epoch [34/1000], loss:0.1192
epoch [35/1000], loss:0.1175
epoch [36/1000], loss:0.1195
epoch [37/1000], loss:0.1247
epoch [38/1000], loss:0.1114
epoch [39/1000], loss:0.1314
epoch [40/1000], loss:0.1216
epoch [41/1000], loss:0.1085
epoch [42/1000], loss:0.1231
epoch [43/1000], loss:0.1193
epoch [44/1000], loss:0.1187
epoch [45/1000], loss:0.1102
epoch [46/1000], loss:0.1294
epoch [47/1000], loss:0.1164
epoch [48/1000], loss:0.1187
epoch [49/1000], loss:0.1123
epoch [50/1000], loss:0.1033
epoch [51/1000], loss:0.1340
epoch [52/1000], loss:0.1214
epoch [53/1000], loss:0.1288
epoch [54/1000], loss:0.1228
epoch [55/1000], loss:0.1117
epoch [56/1000], loss:0.1133
epoch [57/1000], loss:0.1283
epoch [58/1000], loss:0.1136
epoch [59/1000], loss:0.1221
epoch [60/1000], loss:0.1177
epoch [61/1000], loss:0.1179
epoch [62/1000], loss:0.1038
epoch [63/1000], loss:0.1101
epoch [64/1000], loss:0.1282
epoch [65/1000], loss:0.1122
epoch [66/1000], loss:0.1197
epoch [67/1000], loss:0.1176
epoch [68/1000], loss:0.1222
epoch [69/1000], loss:0.1153
epoch [70/1000], loss:0.1137
epoch [71/1000], loss:0.1127
epoch [72/1000], loss:0.1188
epoch [73/1000], loss:0.1136
epoch [74/1000], loss:0.1333
epoch [75/1000], loss:0.1208
epoch [76/1000], loss:0.1281
epoch [77/1000], loss:0.1110
epoch [78/1000], loss:0.1274
epoch [79/1000], loss:0.1248
epoch [80/1000], loss:0.1035
epoch [81/1000], loss:0.1105
epoch [82/1000], loss:0.1306
epoch [83/1000], loss:0.1137
epoch [84/1000], loss:0.1186
epoch [85/1000], loss:0.1133
epoch [86/1000], loss:0.1038
epoch [87/1000], loss:0.1205
epoch [88/1000], loss:0.1144
epoch [89/1000], loss:0.1230
epoch [90/1000], loss:0.1155
epoch [91/1000], loss:0.1233
epoch [92/1000], loss:0.1189
epoch [93/1000], loss:0.1251
epoch [94/1000], loss:0.1123
epoch [95/1000], loss:0.1300
epoch [96/1000], loss:0.1241
epoch [97/1000], loss:0.1151
epoch [98/1000], loss:0.1173
epoch [99/1000], loss:0.1135
epoch [100/1000], loss:0.1203
epoch [101/1000], loss:0.1263
epoch [102/1000], loss:0.1118
epoch [103/1000], loss:0.1118
epoch [104/1000], loss:0.1104
epoch [105/1000], loss:0.1079
epoch [106/1000], loss:0.1029
epoch [107/1000], loss:0.1095
epoch [108/1000], loss:0.1238
epoch [109/1000], loss:0.1138
epoch [110/1000], loss:0.1242
epoch [111/1000], loss:0.1079
epoch [112/1000], loss:0.1137
epoch [113/1000], loss:0.1285
epoch [114/1000], loss:0.1199
epoch [115/1000], loss:0.1073
epoch [116/1000], loss:0.1158
epoch [117/1000], loss:0.1237
epoch [118/1000], loss:0.1309
epoch [119/1000], loss:0.1100
epoch [120/1000], loss:0.1174
epoch [121/1000], loss:0.0937
epoch [122/1000], loss:0.1158
epoch [123/1000], loss:0.1279
epoch [124/1000], loss:0.1358
epoch [125/1000], loss:0.1203
epoch [126/1000], loss:0.1047
epoch [127/1000], loss:0.1262
epoch [128/1000], loss:0.1136
epoch [129/1000], loss:0.1118
epoch [130/1000], loss:0.1216
epoch [131/1000], loss:0.1257
epoch [132/1000], loss:0.1111
epoch [133/1000], loss:0.1185
epoch [134/1000], loss:0.1152
epoch [135/1000], loss:0.1088
epoch [136/1000], loss:0.1151
epoch [137/1000], loss:0.1254
epoch [138/1000], loss:0.1240
epoch [139/1000], loss:0.1214
epoch [140/1000], loss:0.1204
epoch [141/1000], loss:0.1189
epoch [142/1000], loss:0.1199
epoch [143/1000], loss:0.1181
epoch [144/1000], loss:0.1154
epoch [145/1000], loss:0.1274
epoch [146/1000], loss:0.1258
epoch [147/1000], loss:0.1088
epoch [148/1000], loss:0.1297
epoch [149/1000], loss:0.1164
epoch [150/1000], loss:0.1183
epoch [151/1000], loss:0.1179
epoch [152/1000], loss:0.1164
epoch [153/1000], loss:0.1236
epoch [154/1000], loss:0.1114
epoch [155/1000], loss:0.1113
epoch [156/1000], loss:0.0983
epoch [157/1000], loss:0.1217
epoch [158/1000], loss:0.1231
epoch [159/1000], loss:0.1068
epoch [160/1000], loss:0.1110
epoch [161/1000], loss:0.1109
epoch [162/1000], loss:0.1213
epoch [163/1000], loss:0.1322
epoch [164/1000], loss:0.1177
epoch [165/1000], loss:0.1306
epoch [166/1000], loss:0.1153
epoch [167/1000], loss:0.1203
epoch [168/1000], loss:0.1392
epoch [169/1000], loss:0.1272
epoch [170/1000], loss:0.1046
epoch [171/1000], loss:0.1217
epoch [172/1000], loss:0.1184
epoch [173/1000], loss:0.1046
epoch [174/1000], loss:0.1131
epoch [175/1000], loss:0.1165
epoch [176/1000], loss:0.1136
epoch [177/1000], loss:0.1313
epoch [178/1000], loss:0.1087
epoch [179/1000], loss:0.1174
epoch [180/1000], loss:0.1112
epoch [181/1000], loss:0.1037
epoch [182/1000], loss:0.1258
epoch [183/1000], loss:0.1108
epoch [184/1000], loss:0.1213
epoch [185/1000], loss:0.1204
epoch [186/1000], loss:0.1158
epoch [187/1000], loss:0.1079
epoch [188/1000], loss:0.1254
epoch [189/1000], loss:0.1124
epoch [190/1000], loss:0.1077
epoch [191/1000], loss:0.1185
epoch [192/1000], loss:0.1201
epoch [193/1000], loss:0.1152
epoch [194/1000], loss:0.1043
epoch [195/1000], loss:0.1257
epoch [196/1000], loss:0.1151
epoch [197/1000], loss:0.1265
epoch [198/1000], loss:0.1203
epoch [199/1000], loss:0.1232
epoch [200/1000], loss:0.1344
epoch [201/1000], loss:0.1097
epoch [202/1000], loss:0.1168
epoch [203/1000], loss:0.1124
epoch [204/1000], loss:0.1169
epoch [205/1000], loss:0.1064
epoch [206/1000], loss:0.1137
epoch [207/1000], loss:0.1149
epoch [208/1000], loss:0.1317
epoch [209/1000], loss:0.1101
epoch [210/1000], loss:0.1415
epoch [211/1000], loss:0.1274
epoch [212/1000], loss:0.1159
epoch [213/1000], loss:0.1081
epoch [214/1000], loss:0.1279
epoch [215/1000], loss:0.1182
epoch [216/1000], loss:0.1256
epoch [217/1000], loss:0.1108
epoch [218/1000], loss:0.1079
epoch [219/1000], loss:0.1254
epoch [220/1000], loss:0.1235
epoch [221/1000], loss:0.1074
epoch [222/1000], loss:0.1163
epoch [223/1000], loss:0.1163
epoch [224/1000], loss:0.1131
epoch [225/1000], loss:0.1027
epoch [226/1000], loss:0.1156
epoch [227/1000], loss:0.1105
epoch [228/1000], loss:0.1094
epoch [229/1000], loss:0.1162
epoch [230/1000], loss:0.1196
epoch [231/1000], loss:0.1092
epoch [232/1000], loss:0.1185
epoch [233/1000], loss:0.1150
epoch [234/1000], loss:0.1203
epoch [235/1000], loss:0.1338
epoch [236/1000], loss:0.1171
epoch [237/1000], loss:0.1213
epoch [238/1000], loss:0.1162
epoch [239/1000], loss:0.1104
epoch [240/1000], loss:0.1176
epoch [241/1000], loss:0.1255
epoch [242/1000], loss:0.1170
epoch [243/1000], loss:0.1092
epoch [244/1000], loss:0.1352
epoch [245/1000], loss:0.1121
epoch [246/1000], loss:0.1096
epoch [247/1000], loss:0.1285
epoch [248/1000], loss:0.1152
epoch [249/1000], loss:0.1157
epoch [250/1000], loss:0.1266
epoch [251/1000], loss:0.1154
epoch [252/1000], loss:0.1171
epoch [253/1000], loss:0.1218
epoch [254/1000], loss:0.1245
epoch [255/1000], loss:0.1161
epoch [256/1000], loss:0.1261
epoch [257/1000], loss:0.1089
epoch [258/1000], loss:0.1163
epoch [259/1000], loss:0.1270
epoch [260/1000], loss:0.1194
epoch [261/1000], loss:0.1164
epoch [262/1000], loss:0.1090
epoch [263/1000], loss:0.1308
epoch [264/1000], loss:0.1269
epoch [265/1000], loss:0.1140
epoch [266/1000], loss:0.1207
epoch [267/1000], loss:0.1162
epoch [268/1000], loss:0.1089
epoch [269/1000], loss:0.1244
epoch [270/1000], loss:0.1202
epoch [271/1000], loss:0.1208
epoch [272/1000], loss:0.1184
epoch [273/1000], loss:0.1183
epoch [274/1000], loss:0.1231
epoch [275/1000], loss:0.1372
epoch [276/1000], loss:0.1189
epoch [277/1000], loss:0.1132
epoch [278/1000], loss:0.1245
epoch [279/1000], loss:0.1210
epoch [280/1000], loss:0.1130
epoch [281/1000], loss:0.1187
epoch [282/1000], loss:0.1224
epoch [283/1000], loss:0.1218
epoch [284/1000], loss:0.1114
epoch [285/1000], loss:0.1215
epoch [286/1000], loss:0.1162
epoch [287/1000], loss:0.1182
epoch [288/1000], loss:0.0981
epoch [289/1000], loss:0.1083
epoch [290/1000], loss:0.1089
epoch [291/1000], loss:0.1121
epoch [292/1000], loss:0.1136
epoch [293/1000], loss:0.1099
epoch [294/1000], loss:0.1107
epoch [295/1000], loss:0.1207
epoch [296/1000], loss:0.1069
epoch [297/1000], loss:0.1259
epoch [298/1000], loss:0.1020
epoch [299/1000], loss:0.1068
epoch [300/1000], loss:0.1240
epoch [301/1000], loss:0.1143
epoch [302/1000], loss:0.1253
epoch [303/1000], loss:0.1229
epoch [304/1000], loss:0.1052
epoch [305/1000], loss:0.1252
epoch [306/1000], loss:0.1206
epoch [307/1000], loss:0.1226
epoch [308/1000], loss:0.1222
epoch [309/1000], loss:0.1128
epoch [310/1000], loss:0.1078
epoch [311/1000], loss:0.1235
epoch [312/1000], loss:0.1191
epoch [313/1000], loss:0.1180
epoch [314/1000], loss:0.1226
epoch [315/1000], loss:0.1204
epoch [316/1000], loss:0.1099
epoch [317/1000], loss:0.1225
epoch [318/1000], loss:0.1205
epoch [319/1000], loss:0.1043
epoch [320/1000], loss:0.1162
epoch [321/1000], loss:0.1113
epoch [322/1000], loss:0.1097
epoch [323/1000], loss:0.1056
epoch [324/1000], loss:0.1154
epoch [325/1000], loss:0.1102
epoch [326/1000], loss:0.1197
epoch [327/1000], loss:0.1174
epoch [328/1000], loss:0.1205
epoch [329/1000], loss:0.1129
epoch [330/1000], loss:0.1196
epoch [331/1000], loss:0.1070
epoch [332/1000], loss:0.1183
epoch [333/1000], loss:0.1128
epoch [334/1000], loss:0.1171
epoch [335/1000], loss:0.1126
epoch [336/1000], loss:0.1148
epoch [337/1000], loss:0.1023
epoch [338/1000], loss:0.1109
epoch [339/1000], loss:0.1165
epoch [340/1000], loss:0.1183
epoch [341/1000], loss:0.1126
epoch [342/1000], loss:0.1100
epoch [343/1000], loss:0.1111
epoch [344/1000], loss:0.1131
epoch [345/1000], loss:0.1213
epoch [346/1000], loss:0.1362
epoch [347/1000], loss:0.1164
epoch [348/1000], loss:0.1081
epoch [349/1000], loss:0.1036
epoch [350/1000], loss:0.1121
epoch [351/1000], loss:0.1182
epoch [352/1000], loss:0.1181
epoch [353/1000], loss:0.1266
epoch [354/1000], loss:0.1054
epoch [355/1000], loss:0.1124
epoch [356/1000], loss:0.1333
epoch [357/1000], loss:0.1213
epoch [358/1000], loss:0.1123
epoch [359/1000], loss:0.1114
epoch [360/1000], loss:0.1100
epoch [361/1000], loss:0.1166
epoch [362/1000], loss:0.1174
epoch [363/1000], loss:0.1221
epoch [364/1000], loss:0.1110
epoch [365/1000], loss:0.1244
epoch [366/1000], loss:0.1144
epoch [367/1000], loss:0.1114
epoch [368/1000], loss:0.1350
epoch [369/1000], loss:0.1127
epoch [370/1000], loss:0.1091
epoch [371/1000], loss:0.1212
epoch [372/1000], loss:0.1207
epoch [373/1000], loss:0.1282
epoch [374/1000], loss:0.1185
epoch [375/1000], loss:0.1198
epoch [376/1000], loss:0.1122
epoch [377/1000], loss:0.1034
epoch [378/1000], loss:0.1314
epoch [379/1000], loss:0.1122
epoch [380/1000], loss:0.1265
epoch [381/1000], loss:0.1123
epoch [382/1000], loss:0.1157
epoch [383/1000], loss:0.1198
epoch [384/1000], loss:0.1163
epoch [385/1000], loss:0.1239
epoch [386/1000], loss:0.1079
epoch [387/1000], loss:0.1103
epoch [388/1000], loss:0.1153
epoch [389/1000], loss:0.1138
epoch [390/1000], loss:0.1079
epoch [391/1000], loss:0.1122
epoch [392/1000], loss:0.1234
epoch [393/1000], loss:0.1191
epoch [394/1000], loss:0.1205
epoch [395/1000], loss:0.1310
epoch [396/1000], loss:0.1298
epoch [397/1000], loss:0.1137
epoch [398/1000], loss:0.1124
epoch [399/1000], loss:0.1247
epoch [400/1000], loss:0.1332
epoch [401/1000], loss:0.1197
epoch [402/1000], loss:0.1104
epoch [403/1000], loss:0.1149
epoch [404/1000], loss:0.1240
epoch [405/1000], loss:0.1175
epoch [406/1000], loss:0.1247
epoch [407/1000], loss:0.1092
epoch [408/1000], loss:0.1180
epoch [409/1000], loss:0.1250
epoch [410/1000], loss:0.1144
epoch [411/1000], loss:0.1106
epoch [412/1000], loss:0.1146
epoch [413/1000], loss:0.1146
epoch [414/1000], loss:0.1181
epoch [415/1000], loss:0.1141
epoch [416/1000], loss:0.1036
epoch [417/1000], loss:0.1221
epoch [418/1000], loss:0.1119
epoch [419/1000], loss:0.1182
epoch [420/1000], loss:0.1366
epoch [421/1000], loss:0.1112
epoch [422/1000], loss:0.1247
epoch [423/1000], loss:0.1164
epoch [424/1000], loss:0.1294
epoch [425/1000], loss:0.1269
epoch [426/1000], loss:0.1232
epoch [427/1000], loss:0.1186
epoch [428/1000], loss:0.1128
epoch [429/1000], loss:0.1130
epoch [430/1000], loss:0.1170
epoch [431/1000], loss:0.1276
epoch [432/1000], loss:0.1141
epoch [433/1000], loss:0.1256
epoch [434/1000], loss:0.1247
epoch [435/1000], loss:0.1377
epoch [436/1000], loss:0.1329
epoch [437/1000], loss:0.1155
epoch [438/1000], loss:0.1077
epoch [439/1000], loss:0.1222
epoch [440/1000], loss:0.1197
epoch [441/1000], loss:0.1140
epoch [442/1000], loss:0.1162
epoch [443/1000], loss:0.1323
epoch [444/1000], loss:0.1225
epoch [445/1000], loss:0.1301
epoch [446/1000], loss:0.1139
epoch [447/1000], loss:0.1227
epoch [448/1000], loss:0.1257
epoch [449/1000], loss:0.1194
epoch [450/1000], loss:0.1323
epoch [451/1000], loss:0.1269
epoch [452/1000], loss:0.1266
epoch [453/1000], loss:0.1073
epoch [454/1000], loss:0.1173
epoch [455/1000], loss:0.1115
epoch [456/1000], loss:0.1339
epoch [457/1000], loss:0.1048
epoch [458/1000], loss:0.1225
epoch [459/1000], loss:0.1203
epoch [460/1000], loss:0.1276
epoch [461/1000], loss:0.1206
epoch [462/1000], loss:0.1207
epoch [463/1000], loss:0.1024
epoch [464/1000], loss:0.1185
epoch [465/1000], loss:0.1199
epoch [466/1000], loss:0.1197
epoch [467/1000], loss:0.1108
epoch [468/1000], loss:0.1165
epoch [469/1000], loss:0.1119
epoch [470/1000], loss:0.1135
epoch [471/1000], loss:0.1239
epoch [472/1000], loss:0.1148
epoch [473/1000], loss:0.1142
epoch [474/1000], loss:0.1210
epoch [475/1000], loss:0.1202
epoch [476/1000], loss:0.1197
epoch [477/1000], loss:0.1059
epoch [478/1000], loss:0.1167
epoch [479/1000], loss:0.1253
epoch [480/1000], loss:0.1233
epoch [481/1000], loss:0.0988
epoch [482/1000], loss:0.1265
epoch [483/1000], loss:0.1202
epoch [484/1000], loss:0.1116
epoch [485/1000], loss:0.1144
epoch [486/1000], loss:0.1121
epoch [487/1000], loss:0.1182
epoch [488/1000], loss:0.1179
epoch [489/1000], loss:0.1257
epoch [490/1000], loss:0.1354
epoch [491/1000], loss:0.1118
epoch [492/1000], loss:0.1192
epoch [493/1000], loss:0.1052
epoch [494/1000], loss:0.1241
epoch [495/1000], loss:0.1155
epoch [496/1000], loss:0.1212
epoch [497/1000], loss:0.1133
epoch [498/1000], loss:0.1173
epoch [499/1000], loss:0.1134
epoch [500/1000], loss:0.1103
epoch [501/1000], loss:0.1183
epoch [502/1000], loss:0.1105
epoch [503/1000], loss:0.1138
epoch [504/1000], loss:0.1032
epoch [505/1000], loss:0.1183
epoch [506/1000], loss:0.1232
epoch [507/1000], loss:0.1134
epoch [508/1000], loss:0.1132
epoch [509/1000], loss:0.1123
epoch [510/1000], loss:0.1106
epoch [511/1000], loss:0.1162
epoch [512/1000], loss:0.1153
epoch [513/1000], loss:0.1109
epoch [514/1000], loss:0.1366
epoch [515/1000], loss:0.1183
epoch [516/1000], loss:0.1099
epoch [517/1000], loss:0.1180
epoch [518/1000], loss:0.1147
epoch [519/1000], loss:0.1221
epoch [520/1000], loss:0.1244
epoch [521/1000], loss:0.1253
epoch [522/1000], loss:0.1269
epoch [523/1000], loss:0.1262
epoch [524/1000], loss:0.1163
epoch [525/1000], loss:0.1144
epoch [526/1000], loss:0.1020
epoch [527/1000], loss:0.1079
epoch [528/1000], loss:0.1254
epoch [529/1000], loss:0.1277
epoch [530/1000], loss:0.1152
epoch [531/1000], loss:0.1126
epoch [532/1000], loss:0.1149
epoch [533/1000], loss:0.1120
epoch [534/1000], loss:0.1030
epoch [535/1000], loss:0.1184
epoch [536/1000], loss:0.1176
epoch [537/1000], loss:0.1126
epoch [538/1000], loss:0.1190
epoch [539/1000], loss:0.1044
epoch [540/1000], loss:0.1292
epoch [541/1000], loss:0.1179
epoch [542/1000], loss:0.1174
epoch [543/1000], loss:0.1269
epoch [544/1000], loss:0.1238
epoch [545/1000], loss:0.1280
epoch [546/1000], loss:0.1145
epoch [547/1000], loss:0.1195
epoch [548/1000], loss:0.1191
epoch [549/1000], loss:0.1208
epoch [550/1000], loss:0.1189
epoch [551/1000], loss:0.1171
epoch [552/1000], loss:0.1248
epoch [553/1000], loss:0.1061
epoch [554/1000], loss:0.1074
epoch [555/1000], loss:0.1148
epoch [556/1000], loss:0.1180
epoch [557/1000], loss:0.1137
epoch [558/1000], loss:0.1069
epoch [559/1000], loss:0.1235
epoch [560/1000], loss:0.1103
epoch [561/1000], loss:0.1146
epoch [562/1000], loss:0.1052
epoch [563/1000], loss:0.1165
epoch [564/1000], loss:0.1286
epoch [565/1000], loss:0.1134
epoch [566/1000], loss:0.1141
epoch [567/1000], loss:0.1209
epoch [568/1000], loss:0.1223
epoch [569/1000], loss:0.1195
epoch [570/1000], loss:0.0965
epoch [571/1000], loss:0.1334
epoch [572/1000], loss:0.1114
epoch [573/1000], loss:0.1200
epoch [574/1000], loss:0.1192
epoch [575/1000], loss:0.1108
epoch [576/1000], loss:0.1190
epoch [577/1000], loss:0.1123
epoch [578/1000], loss:0.1178
epoch [579/1000], loss:0.1297
epoch [580/1000], loss:0.1109
epoch [581/1000], loss:0.1117
epoch [582/1000], loss:0.0962
epoch [583/1000], loss:0.1327
epoch [584/1000], loss:0.1160
epoch [585/1000], loss:0.1188
epoch [586/1000], loss:0.1216
epoch [587/1000], loss:0.1154
epoch [588/1000], loss:0.1074
epoch [589/1000], loss:0.1287
epoch [590/1000], loss:0.1153
epoch [591/1000], loss:0.1132
epoch [592/1000], loss:0.1198
epoch [593/1000], loss:0.1263
epoch [594/1000], loss:0.1125
epoch [595/1000], loss:0.1258
epoch [596/1000], loss:0.1217
epoch [597/1000], loss:0.1286
epoch [598/1000], loss:0.1024
epoch [599/1000], loss:0.1177
epoch [600/1000], loss:0.1169
epoch [601/1000], loss:0.1115
epoch [602/1000], loss:0.1135
epoch [603/1000], loss:0.1317
epoch [604/1000], loss:0.1193
epoch [605/1000], loss:0.0980
epoch [606/1000], loss:0.1167
epoch [607/1000], loss:0.1076
epoch [608/1000], loss:0.1127
epoch [609/1000], loss:0.1311
epoch [610/1000], loss:0.1152
epoch [611/1000], loss:0.1259
epoch [612/1000], loss:0.1231
epoch [613/1000], loss:0.1086
epoch [614/1000], loss:0.1257
epoch [615/1000], loss:0.1132
epoch [616/1000], loss:0.1134
epoch [617/1000], loss:0.1262
epoch [618/1000], loss:0.1280
epoch [619/1000], loss:0.1008
epoch [620/1000], loss:0.1082
epoch [621/1000], loss:0.1212
epoch [622/1000], loss:0.1224
epoch [623/1000], loss:0.1069
epoch [624/1000], loss:0.1359
epoch [625/1000], loss:0.1215
epoch [626/1000], loss:0.1192
epoch [627/1000], loss:0.1206
epoch [628/1000], loss:0.1190
epoch [629/1000], loss:0.1182
epoch [630/1000], loss:0.1162
epoch [631/1000], loss:0.1317
epoch [632/1000], loss:0.1314
epoch [633/1000], loss:0.1259
epoch [634/1000], loss:0.1087
epoch [635/1000], loss:0.1242
epoch [636/1000], loss:0.1164
epoch [637/1000], loss:0.1068
epoch [638/1000], loss:0.1191
epoch [639/1000], loss:0.1145
epoch [640/1000], loss:0.1307
epoch [641/1000], loss:0.1179
epoch [642/1000], loss:0.1247
epoch [643/1000], loss:0.1182
epoch [644/1000], loss:0.1266
epoch [645/1000], loss:0.1201
epoch [646/1000], loss:0.1271
epoch [647/1000], loss:0.1343
epoch [648/1000], loss:0.1250
epoch [649/1000], loss:0.1177
epoch [650/1000], loss:0.1263
epoch [651/1000], loss:0.1270
epoch [652/1000], loss:0.1059
epoch [653/1000], loss:0.1225
epoch [654/1000], loss:0.1317
epoch [655/1000], loss:0.1197
epoch [656/1000], loss:0.1219
epoch [657/1000], loss:0.1232
epoch [658/1000], loss:0.1157
epoch [659/1000], loss:0.1081
epoch [660/1000], loss:0.1324
epoch [661/1000], loss:0.1158
epoch [662/1000], loss:0.1139
epoch [663/1000], loss:0.1186
epoch [664/1000], loss:0.1268
epoch [665/1000], loss:0.1170
epoch [666/1000], loss:0.1222
epoch [667/1000], loss:0.1077
epoch [668/1000], loss:0.1220
epoch [669/1000], loss:0.1212
epoch [670/1000], loss:0.1304
epoch [671/1000], loss:0.1195
epoch [672/1000], loss:0.1049
epoch [673/1000], loss:0.1118
epoch [674/1000], loss:0.1264
epoch [675/1000], loss:0.1224
epoch [676/1000], loss:0.1181
epoch [677/1000], loss:0.1018
epoch [678/1000], loss:0.1180
epoch [679/1000], loss:0.1168
epoch [680/1000], loss:0.1227
epoch [681/1000], loss:0.1244
epoch [682/1000], loss:0.1221
epoch [683/1000], loss:0.1227
epoch [684/1000], loss:0.1055
epoch [685/1000], loss:0.1233
epoch [686/1000], loss:0.1190
epoch [687/1000], loss:0.1163
epoch [688/1000], loss:0.1291
epoch [689/1000], loss:0.1075
epoch [690/1000], loss:0.1293
epoch [691/1000], loss:0.1152
epoch [692/1000], loss:0.1346
epoch [693/1000], loss:0.1165
epoch [694/1000], loss:0.1179
epoch [695/1000], loss:0.1071
epoch [696/1000], loss:0.1307
epoch [697/1000], loss:0.1168
epoch [698/1000], loss:0.1037
epoch [699/1000], loss:0.1216
epoch [700/1000], loss:0.1115
epoch [701/1000], loss:0.1224
epoch [702/1000], loss:0.1258
epoch [703/1000], loss:0.1188
epoch [704/1000], loss:0.1110
epoch [705/1000], loss:0.1248
epoch [706/1000], loss:0.1244
epoch [707/1000], loss:0.1202
epoch [708/1000], loss:0.1160
epoch [709/1000], loss:0.1139
epoch [710/1000], loss:0.1023
epoch [711/1000], loss:0.1155
epoch [712/1000], loss:0.1206
epoch [713/1000], loss:0.1292
epoch [714/1000], loss:0.1112
epoch [715/1000], loss:0.1380
epoch [716/1000], loss:0.1287
epoch [717/1000], loss:0.1218
epoch [718/1000], loss:0.1087
epoch [719/1000], loss:0.1154
epoch [720/1000], loss:0.1344
epoch [721/1000], loss:0.1319
epoch [722/1000], loss:0.1117
epoch [723/1000], loss:0.1203
epoch [724/1000], loss:0.1074
epoch [725/1000], loss:0.1147
epoch [726/1000], loss:0.1179
epoch [727/1000], loss:0.1097
epoch [728/1000], loss:0.1167
epoch [729/1000], loss:0.1267
epoch [730/1000], loss:0.0980
epoch [731/1000], loss:0.1109
epoch [732/1000], loss:0.1271
epoch [733/1000], loss:0.1101
epoch [734/1000], loss:0.1168
epoch [735/1000], loss:0.1211
epoch [736/1000], loss:0.1303
epoch [737/1000], loss:0.1219
epoch [738/1000], loss:0.1150
epoch [739/1000], loss:0.1226
epoch [740/1000], loss:0.1132
epoch [741/1000], loss:0.1254
epoch [742/1000], loss:0.1227
epoch [743/1000], loss:0.1152
epoch [744/1000], loss:0.1162
epoch [745/1000], loss:0.1231
epoch [746/1000], loss:0.1176
epoch [747/1000], loss:0.1072
epoch [748/1000], loss:0.1187
epoch [749/1000], loss:0.1193
epoch [750/1000], loss:0.1096
epoch [751/1000], loss:0.1131
epoch [752/1000], loss:0.1128
epoch [753/1000], loss:0.1117
epoch [754/1000], loss:0.1286
epoch [755/1000], loss:0.1108
epoch [756/1000], loss:0.1055
epoch [757/1000], loss:0.1141
epoch [758/1000], loss:0.1187
epoch [759/1000], loss:0.1129
epoch [760/1000], loss:0.1246
epoch [761/1000], loss:0.1308
epoch [762/1000], loss:0.1201
epoch [763/1000], loss:0.1064
epoch [764/1000], loss:0.1354
epoch [765/1000], loss:0.1178
epoch [766/1000], loss:0.1224
epoch [767/1000], loss:0.1202
epoch [768/1000], loss:0.1273
epoch [769/1000], loss:0.1180
epoch [770/1000], loss:0.1357
epoch [771/1000], loss:0.1212
epoch [772/1000], loss:0.1213
epoch [773/1000], loss:0.1252
epoch [774/1000], loss:0.1155
epoch [775/1000], loss:0.1152
epoch [776/1000], loss:0.1054
epoch [777/1000], loss:0.1263
epoch [778/1000], loss:0.1172
epoch [779/1000], loss:0.1188
epoch [780/1000], loss:0.1260
epoch [781/1000], loss:0.1175
epoch [782/1000], loss:0.1178
epoch [783/1000], loss:0.1253
epoch [784/1000], loss:0.1320
epoch [785/1000], loss:0.1139
epoch [786/1000], loss:0.1185
epoch [787/1000], loss:0.1221
epoch [788/1000], loss:0.1256
epoch [789/1000], loss:0.1244
epoch [790/1000], loss:0.1190
epoch [791/1000], loss:0.1158
epoch [792/1000], loss:0.1208
epoch [793/1000], loss:0.1272
epoch [794/1000], loss:0.1071
epoch [795/1000], loss:0.1277
epoch [796/1000], loss:0.1139
epoch [797/1000], loss:0.1268
epoch [798/1000], loss:0.1154
epoch [799/1000], loss:0.1245
epoch [800/1000], loss:0.1298
epoch [801/1000], loss:0.1067
epoch [802/1000], loss:0.1286
epoch [803/1000], loss:0.1235
epoch [804/1000], loss:0.1226
epoch [805/1000], loss:0.1176
epoch [806/1000], loss:0.1117
epoch [807/1000], loss:0.1133
epoch [808/1000], loss:0.1122
epoch [809/1000], loss:0.1083
epoch [810/1000], loss:0.1097
epoch [811/1000], loss:0.1091
epoch [812/1000], loss:0.1125
epoch [813/1000], loss:0.1065
epoch [814/1000], loss:0.1202
epoch [815/1000], loss:0.1164
epoch [816/1000], loss:0.1197
epoch [817/1000], loss:0.1179
epoch [818/1000], loss:0.1241
epoch [819/1000], loss:0.1231
epoch [820/1000], loss:0.1164
epoch [821/1000], loss:0.1224
epoch [822/1000], loss:0.1221
epoch [823/1000], loss:0.1247
epoch [824/1000], loss:0.1136
epoch [825/1000], loss:0.1249
epoch [826/1000], loss:0.1032
epoch [827/1000], loss:0.1125
epoch [828/1000], loss:0.1089
epoch [829/1000], loss:0.1088
epoch [830/1000], loss:0.1236
epoch [831/1000], loss:0.1122
epoch [832/1000], loss:0.0993
epoch [833/1000], loss:0.1231
epoch [834/1000], loss:0.1168
epoch [835/1000], loss:0.1149
epoch [836/1000], loss:0.1273
epoch [837/1000], loss:0.1049
epoch [838/1000], loss:0.1188
epoch [839/1000], loss:0.1143
epoch [840/1000], loss:0.1188
epoch [841/1000], loss:0.1194
epoch [842/1000], loss:0.1123
epoch [843/1000], loss:0.1149
epoch [844/1000], loss:0.1094
epoch [845/1000], loss:0.1270
epoch [846/1000], loss:0.1153
epoch [847/1000], loss:0.1135
epoch [848/1000], loss:0.1132
epoch [849/1000], loss:0.1124
epoch [850/1000], loss:0.1411
epoch [851/1000], loss:0.1139
epoch [852/1000], loss:0.1135
epoch [853/1000], loss:0.1195
epoch [854/1000], loss:0.1213
epoch [855/1000], loss:0.1069
epoch [856/1000], loss:0.1099
epoch [857/1000], loss:0.1156
epoch [858/1000], loss:0.1146
epoch [859/1000], loss:0.1115
epoch [860/1000], loss:0.1112
epoch [861/1000], loss:0.1175
epoch [862/1000], loss:0.1113
epoch [863/1000], loss:0.1094
epoch [864/1000], loss:0.1241
epoch [865/1000], loss:0.1353
epoch [866/1000], loss:0.1343
epoch [867/1000], loss:0.1211
epoch [868/1000], loss:0.1395
epoch [869/1000], loss:0.1096
epoch [870/1000], loss:0.1163
epoch [871/1000], loss:0.1102
epoch [872/1000], loss:0.1127
epoch [873/1000], loss:0.1079
epoch [874/1000], loss:0.1147
epoch [875/1000], loss:0.1104
epoch [876/1000], loss:0.0987
epoch [877/1000], loss:0.1114
epoch [878/1000], loss:0.1223
epoch [879/1000], loss:0.1254
epoch [880/1000], loss:0.1094
epoch [881/1000], loss:0.1143
epoch [882/1000], loss:0.1157
epoch [883/1000], loss:0.1241
epoch [884/1000], loss:0.1208
epoch [885/1000], loss:0.1189
epoch [886/1000], loss:0.1112
epoch [887/1000], loss:0.1208
epoch [888/1000], loss:0.1098
epoch [889/1000], loss:0.1103
epoch [890/1000], loss:0.1277
epoch [891/1000], loss:0.1263
epoch [892/1000], loss:0.1142
epoch [893/1000], loss:0.1209
epoch [894/1000], loss:0.1184
epoch [895/1000], loss:0.1230
epoch [896/1000], loss:0.1085
epoch [897/1000], loss:0.0995
epoch [898/1000], loss:0.1186
epoch [899/1000], loss:0.1171
epoch [900/1000], loss:0.1151
epoch [901/1000], loss:0.1114
epoch [902/1000], loss:0.1152
epoch [903/1000], loss:0.1266
epoch [904/1000], loss:0.1192
epoch [905/1000], loss:0.1044
epoch [906/1000], loss:0.1126
epoch [907/1000], loss:0.1116
epoch [908/1000], loss:0.1198
epoch [909/1000], loss:0.1120
epoch [910/1000], loss:0.1152
epoch [911/1000], loss:0.1165
epoch [912/1000], loss:0.1148
epoch [913/1000], loss:0.1043
epoch [914/1000], loss:0.1170
epoch [915/1000], loss:0.1091
epoch [916/1000], loss:0.1168
epoch [917/1000], loss:0.1190
epoch [918/1000], loss:0.1091
epoch [919/1000], loss:0.1118
epoch [920/1000], loss:0.1199
epoch [921/1000], loss:0.1022
epoch [922/1000], loss:0.1198
epoch [923/1000], loss:0.1119
epoch [924/1000], loss:0.1145
epoch [925/1000], loss:0.1240
epoch [926/1000], loss:0.1201
epoch [927/1000], loss:0.1125
epoch [928/1000], loss:0.1278
epoch [929/1000], loss:0.1136
epoch [930/1000], loss:0.1210
epoch [931/1000], loss:0.1256
epoch [932/1000], loss:0.1123
epoch [933/1000], loss:0.1272
epoch [934/1000], loss:0.1303
epoch [935/1000], loss:0.1107
epoch [936/1000], loss:0.1166
epoch [937/1000], loss:0.1076
epoch [938/1000], loss:0.1165
epoch [939/1000], loss:0.1241
epoch [940/1000], loss:0.1045
epoch [941/1000], loss:0.1239
epoch [942/1000], loss:0.1220
epoch [943/1000], loss:0.1146
epoch [944/1000], loss:0.1130
epoch [945/1000], loss:0.1222
epoch [946/1000], loss:0.1226
epoch [947/1000], loss:0.1105
epoch [948/1000], loss:0.1106
epoch [949/1000], loss:0.1299
epoch [950/1000], loss:0.1163
epoch [951/1000], loss:0.1157
epoch [952/1000], loss:0.1184
epoch [953/1000], loss:0.1250
epoch [954/1000], loss:0.1239
epoch [955/1000], loss:0.1216
epoch [956/1000], loss:0.1163
epoch [957/1000], loss:0.1246
epoch [958/1000], loss:0.1169
epoch [959/1000], loss:0.1145
epoch [960/1000], loss:0.1157
epoch [961/1000], loss:0.1331
epoch [962/1000], loss:0.1308
epoch [963/1000], loss:0.1188
epoch [964/1000], loss:0.1130
epoch [965/1000], loss:0.1190
epoch [966/1000], loss:0.1249
epoch [967/1000], loss:0.1122
epoch [968/1000], loss:0.1167
epoch [969/1000], loss:0.1266
epoch [970/1000], loss:0.1212
epoch [971/1000], loss:0.1076
epoch [972/1000], loss:0.1159
epoch [973/1000], loss:0.1286
epoch [974/1000], loss:0.1204
epoch [975/1000], loss:0.1262
epoch [976/1000], loss:0.1241
epoch [977/1000], loss:0.1159
epoch [978/1000], loss:0.1079
epoch [979/1000], loss:0.0965
epoch [980/1000], loss:0.1198
epoch [981/1000], loss:0.1138
epoch [982/1000], loss:0.1194
epoch [983/1000], loss:0.1176
epoch [984/1000], loss:0.1169
epoch [985/1000], loss:0.1280
epoch [986/1000], loss:0.1165
epoch [987/1000], loss:0.1262
epoch [988/1000], loss:0.1217
epoch [989/1000], loss:0.1168
epoch [990/1000], loss:0.1180
epoch [991/1000], loss:0.1156
epoch [992/1000], loss:0.1092
epoch [993/1000], loss:0.1133
epoch [994/1000], loss:0.1166
epoch [995/1000], loss:0.1180
epoch [996/1000], loss:0.1206
epoch [997/1000], loss:0.1176
epoch [998/1000], loss:0.1160
epoch [999/1000], loss:0.1152
epoch [1000/1000], loss:0.1230

评估

将 testing 的图片输入 model 后,可以得到其重建的图片,并对两者取平方差。可以发现 inlier 的平方差应该与 outlier 的平方差形成差距明显的两群数据。

if task == 'ae':if model_type == 'fcn' or model_type == 'vae':y = test.reshape(len(test), -1)else:y = testdata = torch.tensor(y, dtype=torch.float)test_dataset = TensorDataset(data)test_sampler = SequentialSampler(test_dataset)test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size)model = torch.load('best_model_{}.pt'.format(model_type), map_location='cuda')model.eval()reconstructed = list()for i, data in enumerate(test_dataloader): if model_type == 'cnn':img = data[0].transpose(3, 1).cuda()else:img = data[0].cuda()output = model(img)if model_type == 'cnn':output = output.transpose(3, 1)elif model_type == 'vae':output = output[0]reconstructed.append(output.cpu().detach().numpy())reconstructed = np.concatenate(reconstructed, axis=0)anomality = np.sqrt(np.sum(np.square(reconstructed - y).reshape(len(y), -1), axis=1))y_pred = anomalitywith open('prediction.csv', 'w') as f:f.write('id,anomaly\n')for i in range(len(y_pred)):f.write('{},{}\n'.format(i+1, y_pred[i]))#score = roc_auc_score(y_label, y_pred, average='micro')#score = f1_score(y_label, y_pred, average='micro')#print('auc score: {}'.format(score))

李宏毅(2020)作业10:异常检测(Anomaly Detection)相关推荐

  1. 机器学习系列-- 异常检测(Anomaly Detection)

    目录 1.问题的动机 2.高斯分布 3.异常检测算法 4.开发和评价一个异常检测系统 5.异常检测与监督学习对比 6.特征选择 1.问题的动机 这一部分介绍异常检测(Anomaly detection ...

  2. 异常检测 Anomaly Detection研究进展梳理

    异常检测 Anomaly Detection研究进展梳理 异常检测也叫新颖性检测,通俗来讲是指在给定数据中检测出与预期的数据(正常)不同的.未曾出现的.极少出现的部分数据(异常).比如工业上的产品瑕疵 ...

  3. 入门机器学习(十八)--异常检测(Anomaly Detection)

    异常检测(Anomaly Detection) 1. 问题动机(Problem Motivation) 2. 高斯分布(Gaussian Distribution) 3. 算法(Algorithm) ...

  4. 【李宏毅2020 ML/DL】P67-72 Anomaly Detection

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 本系列课程第 5 节有缺失,我已找到补充视频,见[ 李宏毅机器学习:异常检测 ]Anomaly Detecti ...

  5. 异常检测——Anomaly Detection

    目录 1.问题来源 2.应用 3.分类 3.1 method1:With Classifier 3.2 method 2:Gaussian distribution 3.2.1 问题阐述 实现代码 3 ...

  6. 异常检测(Anomaly detection): 异常检测算法(应用高斯分布)

    估计P(x)的分布--密度估计 我们有m个样本,每个样本有n个特征值,每个特征都分别服从不同的高斯分布,上图中的公式是在假设每个特征都独立的情况下,实际无论每个特征是否独立,这个公式的效果都不错.连乘 ...

  7. 异常检测(Out-of-distribution detection\ anomaly detection)相关论文阅读

    Learning Confidence for Out-of-Distribution detection in Neural Network 作者 Terrance DeVries .Graham ...

  8. 李宏毅2020作业3---CNN

    其他作业指路:⭐李宏毅机器学习2020作业汇总 1.配置环境 第三方库: cv2 pytorch torchvision 理论: keras常用函数: pytorch常用函数: 程序: 函数: 1.s ...

  9. 李宏毅2020作业4---RNN

    ​其他作业指路:⭐李宏毅机器学习2020作业汇总 目录 ==作业说明== ==数据说明== ==原理== *LSTM* ==参考内容== 作业说明 通过RNN进行情感分析,给定一个句子,判断这个句子是 ...

  10. 【论文解读】【多元时间序列异常检测】 Detection and Characterization of Anomalies in Multivariate Time Series

    这是一篇比较经典的多元时间序列数据异常检测算法的论文阅读解析,包括了算法代码的实现,论文原文已经上传到CSDN方便下载阅读,点击这里下载原始论文 一.这篇论文说了什么? 针对多元时间序列的异常数据检测 ...

最新文章

  1. R构建岭回归模型(Ridge Regression)实战
  2. winrar压缩指定目录
  3. 【Visual C++】游戏开发笔记二十一 游戏基础物理建模(三) 摩擦力系统模拟
  4. 给ThinkPad E470C 换个高分屏(1080P)
  5. Windows10——荣耀笔记本任务栏图标显示异常且无显示/隐藏图标的箭头解决方案
  6. ML.NET Cookbook:(17)如何在分类数据上训练模型?
  7. H.266 参考软件VTM下载和安装
  8. 大数据开发初学者学习路线
  9. Android中Menu的基本用法
  10. win10计算机性能设置,巧设置让Win10运行更流畅
  11. python选择题题库百度文库_(完整版)Python题库
  12. [转载]大数据入门 - 董飞
  13. PC改变文档显示颜色,保护眼睛,缓解眼疲劳
  14. 【计算机图形学 】绘制椭圆 | OpenGL+鼠标交互
  15. FORCE_CONSTANTS中3阶力常数大小与原子间距的分析脚本
  16. 深入了解style标签元素
  17. 2.H.265/HEVC —— 帧内预测
  18. Superset(5):Superset Dashboards看板展示实战
  19. link和@import的区别:
  20. 图腾柱PFC入门仿真

热门文章

  1. apsara clouder基础认证API接口
  2. X Spring File Storage 0.6.0 发布,新增支持 FTP、SFTP、WebDAV
  3. openlayers根据半径绘制圆形,多圆连线并标记距离
  4. Java:计算圆形和长方形的面积
  5. 奥塔在线:Centos下使用RPM方式安装JDK1.8
  6. ORACLE- check 检查约束
  7. mybatis的left join多条件操作
  8. i58400升级可以换什么cpu_CPU硅脂有什么作用?CPU硅脂可以用牙膏代替吗?不涂硅脂可以吗?...
  9. Matlab图像处理rgb2ind函数
  10. oracle备份数据脚本,oracle数据库自动备份脚本