请至少使用两种方法 (autoencoder 架构、optimizer、data preprocessing、后续降维方法、clustering 算法等等) 来改进 baseline code 的 accuracy。

  • 记录改进前、后的 accuracy 分别为多少。
  • 使用改进前、后的方法,分别将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。


使用你 accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片。

  • 画出他们的原图以及 reconstruct 之后的图片。


在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints。

  • 请用 model 的 reconstruction error (用所有的 trainX 计算 MSE) 和 val accuracy 对那些 checkpoints 作图。


请同学以 np.load() 读入资料,valX.npy 和 valY.npy 只用来检验我们的训练效果,不能用来训练。


  • 里面总共有 8500 张 RGB 图片,大小都是 32 * 32 * 3
  • shape 为 (8500, 32, 32, 3)


  • 请不要用来训练
    • 里面总共有 500 张 RGB 图片,大小都是 32 * 32 * 3
    • shape 为 (500, 32, 32, 3)


  • 请不要用来训练
  • 对应 valX.npy 的 label
  • shape为 (500,)


创建 checkpoints文件夹

#!gdown --id '1BZb2AqOHHaad7Mo82St1qTBaXo_xtcUc' --output trainX.npy 
# !gdown --id '152NKCpj8S_zuIx3bQy0NN5oqpvBjdPIq' --output valX.npy 
# !gdown --id '1_hRGsFtm5KEazUg2ZvPZcuNScGF-ANh4' --output valY.npy 
!mkdir checkpoints
定义我们的 preprocess:将图片的数值介于 0~255 的 int 线性转为 -1~1 的 float。

import numpy as np

def preprocess(image_list):
    """ Normalize Image and Permute (N,H,W,C) to (N,C,H,W)
      image_list: List of images (9000, 32, 32, 3)
      image_list: List of images (9000, 3, 32, 32)
    image_list = np.array(image_list)
    image_list = np.transpose(image_list, (0, 3, 1, 2))
    image_list = (image_list / 255.0) * 2 - 1
    image_list = image_list.astype(np.float32)
    return image_list


from torch.utils.data import Dataset

class Image_Dataset(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):
        images = self.image_list[idx]
        return images

将训练资料读入,并且 preprocess。之后我们将 preprocess 完的训练资料变成我们需要的 dataset。请同学不要使用 valX 和 valY 来训练。

from torch.utils.data import DataLoader

trainX = np.load('trainX.npy')
trainX_preprocessed = preprocess(trainX)
img_dataset = Image_Dataset(trainX_preprocessed)


这边提供一些有用的 functions。一个是计算 model 参数量的(report 会用到),另一个是固定训练的随机种子(以便 reproduce)。

import random
import torch

def count_parameters(model, only_trainable=False):
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
        return sum(p.numel() for p in model.parameters())

def same_seeds(seed):
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False #不做网络加速
    torch.backends.cudnn.deterministic = True #每次返回的卷积算法固定


定义我们的 baseline autoencoder


import torch.nn as nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.Conv2d(256, 512, 3, stride=1, padding=1),
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=1),
            nn.ConvTranspose2d(256, 128, 5, stride=1),
            nn.ConvTranspose2d(128, 64, 9, stride=1),
            nn.ConvTranspose2d(64, 3, 17, stride=1),

    def forward(self, x):
        x1 = self.encoder(x)
        x  = self.decoder(x1)
        return x1, x
这个部分就是主要的训练阶段。我们先将准备好的 dataset 当作参数喂给 dataloader。将 dataloader、model、loss criterion、optimizer 都准备好之后,就可以开始训练。训练完成后,我们会将 model 存下来。

import torch
from torch import optim


model = AE().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

n_epoch = 1000

# 准备 dataloader, model, loss criterion 和 optimizer
img_dataloader = DataLoader(img_dataset, batch_size=64, shuffle=True)

epoch_loss = 0

# 主要的训练过程
for epoch in range(n_epoch):
    epoch_loss = 0
    for data in img_dataloader:
        img = data
        img = img.cuda()

        output1, output = model(img)
        loss = criterion(output, img)
        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), './checkpoints/checkpoint_{}.pth'.format(epoch+1))
        epoch_loss += loss.item()
    print('epoch [{}/{}], loss:{:.5f}'.format(epoch+1, n_epoch, epoch_loss))

# 训练完成后存储 model
torch.save(model.state_dict(), './checkpoints/last_checkpoint.pth')
epoch [1/1000], loss:30.54165
epoch [2/1000], loss:26.34405
epoch [3/1000], loss:21.83250
epoch [4/1000], loss:19.13653
epoch [5/1000], loss:16.89123
epoch [6/1000], loss:15.81137
epoch [7/1000], loss:15.24495
epoch [8/1000], loss:14.82142
epoch [9/1000], loss:14.43517
epoch [10/1000], loss:14.08439
epoch [11/1000], loss:13.73920
epoch [12/1000], loss:13.40639
epoch [13/1000], loss:13.08327
epoch [14/1000], loss:12.66554
epoch [15/1000], loss:12.26715
epoch [16/1000], loss:11.93717
epoch [17/1000], loss:11.67487
epoch [18/1000], loss:11.45737
epoch [19/1000], loss:11.28208
epoch [20/1000], loss:11.08628
epoch [21/1000], loss:10.94622
epoch [22/1000], loss:10.80847
epoch [23/1000], loss:10.70417
epoch [24/1000], loss:10.58255
epoch [25/1000], loss:10.48495
epoch [26/1000], loss:10.39527
epoch [27/1000], loss:10.30006
epoch [28/1000], loss:10.20910
epoch [29/1000], loss:10.13124
epoch [30/1000], loss:10.04456
epoch [31/1000], loss:9.96836
epoch [32/1000], loss:9.88246
epoch [33/1000], loss:9.81235
epoch [34/1000], loss:9.72425
epoch [35/1000], loss:9.65545
epoch [36/1000], loss:9.57657
epoch [37/1000], loss:9.51310
epoch [38/1000], loss:9.45421
epoch [39/1000], loss:9.38250
epoch [40/1000], loss:9.31712
epoch [41/1000], loss:9.25833
epoch [42/1000], loss:9.20196
epoch [43/1000], loss:9.14868
epoch [44/1000], loss:9.08939
epoch [45/1000], loss:9.02597
epoch [46/1000], loss:8.95911
epoch [47/1000], loss:8.91480
epoch [48/1000], loss:8.86116
epoch [49/1000], loss:8.79443
epoch [50/1000], loss:8.73779
epoch [51/1000], loss:8.68570
epoch [52/1000], loss:8.62910
epoch [53/1000], loss:8.57338
epoch [54/1000], loss:8.53807
epoch [55/1000], loss:8.48156
epoch [56/1000], loss:8.43463
epoch [57/1000], loss:8.39641
epoch [58/1000], loss:8.34074
epoch [59/1000], loss:8.30465
epoch [60/1000], loss:8.27341
epoch [61/1000], loss:8.23230
epoch [62/1000], loss:8.18089
epoch [63/1000], loss:8.15129
epoch [64/1000], loss:8.11520
epoch [65/1000], loss:8.07959
epoch [66/1000], loss:8.04687
epoch [67/1000], loss:8.02380
epoch [68/1000], loss:7.98933
epoch [69/1000], loss:7.95649
epoch [70/1000], loss:7.92910
epoch [71/1000], loss:7.88972
epoch [72/1000], loss:7.85813
epoch [73/1000], loss:7.82851
epoch [74/1000], loss:7.81065
epoch [75/1000], loss:7.78497
epoch [76/1000], loss:7.73110
epoch [77/1000], loss:7.71461
epoch [78/1000], loss:7.68887
epoch [79/1000], loss:7.65523
epoch [80/1000], loss:7.63705
epoch [81/1000], loss:7.61096
epoch [82/1000], loss:7.57877
epoch [83/1000], loss:7.54703
epoch [84/1000], loss:7.52961
epoch [85/1000], loss:7.48876
epoch [86/1000], loss:7.46642
epoch [87/1000], loss:7.43804
epoch [88/1000], loss:7.41458
epoch [89/1000], loss:7.38298
epoch [90/1000], loss:7.38157
epoch [91/1000], loss:7.34053
epoch [92/1000], loss:7.32307
epoch [93/1000], loss:7.28897
epoch [94/1000], loss:7.27476
epoch [95/1000], loss:7.25432
epoch [96/1000], loss:7.23210
epoch [97/1000], loss:7.20764
epoch [98/1000], loss:7.17726
epoch [99/1000], loss:7.16785
epoch [100/1000], loss:7.14477
epoch [101/1000], loss:7.12776
epoch [102/1000], loss:7.10490
epoch [103/1000], loss:7.08108
epoch [104/1000], loss:7.06430
epoch [105/1000], loss:7.04382
epoch [106/1000], loss:7.01336
epoch [107/1000], loss:7.00099
epoch [108/1000], loss:6.97758
epoch [109/1000], loss:6.95376
epoch [110/1000], loss:6.94354
epoch [111/1000], loss:6.91744
epoch [112/1000], loss:6.91015
epoch [113/1000], loss:6.88055
epoch [114/1000], loss:6.86521
epoch [115/1000], loss:6.84671
epoch [116/1000], loss:6.82973
epoch [117/1000], loss:6.80817
epoch [118/1000], loss:6.78769
epoch [119/1000], loss:6.77140
epoch [120/1000], loss:6.76178
epoch [121/1000], loss:6.74296
epoch [122/1000], loss:6.71641
epoch [123/1000], loss:6.69564
epoch [124/1000], loss:6.67923
epoch [125/1000], loss:6.66339
epoch [126/1000], loss:6.64667
epoch [127/1000], loss:6.62993
epoch [128/1000], loss:6.60127
epoch [129/1000], loss:6.58229
epoch [130/1000], loss:6.57563
epoch [131/1000], loss:6.55139
epoch [132/1000], loss:6.53123
epoch [133/1000], loss:6.51448
epoch [134/1000], loss:6.49753
epoch [135/1000], loss:6.46827
epoch [136/1000], loss:6.45886
epoch [137/1000], loss:6.43451
epoch [138/1000], loss:6.41819
epoch [139/1000], loss:6.39429
epoch [140/1000], loss:6.38479
epoch [141/1000], loss:6.36964
epoch [142/1000], loss:6.34008
epoch [143/1000], loss:6.32599
epoch [144/1000], loss:6.30631
epoch [145/1000], loss:6.29071
epoch [146/1000], loss:6.27065
epoch [147/1000], loss:6.25629
epoch [148/1000], loss:6.23477
epoch [149/1000], loss:6.22027
epoch [150/1000], loss:6.20892
epoch [151/1000], loss:6.18379
epoch [152/1000], loss:6.16717
epoch [153/1000], loss:6.15294
epoch [154/1000], loss:6.13922
epoch [155/1000], loss:6.12273
epoch [156/1000], loss:6.09983
epoch [157/1000], loss:6.09613
epoch [158/1000], loss:6.08098
epoch [159/1000], loss:6.06648
epoch [160/1000], loss:6.05687
epoch [161/1000], loss:6.03163
epoch [162/1000], loss:6.00917
epoch [163/1000], loss:6.00572
epoch [164/1000], loss:5.99157
epoch [165/1000], loss:5.97707
epoch [166/1000], loss:5.96627
epoch [167/1000], loss:5.96171
epoch [168/1000], loss:5.93227
epoch [169/1000], loss:5.92656
epoch [170/1000], loss:5.92673
epoch [171/1000], loss:5.90135
epoch [172/1000], loss:5.89017
epoch [173/1000], loss:5.87263
epoch [174/1000], loss:5.86483
epoch [175/1000], loss:5.85099
epoch [176/1000], loss:5.83615
epoch [177/1000], loss:5.83101
epoch [178/1000], loss:5.82030
epoch [179/1000], loss:5.82544
epoch [180/1000], loss:5.78977
epoch [181/1000], loss:5.78293
epoch [182/1000], loss:5.77460
epoch [183/1000], loss:5.76192
epoch [184/1000], loss:5.75049
epoch [185/1000], loss:5.74188
epoch [186/1000], loss:5.73882
epoch [187/1000], loss:5.72205
epoch [188/1000], loss:5.70864
epoch [189/1000], loss:5.70273
epoch [190/1000], loss:5.69353
epoch [191/1000], loss:5.68343
epoch [192/1000], loss:5.67216
epoch [193/1000], loss:5.66239
epoch [194/1000], loss:5.65125
epoch [195/1000], loss:5.63932
epoch [196/1000], loss:5.63388
epoch [197/1000], loss:5.62116
epoch [198/1000], loss:5.61385
epoch [199/1000], loss:5.61483
epoch [200/1000], loss:5.59609
epoch [201/1000], loss:5.57955
epoch [202/1000], loss:5.57469
epoch [203/1000], loss:5.56383
epoch [204/1000], loss:5.55489
epoch [205/1000], loss:5.54320
epoch [206/1000], loss:5.52971
epoch [207/1000], loss:5.53083
epoch [208/1000], loss:5.51958
epoch [209/1000], loss:5.50594
epoch [210/1000], loss:5.50194
epoch [211/1000], loss:5.49632
epoch [212/1000], loss:5.47642
epoch [213/1000], loss:5.47105
epoch [214/1000], loss:5.46103
epoch [215/1000], loss:5.45136
epoch [216/1000], loss:5.44927
epoch [217/1000], loss:5.43372
epoch [218/1000], loss:5.43229
epoch [219/1000], loss:5.41293
epoch [220/1000], loss:5.40677
epoch [221/1000], loss:5.39713
epoch [222/1000], loss:5.39402
epoch [223/1000], loss:5.38856
epoch [224/1000], loss:5.37551
epoch [225/1000], loss:5.36045
epoch [226/1000], loss:5.35389
epoch [227/1000], loss:5.34672
epoch [228/1000], loss:5.33802
epoch [229/1000], loss:5.33105
epoch [230/1000], loss:5.32277
epoch [231/1000], loss:5.30828
epoch [232/1000], loss:5.29910
epoch [233/1000], loss:5.29399
epoch [234/1000], loss:5.28984
epoch [235/1000], loss:5.27597
epoch [236/1000], loss:5.26934
epoch [237/1000], loss:5.26663
epoch [238/1000], loss:5.25943
epoch [239/1000], loss:5.24395
epoch [240/1000], loss:5.24214
epoch [241/1000], loss:5.23017
epoch [242/1000], loss:5.21525
epoch [243/1000], loss:5.21001
epoch [244/1000], loss:5.20533
epoch [245/1000], loss:5.19778
epoch [246/1000], loss:5.19444
epoch [247/1000], loss:5.17834
epoch [248/1000], loss:5.17032
epoch [249/1000], loss:5.16573
epoch [250/1000], loss:5.16030
epoch [251/1000], loss:5.15691
epoch [252/1000], loss:5.14337
epoch [253/1000], loss:5.13357
epoch [254/1000], loss:5.12614
epoch [255/1000], loss:5.12397
epoch [256/1000], loss:5.11111
epoch [257/1000], loss:5.09905
epoch [258/1000], loss:5.09718
epoch [259/1000], loss:5.09271
epoch [260/1000], loss:5.08443
epoch [261/1000], loss:5.07630
epoch [262/1000], loss:5.06473
epoch [263/1000], loss:5.06329
epoch [264/1000], loss:5.05452
epoch [265/1000], loss:5.04306
epoch [266/1000], loss:5.04899
epoch [267/1000], loss:5.03139
epoch [268/1000], loss:5.02383
epoch [269/1000], loss:5.01982
epoch [270/1000], loss:5.01273
epoch [271/1000], loss:5.00642
epoch [272/1000], loss:4.99454
epoch [273/1000], loss:4.99690
epoch [274/1000], loss:4.98375
epoch [275/1000], loss:4.98370
epoch [276/1000], loss:4.96812
epoch [277/1000], loss:4.96210
epoch [278/1000], loss:4.96167
epoch [279/1000], loss:4.94264
epoch [280/1000], loss:4.94708
epoch [281/1000], loss:4.93381
epoch [282/1000], loss:4.92656
epoch [283/1000], loss:4.92751
epoch [284/1000], loss:4.91519
epoch [285/1000], loss:4.90649
epoch [286/1000], loss:4.90130
epoch [287/1000], loss:4.89965
epoch [288/1000], loss:4.88647
epoch [289/1000], loss:4.88522
epoch [290/1000], loss:4.87119
epoch [291/1000], loss:4.86967
epoch [292/1000], loss:4.86545
epoch [293/1000], loss:4.85670
epoch [294/1000], loss:4.84635
epoch [295/1000], loss:4.84253
epoch [296/1000], loss:4.84705
epoch [297/1000], loss:4.82709
epoch [298/1000], loss:4.82251
epoch [299/1000], loss:4.81915
epoch [300/1000], loss:4.81493
epoch [301/1000], loss:4.80140
epoch [302/1000], loss:4.79302
epoch [303/1000], loss:4.79099
epoch [304/1000], loss:4.78271
epoch [305/1000], loss:4.77509
epoch [306/1000], loss:4.76755
epoch [307/1000], loss:4.76485
epoch [308/1000], loss:4.76169
epoch [309/1000], loss:4.75328
epoch [310/1000], loss:4.74254
epoch [311/1000], loss:4.74224
epoch [312/1000], loss:4.74067
epoch [313/1000], loss:4.72933
epoch [314/1000], loss:4.71486
epoch [315/1000], loss:4.71784
epoch [316/1000], loss:4.70222
epoch [317/1000], loss:4.70290
epoch [318/1000], loss:4.69542
epoch [319/1000], loss:4.69025
epoch [320/1000], loss:4.68246
epoch [321/1000], loss:4.67295
epoch [322/1000], loss:4.67523
epoch [323/1000], loss:4.67207
epoch [324/1000], loss:4.66636
epoch [325/1000], loss:4.64616
epoch [326/1000], loss:4.64512
epoch [327/1000], loss:4.64286
epoch [328/1000], loss:4.63428
epoch [329/1000], loss:4.62759
epoch [330/1000], loss:4.62275
epoch [331/1000], loss:4.61570
epoch [332/1000], loss:4.61228
epoch [333/1000], loss:4.60109
epoch [334/1000], loss:4.60413
epoch [335/1000], loss:4.58950
epoch [336/1000], loss:4.59071
epoch [337/1000], loss:4.58295
epoch [338/1000], loss:4.57782
epoch [339/1000], loss:4.57129
epoch [340/1000], loss:4.56505
epoch [341/1000], loss:4.56037
epoch [342/1000], loss:4.55598
epoch [343/1000], loss:4.54537
epoch [344/1000], loss:4.54019
epoch [345/1000], loss:4.53571
epoch [346/1000], loss:4.53185
epoch [347/1000], loss:4.53183
epoch [348/1000], loss:4.52009
epoch [349/1000], loss:4.51411
epoch [350/1000], loss:4.50916
epoch [351/1000], loss:4.50595
epoch [352/1000], loss:4.50171
epoch [353/1000], loss:4.49431
epoch [354/1000], loss:4.48945
epoch [355/1000], loss:4.48904
epoch [356/1000], loss:4.47484
epoch [357/1000], loss:4.47601
epoch [358/1000], loss:4.46283
epoch [359/1000], loss:4.46043
epoch [360/1000], loss:4.45623
epoch [361/1000], loss:4.45144
epoch [387/1000], loss:4.32588
epoch [388/1000], loss:4.31738
epoch [389/1000], loss:4.31798
epoch [390/1000], loss:4.31714
epoch [391/1000], loss:4.30985
epoch [392/1000], loss:4.29957
epoch [393/1000], loss:4.29696
epoch [394/1000], loss:4.29420
epoch [395/1000], loss:4.28667
epoch [396/1000], loss:4.28612
epoch [397/1000], loss:4.27635
epoch [398/1000], loss:4.27332
epoch [399/1000], loss:4.27225
epoch [400/1000], loss:4.26569
epoch [401/1000], loss:4.26683
epoch [402/1000], loss:4.25562
epoch [403/1000], loss:4.24940
epoch [404/1000], loss:4.24415
epoch [405/1000], loss:4.24422
epoch [406/1000], loss:4.24053
epoch [407/1000], loss:4.23612
epoch [408/1000], loss:4.23212
epoch [409/1000], loss:4.23014
epoch [410/1000], loss:4.22054
epoch [411/1000], loss:4.21572
epoch [412/1000], loss:4.21339
epoch [413/1000], loss:4.20922
epoch [414/1000], loss:4.20910
epoch [415/1000], loss:4.20353
epoch [416/1000], loss:4.19610
epoch [417/1000], loss:4.19232
epoch [418/1000], loss:4.18926
epoch [419/1000], loss:4.18134
epoch [420/1000], loss:4.17638
epoch [421/1000], loss:4.17397
epoch [422/1000], loss:4.17142
epoch [423/1000], loss:4.16676
epoch [424/1000], loss:4.17102
epoch [425/1000], loss:4.15542
epoch [426/1000], loss:4.15438
epoch [427/1000], loss:4.15161
epoch [428/1000], loss:4.14431
epoch [429/1000], loss:4.14308
epoch [430/1000], loss:4.14248
epoch [431/1000], loss:4.13705
epoch [432/1000], loss:4.13069
epoch [433/1000], loss:4.12359
epoch [434/1000], loss:4.12440
epoch [435/1000], loss:4.12047
epoch [436/1000], loss:4.11715
epoch [437/1000], loss:4.11095
epoch [438/1000], loss:4.10556
epoch [439/1000], loss:4.10342
epoch [440/1000], loss:4.10314
epoch [441/1000], loss:4.09450
epoch [442/1000], loss:4.08683
epoch [443/1000], loss:4.08545
epoch [444/1000], loss:4.08673
epoch [445/1000], loss:4.07830
epoch [446/1000], loss:4.07518
epoch [447/1000], loss:4.06704
epoch [448/1000], loss:4.06815
epoch [449/1000], loss:4.06158
epoch [450/1000], loss:4.06410
epoch [451/1000], loss:4.05870
epoch [452/1000], loss:4.05462
epoch [453/1000], loss:4.04799
epoch [454/1000], loss:4.04455
epoch [455/1000], loss:4.03678
epoch [456/1000], loss:4.04038
epoch [457/1000], loss:4.03390
epoch [458/1000], loss:4.02727
epoch [459/1000], loss:4.02408
epoch [460/1000], loss:4.02337
epoch [461/1000], loss:4.01824
epoch [462/1000], loss:4.01433
epoch [463/1000], loss:4.00995
epoch [464/1000], loss:4.00826
epoch [465/1000], loss:4.00209
epoch [466/1000], loss:4.00384
epoch [467/1000], loss:3.99173
epoch [468/1000], loss:3.99856
epoch [469/1000], loss:3.99148
epoch [470/1000], loss:3.98304
epoch [471/1000], loss:3.98313
epoch [472/1000], loss:3.97725
epoch [473/1000], loss:3.97736
epoch [474/1000], loss:3.97326
epoch [475/1000], loss:3.96900
epoch [476/1000], loss:3.96096
epoch [477/1000], loss:3.96076
epoch [478/1000], loss:3.96005
epoch [479/1000], loss:3.95441
epoch [480/1000], loss:3.95287
epoch [481/1000], loss:3.94587
epoch [482/1000], loss:3.94024
epoch [483/1000], loss:3.93922
epoch [484/1000], loss:3.93559
epoch [485/1000], loss:3.93831
epoch [486/1000], loss:3.92520
epoch [487/1000], loss:3.92634
epoch [488/1000], loss:3.92151
epoch [489/1000], loss:3.91649
epoch [490/1000], loss:3.91573
epoch [491/1000], loss:3.91516
epoch [492/1000], loss:3.90679
epoch [493/1000], loss:3.90961
epoch [494/1000], loss:3.89975
epoch [495/1000], loss:3.89675
epoch [496/1000], loss:3.89311
epoch [497/1000], loss:3.89344
epoch [498/1000], loss:3.89109
epoch [499/1000], loss:3.88556
epoch [500/1000], loss:3.87982
epoch [501/1000], loss:3.87826
epoch [502/1000], loss:3.87651
epoch [503/1000], loss:3.87134
epoch [504/1000], loss:3.86625
epoch [505/1000], loss:3.86563
epoch [506/1000], loss:3.86109
epoch [507/1000], loss:3.86168
epoch [508/1000], loss:3.85732
epoch [509/1000], loss:3.84998
epoch [510/1000], loss:3.85233
epoch [511/1000], loss:3.84760
epoch [512/1000], loss:3.84713
epoch [513/1000], loss:3.83537
epoch [514/1000], loss:3.83900
epoch [515/1000], loss:3.82796
epoch [516/1000], loss:3.82622
epoch [517/1000], loss:3.83100
epoch [518/1000], loss:3.82413
epoch [519/1000], loss:3.81903
epoch [520/1000], loss:3.81732
epoch [521/1000], loss:3.81084
epoch [522/1000], loss:3.81144
epoch [523/1000], loss:3.80305
epoch [524/1000], loss:3.80411
epoch [525/1000], loss:3.80302
epoch [526/1000], loss:3.79430
epoch [527/1000], loss:3.79282
epoch [528/1000], loss:3.79408
epoch [529/1000], loss:3.79307
epoch [530/1000], loss:3.78673
epoch [531/1000], loss:3.78254
epoch [532/1000], loss:3.77649
epoch [533/1000], loss:3.77460
epoch [534/1000], loss:3.77207
epoch [535/1000], loss:3.76966
epoch [536/1000], loss:3.76757
epoch [537/1000], loss:3.76382
epoch [538/1000], loss:3.75726
epoch [539/1000], loss:3.76330
epoch [540/1000], loss:3.75130
epoch [541/1000], loss:3.74979
epoch [542/1000], loss:3.74968
epoch [543/1000], loss:3.73983
epoch [544/1000], loss:3.73901
epoch [545/1000], loss:3.73932
epoch [546/1000], loss:3.73718
epoch [547/1000], loss:3.73794
epoch [548/1000], loss:3.72818
epoch [549/1000], loss:3.72528
epoch [550/1000], loss:3.72475
epoch [551/1000], loss:3.71988
epoch [552/1000], loss:3.71729
epoch [553/1000], loss:3.71119
epoch [554/1000], loss:3.71207
epoch [555/1000], loss:3.71167
epoch [556/1000], loss:3.70275
epoch [557/1000], loss:3.70654
epoch [558/1000], loss:3.69792
epoch [559/1000], loss:3.69927
epoch [560/1000], loss:3.69409
epoch [561/1000], loss:3.69188
epoch [562/1000], loss:3.68632
epoch [563/1000], loss:3.68308
epoch [564/1000], loss:3.68161
epoch [565/1000], loss:3.68463
epoch [566/1000], loss:3.67181
epoch [567/1000], loss:3.67101
epoch [568/1000], loss:3.66956
epoch [569/1000], loss:3.66723
epoch [570/1000], loss:3.66829
epoch [571/1000], loss:3.66422
epoch [572/1000], loss:3.66120
epoch [573/1000], loss:3.65323
epoch [574/1000], loss:3.65280
epoch [575/1000], loss:3.65279
epoch [576/1000], loss:3.64698
epoch [577/1000], loss:3.64525
epoch [578/1000], loss:3.64385
epoch [579/1000], loss:3.63892
epoch [580/1000], loss:3.63570
epoch [581/1000], loss:3.63038
epoch [582/1000], loss:3.63306
epoch [583/1000], loss:3.62456
epoch [584/1000], loss:3.62961
epoch [585/1000], loss:3.61710
epoch [586/1000], loss:3.62218
epoch [587/1000], loss:3.61367
epoch [588/1000], loss:3.61351
epoch [589/1000], loss:3.61048
epoch [590/1000], loss:3.60863
epoch [591/1000], loss:3.60503
epoch [592/1000], loss:3.60068
epoch [593/1000], loss:3.59856
epoch [594/1000], loss:3.59472
epoch [595/1000], loss:3.59365
epoch [596/1000], loss:3.59324
epoch [597/1000], loss:3.58769
epoch [598/1000], loss:3.58214
epoch [599/1000], loss:3.58244
epoch [600/1000], loss:3.57799
epoch [601/1000], loss:3.57877
epoch [602/1000], loss:3.57055
epoch [603/1000], loss:3.57307
epoch [604/1000], loss:3.57202
epoch [605/1000], loss:3.56517
epoch [606/1000], loss:3.56280
epoch [607/1000], loss:3.56200
epoch [608/1000], loss:3.56267
epoch [609/1000], loss:3.55470
epoch [610/1000], loss:3.55250
epoch [611/1000], loss:3.54826
epoch [612/1000], loss:3.55154
epoch [613/1000], loss:3.54208
epoch [614/1000], loss:3.54206
epoch [615/1000], loss:3.54105
epoch [616/1000], loss:3.53665
epoch [617/1000], loss:3.53198
epoch [618/1000], loss:3.52956
epoch [619/1000], loss:3.52716
epoch [620/1000], loss:3.52535
epoch [621/1000], loss:3.52693
epoch [622/1000], loss:3.51926
epoch [623/1000], loss:3.51655
epoch [624/1000], loss:3.51352
epoch [625/1000], loss:3.51410
epoch [626/1000], loss:3.50871
epoch [627/1000], loss:3.50490
epoch [628/1000], loss:3.50470
epoch [629/1000], loss:3.50429
epoch [630/1000], loss:3.50063
epoch [631/1000], loss:3.49522
epoch [632/1000], loss:3.49489
epoch [633/1000], loss:3.49385
epoch [634/1000], loss:3.48804
epoch [635/1000], loss:3.48522
epoch [636/1000], loss:3.48331
epoch [637/1000], loss:3.47941
epoch [638/1000], loss:3.47592
epoch [639/1000], loss:3.47459
epoch [640/1000], loss:3.47359
epoch [641/1000], loss:3.47270
epoch [642/1000], loss:3.46967
epoch [643/1000], loss:3.46600
epoch [644/1000], loss:3.46549
epoch [645/1000], loss:3.46019
epoch [646/1000], loss:3.45748
epoch [647/1000], loss:3.45389
epoch [648/1000], loss:3.44896
epoch [649/1000], loss:3.44991
epoch [650/1000], loss:3.44311
epoch [651/1000], loss:3.44865
epoch [652/1000], loss:3.44133
epoch [653/1000], loss:3.43858
epoch [654/1000], loss:3.44189
epoch [655/1000], loss:3.43480
epoch [656/1000], loss:3.43255
epoch [657/1000], loss:3.42989
epoch [658/1000], loss:3.42864
epoch [659/1000], loss:3.42396
epoch [660/1000], loss:3.42112
epoch [661/1000], loss:3.42302
epoch [662/1000], loss:3.41736
epoch [663/1000], loss:3.41416
epoch [664/1000], loss:3.41132
epoch [665/1000], loss:3.41046
epoch [666/1000], loss:3.40492
epoch [667/1000], loss:3.40502
epoch [668/1000], loss:3.40614
epoch [669/1000], loss:3.40063
epoch [670/1000], loss:3.40028
epoch [671/1000], loss:3.39271
epoch [672/1000], loss:3.39536
epoch [673/1000], loss:3.39127
epoch [674/1000], loss:3.38746
epoch [675/1000], loss:3.38874
epoch [676/1000], loss:3.38427
epoch [677/1000], loss:3.38143
epoch [678/1000], loss:3.37742
epoch [679/1000], loss:3.37587
epoch [680/1000], loss:3.37513
epoch [681/1000], loss:3.37196
epoch [682/1000], loss:3.36916
epoch [683/1000], loss:3.36594
epoch [684/1000], loss:3.36606
epoch [685/1000], loss:3.36292
epoch [686/1000], loss:3.35892
epoch [687/1000], loss:3.35532
epoch [688/1000], loss:3.35597
epoch [689/1000], loss:3.35689
epoch [690/1000], loss:3.34953
epoch [691/1000], loss:3.34964
epoch [692/1000], loss:3.34474
epoch [693/1000], loss:3.34500
epoch [694/1000], loss:3.34074
epoch [695/1000], loss:3.34088
epoch [696/1000], loss:3.33748
epoch [697/1000], loss:3.33662
epoch [698/1000], loss:3.33202
epoch [699/1000], loss:3.33229
epoch [700/1000], loss:3.32739
epoch [701/1000], loss:3.32630
epoch [702/1000], loss:3.32807
epoch [703/1000], loss:3.32146
epoch [704/1000], loss:3.31806
epoch [705/1000], loss:3.31831
epoch [706/1000], loss:3.31332
epoch [707/1000], loss:3.31269
epoch [708/1000], loss:3.30964
epoch [709/1000], loss:3.30984
epoch [710/1000], loss:3.30538
epoch [711/1000], loss:3.30281
epoch [712/1000], loss:3.30262
epoch [713/1000], loss:3.29772
epoch [714/1000], loss:3.29625
epoch [715/1000], loss:3.29219
epoch [716/1000], loss:3.29506
epoch [717/1000], loss:3.28936
epoch [718/1000], loss:3.28897
epoch [719/1000], loss:3.29049
epoch [720/1000], loss:3.28375
epoch [721/1000], loss:3.28123
epoch [722/1000], loss:3.27900
epoch [723/1000], loss:3.27359
epoch [724/1000], loss:3.27611
epoch [725/1000], loss:3.27433
epoch [726/1000], loss:3.27112
epoch [727/1000], loss:3.26646
epoch [728/1000], loss:3.26737
epoch [729/1000], loss:3.26536
epoch [730/1000], loss:3.26612
epoch [731/1000], loss:3.26075
epoch [732/1000], loss:3.26027
epoch [733/1000], loss:3.25291
epoch [734/1000], loss:3.25916
epoch [735/1000], loss:3.24919
epoch [736/1000], loss:3.25470
epoch [737/1000], loss:3.24516
epoch [738/1000], loss:3.24314
epoch [739/1000], loss:3.24429
epoch [740/1000], loss:3.24261
epoch [741/1000], loss:3.23813
epoch [742/1000], loss:3.23578
epoch [743/1000], loss:3.23666
epoch [744/1000], loss:3.23200
epoch [745/1000], loss:3.23238
epoch [746/1000], loss:3.22988
epoch [747/1000], loss:3.22826
epoch [748/1000], loss:3.23023
epoch [749/1000], loss:3.22209
epoch [750/1000], loss:3.21966
epoch [751/1000], loss:3.21754
epoch [752/1000], loss:3.21620
epoch [753/1000], loss:3.21760
epoch [754/1000], loss:3.21165
epoch [755/1000], loss:3.21131
epoch [756/1000], loss:3.21038
epoch [757/1000], loss:3.20712
epoch [758/1000], loss:3.20317
epoch [759/1000], loss:3.20223
epoch [760/1000], loss:3.20180
epoch [761/1000], loss:3.20010
epoch [762/1000], loss:3.19946
epoch [763/1000], loss:3.19183
epoch [764/1000], loss:3.19291
epoch [765/1000], loss:3.18863
epoch [766/1000], loss:3.18918
epoch [767/1000], loss:3.18898
epoch [768/1000], loss:3.18414
epoch [769/1000], loss:3.18572
epoch [770/1000], loss:3.18738
epoch [771/1000], loss:3.17861
epoch [772/1000], loss:3.17652
epoch [773/1000], loss:3.17587
epoch [774/1000], loss:3.17144
epoch [775/1000], loss:3.17319
epoch [776/1000], loss:3.17009
epoch [777/1000], loss:3.16943
epoch [778/1000], loss:3.16559
epoch [779/1000], loss:3.16415
epoch [780/1000], loss:3.16417
epoch [781/1000], loss:3.16414
epoch [782/1000], loss:3.15878
epoch [783/1000], loss:3.15620
epoch [784/1000], loss:3.15162
epoch [785/1000], loss:3.15188
epoch [786/1000], loss:3.15056
epoch [787/1000], loss:3.14792
epoch [788/1000], loss:3.14884
epoch [789/1000], loss:3.14594
epoch [790/1000], loss:3.14544
epoch [791/1000], loss:3.14156
epoch [792/1000], loss:3.13851
epoch [793/1000], loss:3.13792
epoch [794/1000], loss:3.13770
epoch [795/1000], loss:3.13333
epoch [796/1000], loss:3.13036
epoch [797/1000], loss:3.12862
epoch [798/1000], loss:3.13088
epoch [799/1000], loss:3.12679
epoch [800/1000], loss:3.12329
epoch [801/1000], loss:3.12549
epoch [802/1000], loss:3.12244
epoch [803/1000], loss:3.11828
epoch [804/1000], loss:3.11357
epoch [805/1000], loss:3.11698
epoch [806/1000], loss:3.11326
epoch [807/1000], loss:3.11584
epoch [808/1000], loss:3.10921
epoch [809/1000], loss:3.10769
epoch [810/1000], loss:3.10721
epoch [811/1000], loss:3.10426
epoch [812/1000], loss:3.10207
epoch [813/1000], loss:3.09837
epoch [814/1000], loss:3.09836
epoch [815/1000], loss:3.09801
epoch [816/1000], loss:3.09438
epoch [817/1000], loss:3.09267
epoch [818/1000], loss:3.09224
epoch [819/1000], loss:3.08851
epoch [820/1000], loss:3.08578
epoch [821/1000], loss:3.08942
epoch [822/1000], loss:3.08425
epoch [823/1000], loss:3.08528
epoch [824/1000], loss:3.08140
epoch [825/1000], loss:3.07830
epoch [826/1000], loss:3.07588
epoch [827/1000], loss:3.07775
epoch [828/1000], loss:3.07456
epoch [829/1000], loss:3.07019
epoch [830/1000], loss:3.07405
epoch [831/1000], loss:3.06494
epoch [832/1000], loss:3.06572
epoch [833/1000], loss:3.06405
epoch [834/1000], loss:3.06366
epoch [835/1000], loss:3.05963
epoch [836/1000], loss:3.05978
epoch [837/1000], loss:3.05587
epoch [838/1000], loss:3.05641
epoch [839/1000], loss:3.05452
epoch [840/1000], loss:3.05307
epoch [841/1000], loss:3.04878
epoch [842/1000], loss:3.05134
epoch [843/1000], loss:3.04592
epoch [844/1000], loss:3.04432
epoch [845/1000], loss:3.04292
epoch [846/1000], loss:3.04020
epoch [847/1000], loss:3.04101
epoch [848/1000], loss:3.04131
epoch [849/1000], loss:3.03655
epoch [850/1000], loss:3.03434
epoch [851/1000], loss:3.03037
epoch [852/1000], loss:3.03011
epoch [853/1000], loss:3.03031
epoch [854/1000], loss:3.02658
epoch [855/1000], loss:3.02762
epoch [856/1000], loss:3.02805
epoch [857/1000], loss:3.02052
epoch [858/1000], loss:3.02101
epoch [859/1000], loss:3.01820
epoch [860/1000], loss:3.01740
epoch [861/1000], loss:3.01673
epoch [862/1000], loss:3.01265
epoch [863/1000], loss:3.00953
epoch [864/1000], loss:3.01045
epoch [865/1000], loss:3.00850
epoch [866/1000], loss:3.01031
epoch [867/1000], loss:3.00408
epoch [868/1000], loss:3.00111
epoch [869/1000], loss:3.00130
epoch [870/1000], loss:3.00163
epoch [871/1000], loss:2.99810
epoch [872/1000], loss:2.99874
epoch [873/1000], loss:2.99178
epoch [874/1000], loss:2.99280
epoch [875/1000], loss:2.99230
epoch [876/1000], loss:2.98815
epoch [877/1000], loss:2.98851
epoch [878/1000], loss:2.98612
epoch [879/1000], loss:2.98797
epoch [880/1000], loss:2.98337
epoch [881/1000], loss:2.98161
epoch [882/1000], loss:2.98003
epoch [883/1000], loss:2.97484
epoch [884/1000], loss:2.97611
epoch [885/1000], loss:2.97621
epoch [886/1000], loss:2.97396
epoch [887/1000], loss:2.96927
epoch [888/1000], loss:2.96680
epoch [889/1000], loss:2.96926
epoch [890/1000], loss:2.96575
epoch [891/1000], loss:2.96431
epoch [892/1000], loss:2.96193
epoch [893/1000], loss:2.95761
epoch [894/1000], loss:2.96028
epoch [895/1000], loss:2.96046
epoch [896/1000], loss:2.95814
epoch [897/1000], loss:2.95228
epoch [898/1000], loss:2.94921
epoch [899/1000], loss:2.95213
epoch [900/1000], loss:2.94890
epoch [901/1000], loss:2.94738
epoch [902/1000], loss:2.94390
epoch [903/1000], loss:2.94118
epoch [904/1000], loss:2.94426
epoch [905/1000], loss:2.94239
epoch [906/1000], loss:2.93883
epoch [907/1000], loss:2.93823
epoch [908/1000], loss:2.93640
epoch [909/1000], loss:2.93234
epoch [910/1000], loss:2.93235
epoch [911/1000], loss:2.92981
epoch [912/1000], loss:2.93039
epoch [913/1000], loss:2.93373
epoch [914/1000], loss:2.92795
epoch [915/1000], loss:2.92420
epoch [916/1000], loss:2.92136
epoch [917/1000], loss:2.91813
epoch [918/1000], loss:2.91754
epoch [919/1000], loss:2.91795
epoch [920/1000], loss:2.91643
epoch [921/1000], loss:2.91321
epoch [922/1000], loss:2.91369
epoch [923/1000], loss:2.91094
epoch [924/1000], loss:2.91049
epoch [925/1000], loss:2.90867
epoch [926/1000], loss:2.90595
epoch [927/1000], loss:2.90455
epoch [928/1000], loss:2.90523
epoch [929/1000], loss:2.90355
epoch [930/1000], loss:2.90085
epoch [931/1000], loss:2.89791
epoch [932/1000], loss:2.89439
epoch [933/1000], loss:2.89587
epoch [934/1000], loss:2.89358
epoch [935/1000], loss:2.89229
epoch [936/1000], loss:2.88939
epoch [937/1000], loss:2.89070
epoch [938/1000], loss:2.88834
epoch [939/1000], loss:2.88700
epoch [940/1000], loss:2.88633
epoch [941/1000], loss:2.88195
epoch [942/1000], loss:2.88308
epoch [943/1000], loss:2.87824
epoch [944/1000], loss:2.87709
epoch [945/1000], loss:2.87709
epoch [946/1000], loss:2.87699
epoch [947/1000], loss:2.87330
epoch [948/1000], loss:2.87141
epoch [949/1000], loss:2.87136
epoch [950/1000], loss:2.86982
epoch [951/1000], loss:2.86829
epoch [952/1000], loss:2.86615
epoch [953/1000], loss:2.86325
epoch [954/1000], loss:2.86094
epoch [955/1000], loss:2.86219
epoch [956/1000], loss:2.85894
epoch [957/1000], loss:2.86180
epoch [958/1000], loss:2.85887
epoch [959/1000], loss:2.85384
epoch [960/1000], loss:2.85410
epoch [961/1000], loss:2.85243
epoch [962/1000], loss:2.85051
epoch [963/1000], loss:2.84668
epoch [964/1000], loss:2.84494
epoch [965/1000], loss:2.84352
epoch [966/1000], loss:2.84500
epoch [967/1000], loss:2.84642
epoch [968/1000], loss:2.83922
epoch [969/1000], loss:2.83965
epoch [970/1000], loss:2.84072
epoch [971/1000], loss:2.83823
epoch [972/1000], loss:2.83543
epoch [973/1000], loss:2.83415
epoch [974/1000], loss:2.83639
epoch [975/1000], loss:2.82995
epoch [976/1000], loss:2.82914
epoch [977/1000], loss:2.82669
epoch [978/1000], loss:2.83094
epoch [979/1000], loss:2.82190
epoch [980/1000], loss:2.82548
epoch [981/1000], loss:2.82011
epoch [982/1000], loss:2.82137
epoch [983/1000], loss:2.81966
epoch [984/1000], loss:2.81743
epoch [985/1000], loss:2.81949
epoch [986/1000], loss:2.81346
epoch [987/1000], loss:2.81393
epoch [988/1000], loss:2.81204
epoch [989/1000], loss:2.81101
epoch [990/1000], loss:2.81068
epoch [991/1000], loss:2.80631
epoch [992/1000], loss:2.80828
epoch [993/1000], loss:2.80407
epoch [994/1000], loss:2.80417
epoch [995/1000], loss:2.80385
epoch [996/1000], loss:2.80122
epoch [997/1000], loss:2.80091
epoch [998/1000], loss:2.79750
epoch [999/1000], loss:2.79585
epoch [1000/1000], loss:2.79409


import numpy as np

def cal_acc(gt, pred):
    """ Computes categorization accuracy of our task.
      gt: Ground truth labels (9000, )
      pred: Predicted labels (9000, )
      acc: Accuracy (0~1 scalar)
    # Calculate Correct predictions
    correct = np.sum(gt == pred)
    acc = correct / gt.shape[0]
    # 因为是binary unsupervised clustering,因此取max(acc,1-acc)# 因为我们只在乎有没有成功将图片分成两群
    return max(acc, 1-acc)
import matplotlib.pyplot as plt

def plot_scatter(feat, label, savefig=None):
    """ Plot Scatter Image.
      feat: the (x, y) coordinate of clustering result, shape: (9000, 2)
      label: ground truth label of image (0/1), shape: (9000,)
    X = feat[:, 0]
    Y = feat[:, 1]
    plt.scatter(X, Y, c = label)
    if savefig is not None:

接着我们使用训练好的 model,来预测 testing data 的类别。

由于 testing data 与 training data 一样,因此我们使用同样的 dataset 来实作 dataloader。与 training 不同的地方在于 shuffle 这个参数值在这边是 False。

准备好 model 与 dataloader,我们就可以进行预测了。

我们只需要 encoder 的结果(latents),利用 latents 进行 clustering 之后,就可以分类了。

import torch
from sklearn.decomposition import KernelPCA
#主成分分析(Principal Component Analysis)是目前为止最流行的降维算法。首先它找到接近数据集分布的超平面,然后将所有的数据都投影到这个超平面上。
# kPCA 是无监督学习算法,因此没有明显的性能指标可以帮助我们选择最佳的核和超参数值。不过,降维通常是监督学习任务(例如分类)的准备步骤.
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans

def inference(X, model, batch_size=256):
    X = preprocess(X)
    dataset = Image_Dataset(X)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    latents = []
    for i, x in enumerate(dataloader):
        x = torch.FloatTensor(x)
        vec, img = model(x.cuda())
        if i == 0:
          #x = x.view(batchsize, -1)中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。
            latents = vec.view(img.size()[0], -1).cpu().detach().numpy()
            latents = np.concatenate((latents, vec.view(img.size()[0], -1).cpu().detach().numpy()), axis = 0)
    print('Latents Shape:', latents.shape)
    return latents

def predict(latents):
    # First Dimension Reduction
    transformer = KernelPCA(n_components=200, kernel='rbf', n_jobs=-1)
    #-1:使用所有CPU. n_jobs<-1时,使用(n_cpus+1+n_jobs)个CPU

    kpca = transformer.fit_transform(latents)
    print('First Reduction Shape:', kpca.shape)

    # # Second Dimesnion Reduction
    X_embedded = TSNE(n_components=2).fit_transform(kpca)
    print('Second Reduction Shape:', X_embedded.shape)

    # Clustering
    #random_state:参数为int,RandomState instance or None.用来设置生成随机数的方式 
    pred = MiniBatchKMeans(n_clusters=2, random_state=0).fit(X_embedded)
    pred = [int(i) for i in pred.labels_]
    pred = np.array(pred)
    return pred, X_embedded

def invert(pred):
    return np.abs(1-pred)

def save_prediction(pred, out_csv='prediction.csv'):
    with open(out_csv, 'w') as f:
        for i, p in enumerate(pred):
    print(f'Save prediction to {

# load model
model = AE().cuda()

# 准备 data
trainX = np.load('trainX.npy')

# 预测答案
latents = inference(X=trainX, model=model)
pred, X_embedded = predict(latents)

# 將预测結果存檔,上上传 kaggle
save_prediction(pred, 'prediction.csv')

# 由于是unsupervised的二分类问题,我们只在乎有没有成功将图片分成两群
# 如果上面的档案上传kaggle后正确率不足0.5,只要将label反过来就行了
save_prediction(invert(pred), 'prediction_invert.csv')
Latents Shape: (8500, 2048)
First Reduction Shape: (8500, 200)
Second Reduction Shape: (8500, 2)
Save prediction to prediction.csv.
Save prediction to prediction_invert.csv.


将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。

valX = np.load('valX.npy')
valY = np.load('valY.npy')

# ==============================================
#  我们示范basline model的作图,
#  report请同学另外还要再画一张improved model的图。
# ==============================================
latents = inference(valX, model)
pred_from_latent, emb_from_latent = predict(latents)
acc_latent = cal_acc(valY, pred_from_latent)
print('The clustering accuracy is:', acc_latent)
print('The clustering result:')
plot_scatter(emb_from_latent, valY, savefig='p1_baseline.png')
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)

No handles with labels found to put in legend.

Second Reduction Shape: (500, 2)
The clustering accuracy is: 0.75
The clustering result:



使用你 test accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片 画出他们的原图以及 reconstruct 之后的图片

import matplotlib.pyplot as plt
import numpy as np

# 画出原图
indexes = [1,2,3,6,7,9]
imgs = trainX[indexes,]
for i, img in enumerate(imgs):
    plt.subplot(2, 6, i+1, xticks=[], yticks=[])

# 画出 reconstruct 的图
inp = torch.Tensor(trainX_preprocessed[indexes,]).cuda()
latents, recs = model(inp)
recs = ((recs+1)/2 ).cpu().detach().numpy()
recs = recs.transpose(0, 2, 3, 1)
for i, img in enumerate(recs):
    plt.subplot(2, 6, 6+i+1, xticks=[], yticks=[])



在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints 请用 model 的 train reconstruction error 对 val accuracy 作图 简单说明你观察到的现象

import os
import glob
checkpoints_list = sorted(glob.glob('checkpoints/checkpoint_*.pth'), key= lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))
# load data
dataset = Image_Dataset(trainX_preprocessed)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

points = []
with torch.no_grad():
    for i, checkpoint in enumerate(checkpoints_list):
        print('[{}/{}] {}'.format(i+1, len(checkpoints_list), checkpoint))
        err = 0
        n = 0
        for x in dataloader:
            x = x.cuda()
            _, rec = model(x)
            err += torch.nn.MSELoss(reduction='sum')(x, rec).item()
            n += x.flatten().size(0)
        print('Reconstruction error (MSE):', err/n)
        latents = inference(X=valX, model=model)
        pred, X_embedded = predict(latents)
        acc = cal_acc(valY, pred)
        print('Accuracy:', acc)
        points.append((err/n, acc))
['checkpoints/checkpoint_10.pth', 'checkpoints/checkpoint_20.pth', 'checkpoints/checkpoint_30.pth', 'checkpoints/checkpoint_40.pth', 'checkpoints/checkpoint_50.pth', 'checkpoints/checkpoint_60.pth', 'checkpoints/checkpoint_70.pth', 'checkpoints/checkpoint_80.pth', 'checkpoints/checkpoint_90.pth', 'checkpoints/checkpoint_100.pth', 'checkpoints/checkpoint_110.pth', 'checkpoints/checkpoint_120.pth', 'checkpoints/checkpoint_130.pth', 'checkpoints/checkpoint_140.pth', 'checkpoints/checkpoint_150.pth', 'checkpoints/checkpoint_160.pth', 'checkpoints/checkpoint_170.pth', 'checkpoints/checkpoint_180.pth', 'checkpoints/checkpoint_190.pth', 'checkpoints/checkpoint_200.pth', 'checkpoints/checkpoint_210.pth', 'checkpoints/checkpoint_220.pth', 'checkpoints/checkpoint_230.pth', 'checkpoints/checkpoint_240.pth', 'checkpoints/checkpoint_250.pth', 'checkpoints/checkpoint_260.pth', 'checkpoints/checkpoint_270.pth', 'checkpoints/checkpoint_280.pth', 'checkpoints/checkpoint_290.pth', 'checkpoints/checkpoint_300.pth', 'checkpoints/checkpoint_310.pth', 'checkpoints/checkpoint_320.pth', 'checkpoints/checkpoint_330.pth', 'checkpoints/checkpoint_340.pth', 'checkpoints/checkpoint_350.pth', 'checkpoints/checkpoint_360.pth', 'checkpoints/checkpoint_370.pth', 'checkpoints/checkpoint_380.pth', 'checkpoints/checkpoint_390.pth', 'checkpoints/checkpoint_400.pth', 'checkpoints/checkpoint_410.pth', 'checkpoints/checkpoint_420.pth', 'checkpoints/checkpoint_430.pth', 'checkpoints/checkpoint_440.pth', 'checkpoints/checkpoint_450.pth', 'checkpoints/checkpoint_460.pth', 'checkpoints/checkpoint_470.pth', 'checkpoints/checkpoint_480.pth', 'checkpoints/checkpoint_490.pth', 'checkpoints/checkpoint_500.pth', 'checkpoints/checkpoint_510.pth', 'checkpoints/checkpoint_520.pth', 'checkpoints/checkpoint_530.pth', 'checkpoints/checkpoint_540.pth', 'checkpoints/checkpoint_550.pth', 'checkpoints/checkpoint_560.pth', 'checkpoints/checkpoint_570.pth', 'checkpoints/checkpoint_580.pth', 'checkpoints/checkpoint_590.pth', 'checkpoints/checkpoint_600.pth', 'checkpoints/checkpoint_610.pth', 'checkpoints/checkpoint_620.pth', 'checkpoints/checkpoint_630.pth', 'checkpoints/checkpoint_640.pth', 'checkpoints/checkpoint_650.pth', 'checkpoints/checkpoint_660.pth', 'checkpoints/checkpoint_670.pth', 'checkpoints/checkpoint_680.pth', 'checkpoints/checkpoint_690.pth', 'checkpoints/checkpoint_700.pth', 'checkpoints/checkpoint_710.pth', 'checkpoints/checkpoint_720.pth', 'checkpoints/checkpoint_730.pth', 'checkpoints/checkpoint_740.pth', 'checkpoints/checkpoint_750.pth', 'checkpoints/checkpoint_760.pth', 'checkpoints/checkpoint_770.pth', 'checkpoints/checkpoint_780.pth', 'checkpoints/checkpoint_790.pth', 'checkpoints/checkpoint_800.pth', 'checkpoints/checkpoint_810.pth', 'checkpoints/checkpoint_820.pth', 'checkpoints/checkpoint_830.pth', 'checkpoints/checkpoint_840.pth', 'checkpoints/checkpoint_850.pth', 'checkpoints/checkpoint_860.pth', 'checkpoints/checkpoint_870.pth', 'checkpoints/checkpoint_880.pth', 'checkpoints/checkpoint_890.pth', 'checkpoints/checkpoint_900.pth', 'checkpoints/checkpoint_910.pth', 'checkpoints/checkpoint_920.pth', 'checkpoints/checkpoint_930.pth', 'checkpoints/checkpoint_940.pth', 'checkpoints/checkpoint_950.pth', 'checkpoints/checkpoint_960.pth', 'checkpoints/checkpoint_970.pth', 'checkpoints/checkpoint_980.pth', 'checkpoints/checkpoint_990.pth', 'checkpoints/checkpoint_1000.pth']
[1/100] checkpoints/checkpoint_10.pth
Reconstruction error (MSE): 0.10465650191961551
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.56
[2/100] checkpoints/checkpoint_20.pth
Reconstruction error (MSE): 0.08282024884691426
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.564
[3/100] checkpoints/checkpoint_30.pth
Reconstruction error (MSE): 0.07529751972123688
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[4/100] checkpoints/checkpoint_40.pth
Reconstruction error (MSE): 0.06996455570295745
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[5/100] checkpoints/checkpoint_50.pth
Reconstruction error (MSE): 0.06556768215403837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[6/100] checkpoints/checkpoint_60.pth
Reconstruction error (MSE): 0.0623151860704609
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[7/100] checkpoints/checkpoint_70.pth
Reconstruction error (MSE): 0.05941213181439568
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[8/100] checkpoints/checkpoint_80.pth
Reconstruction error (MSE): 0.057350127874636184
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[9/100] checkpoints/checkpoint_90.pth
Reconstruction error (MSE): 0.05522508699753705
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[10/100] checkpoints/checkpoint_100.pth
Reconstruction error (MSE): 0.05384457483478621
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[11/100] checkpoints/checkpoint_110.pth
Reconstruction error (MSE): 0.05195554022695504
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.5
[12/100] checkpoints/checkpoint_120.pth
Reconstruction error (MSE): 0.05074959627787272
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[13/100] checkpoints/checkpoint_130.pth
Reconstruction error (MSE): 0.04992709094402837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[14/100] checkpoints/checkpoint_140.pth
Reconstruction error (MSE): 0.04817914684146058
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[15/100] checkpoints/checkpoint_150.pth
Reconstruction error (MSE): 0.04657277587815827
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.512
[16/100] checkpoints/checkpoint_160.pth
Reconstruction error (MSE): 0.045626810316945994
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[17/100] checkpoints/checkpoint_170.pth
Reconstruction error (MSE): 0.04440261214387183
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.548
[18/100] checkpoints/checkpoint_180.pth
Reconstruction error (MSE): 0.04345548491384469
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[19/100] checkpoints/checkpoint_190.pth
Reconstruction error (MSE): 0.04282478637321323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[20/100] checkpoints/checkpoint_200.pth
Reconstruction error (MSE): 0.042173867076051
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[21/100] checkpoints/checkpoint_210.pth
Reconstruction error (MSE): 0.041361579988517014
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[22/100] checkpoints/checkpoint_220.pth
Reconstruction error (MSE): 0.040615920683916874
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[23/100] checkpoints/checkpoint_230.pth
Reconstruction error (MSE): 0.039873808785980826
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[24/100] checkpoints/checkpoint_240.pth
Reconstruction error (MSE): 0.03932966136932373
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[25/100] checkpoints/checkpoint_250.pth
Reconstruction error (MSE): 0.038771576432620775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[26/100] checkpoints/checkpoint_260.pth
Reconstruction error (MSE): 0.0381339080099966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[27/100] checkpoints/checkpoint_270.pth
Reconstruction error (MSE): 0.03751208638209923
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[28/100] checkpoints/checkpoint_280.pth
Reconstruction error (MSE): 0.037052626366708794
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[29/100] checkpoints/checkpoint_290.pth
Reconstruction error (MSE): 0.03666375287373861
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[30/100] checkpoints/checkpoint_300.pth
Reconstruction error (MSE): 0.03611967169069776
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[31/100] checkpoints/checkpoint_310.pth
Reconstruction error (MSE): 0.03564991631227381
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[32/100] checkpoints/checkpoint_320.pth
Reconstruction error (MSE): 0.035199516689076144
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[33/100] checkpoints/checkpoint_330.pth
Reconstruction error (MSE): 0.034691137108148314
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[34/100] checkpoints/checkpoint_340.pth
Reconstruction error (MSE): 0.03432022960513246
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[35/100] checkpoints/checkpoint_350.pth
Reconstruction error (MSE): 0.033855246824376725
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[36/100] checkpoints/checkpoint_360.pth
Reconstruction error (MSE): 0.03339220189113243
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[37/100] checkpoints/checkpoint_370.pth
Reconstruction error (MSE): 0.03329624884736304
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[38/100] checkpoints/checkpoint_380.pth
Reconstruction error (MSE): 0.03264928217495189
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[39/100] checkpoints/checkpoint_390.pth
Reconstruction error (MSE): 0.03237577991859586
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[40/100] checkpoints/checkpoint_400.pth
Reconstruction error (MSE): 0.03208851829229617
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[41/100] checkpoints/checkpoint_410.pth
Reconstruction error (MSE): 0.0316866933037253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[42/100] checkpoints/checkpoint_420.pth
Reconstruction error (MSE): 0.031364078933117434
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[43/100] checkpoints/checkpoint_430.pth
Reconstruction error (MSE): 0.031136608348173254
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[44/100] checkpoints/checkpoint_440.pth
Reconstruction error (MSE): 0.0309632776297775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[45/100] checkpoints/checkpoint_450.pth
Reconstruction error (MSE): 0.030496950392629587
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[46/100] checkpoints/checkpoint_460.pth
Reconstruction error (MSE): 0.030128193126005284
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[47/100] checkpoints/checkpoint_470.pth
Reconstruction error (MSE): 0.029998875262690527
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[48/100] checkpoints/checkpoint_480.pth
Reconstruction error (MSE): 0.029572404412662283
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[49/100] checkpoints/checkpoint_490.pth
Reconstruction error (MSE): 0.02939370559243595
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[50/100] checkpoints/checkpoint_500.pth
Reconstruction error (MSE): 0.02911538221321854
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[51/100] checkpoints/checkpoint_510.pth
Reconstruction error (MSE): 0.02889633548960966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.546
[52/100] checkpoints/checkpoint_520.pth
Reconstruction error (MSE): 0.02860628096262614
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[53/100] checkpoints/checkpoint_530.pth
Reconstruction error (MSE): 0.028405724600249645
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.54
[54/100] checkpoints/checkpoint_540.pth
Reconstruction error (MSE): 0.028084655219433353
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[55/100] checkpoints/checkpoint_550.pth
Reconstruction error (MSE): 0.02798689774905934
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[56/100] checkpoints/checkpoint_560.pth
Reconstruction error (MSE): 0.027731095856311276
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[57/100] checkpoints/checkpoint_570.pth
Reconstruction error (MSE): 0.027528591081207875
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[58/100] checkpoints/checkpoint_580.pth
Reconstruction error (MSE): 0.02748092877631094
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[59/100] checkpoints/checkpoint_590.pth
Reconstruction error (MSE): 0.027148202690423704
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[60/100] checkpoints/checkpoint_600.pth
Reconstruction error (MSE): 0.02693716204400156
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[61/100] checkpoints/checkpoint_610.pth
Reconstruction error (MSE): 0.02663602849548938
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.656
[62/100] checkpoints/checkpoint_620.pth
Reconstruction error (MSE): 0.026486863996468338
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[63/100] checkpoints/checkpoint_630.pth
Reconstruction error (MSE): 0.026279585034239526
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[64/100] checkpoints/checkpoint_640.pth
Reconstruction error (MSE): 0.02615043982337503
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.61
[65/100] checkpoints/checkpoint_650.pth
Reconstruction error (MSE): 0.025924385631785674
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[66/100] checkpoints/checkpoint_660.pth
Reconstruction error (MSE): 0.025687772582559023
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[67/100] checkpoints/checkpoint_670.pth
Reconstruction error (MSE): 0.025555453281776577
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.628
[68/100] checkpoints/checkpoint_680.pth
Reconstruction error (MSE): 0.025691534911884983
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[69/100] checkpoints/checkpoint_690.pth
Reconstruction error (MSE): 0.025101487271925984
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.708
[70/100] checkpoints/checkpoint_700.pth
Reconstruction error (MSE): 0.02504801980186911
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.732
[71/100] checkpoints/checkpoint_710.pth
Reconstruction error (MSE): 0.02484492769428328
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.752
[72/100] checkpoints/checkpoint_720.pth
Reconstruction error (MSE): 0.02478704075719796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[73/100] checkpoints/checkpoint_730.pth
Reconstruction error (MSE): 0.02446424291648117
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.608
[74/100] checkpoints/checkpoint_740.pth
Reconstruction error (MSE): 0.024349503433003145
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[75/100] checkpoints/checkpoint_750.pth
Reconstruction error (MSE): 0.02417324640236649
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.764
[76/100] checkpoints/checkpoint_760.pth
Reconstruction error (MSE): 0.024010706882850796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.726
[77/100] checkpoints/checkpoint_770.pth
Reconstruction error (MSE): 0.02394120900771197
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.762
[78/100] checkpoints/checkpoint_780.pth
Reconstruction error (MSE): 0.023713757514953613
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.55
[79/100] checkpoints/checkpoint_790.pth
Reconstruction error (MSE): 0.02374166191325468
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[80/100] checkpoints/checkpoint_800.pth
Reconstruction error (MSE): 0.023461397339315977
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.64
[81/100] checkpoints/checkpoint_810.pth
Reconstruction error (MSE): 0.023291605500613943
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.756
[82/100] checkpoints/checkpoint_820.pth
Reconstruction error (MSE): 0.023138159677094105
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.526
[83/100] checkpoints/checkpoint_830.pth
Reconstruction error (MSE): 0.02306466459760479
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[84/100] checkpoints/checkpoint_840.pth
Reconstruction error (MSE): 0.022922015835257138
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.664
[85/100] checkpoints/checkpoint_850.pth
Reconstruction error (MSE): 0.022727084767584706
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.67
[86/100] checkpoints/checkpoint_860.pth
Reconstruction error (MSE): 0.022709223756603166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[87/100] checkpoints/checkpoint_870.pth
Reconstruction error (MSE): 0.022506213861353257
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.542
[88/100] checkpoints/checkpoint_880.pth
Reconstruction error (MSE): 0.022330569921755323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.682
[89/100] checkpoints/checkpoint_890.pth
Reconstruction error (MSE): 0.022259797694636325
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[90/100] checkpoints/checkpoint_900.pth
Reconstruction error (MSE): 0.022161509541904226
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[91/100] checkpoints/checkpoint_910.pth
Reconstruction error (MSE): 0.022015575642679253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.712
[92/100] checkpoints/checkpoint_920.pth
Reconstruction error (MSE): 0.021944920754900166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.758
[93/100] checkpoints/checkpoint_930.pth
Reconstruction error (MSE): 0.021774335898605047
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[94/100] checkpoints/checkpoint_940.pth
Reconstruction error (MSE): 0.021657160151238534
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[95/100] checkpoints/checkpoint_950.pth
Reconstruction error (MSE): 0.021555810619803037
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.698
[96/100] checkpoints/checkpoint_960.pth
Reconstruction error (MSE): 0.021441521494996313
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[97/100] checkpoints/checkpoint_970.pth
Reconstruction error (MSE): 0.02138799679513071
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.802
[98/100] checkpoints/checkpoint_980.pth
Reconstruction error (MSE): 0.021166577629014558
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[99/100] checkpoints/checkpoint_990.pth
Reconstruction error (MSE): 0.02112917330685784
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.736
[100/100] checkpoints/checkpoint_1000.pth
Reconstruction error (MSE): 0.02094145799150654
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
ps = list(zip(*points))
plt.subplot(211, title='Reconstruction error (MSE)').plot(ps[0])
plt.subplot(212, title='Accuracy (val)').plot(ps[1])



