李宏毅(2020)作业9:无监督学习降维、聚类、自编码_clustering李宏毅-程序员宅基地

技术标签: 聚类  机器学习  无监督学习  自编码器  pca降维  

在这里插入图片描述

数据集

作业

任务1

请至少使用两种方法 (autoencoder 架构、optimizer、data preprocessing、后续降维方法、clustering 算法等等) 来改进 baseline code 的 accuracy。

  • 记录改进前、后的 accuracy 分别为多少。
  • 使用改进前、后的方法,分别将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。
    在这里插入图片描述

任务2

使用你 accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片。

  • 画出他们的原图以及 reconstruct 之后的图片。
    在这里插入图片描述

任务3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints。

  • 请用 model 的 reconstruction error (用所有的 trainX 计算 MSE) 和 val accuracy 对那些 checkpoints 作图。
    在这里插入图片描述

数据

请同学以 np.load() 读入资料,valX.npy 和 valY.npy 只用来检验我们的训练效果,不能用来训练。

trainX.npy

  • 里面总共有 8500 张 RGB 图片,大小都是 32 * 32 * 3
  • shape 为 (8500, 32, 32, 3)

valX.npy

  • 请不要用来训练
    • 里面总共有 500 张 RGB 图片,大小都是 32 * 32 * 3
    • shape 为 (500, 32, 32, 3)

valY.npy

  • 请不要用来训练
  • 对应 valX.npy 的 label
  • shape为 (500,)

下载数据集

创建 checkpoints文件夹

#!gdown --id '1BZb2AqOHHaad7Mo82St1qTBaXo_xtcUc' --output trainX.npy 
# !gdown --id '152NKCpj8S_zuIx3bQy0NN5oqpvBjdPIq' --output valX.npy 
# !gdown --id '1_hRGsFtm5KEazUg2ZvPZcuNScGF-ANh4' --output valY.npy 
!mkdir checkpoints
!ls
mkdir: 无法创建目录"checkpoints": 文件已存在
checkpoints	       trainX.npy
p1_baseline.png        valX.npy
prediction.csv	       valY.npy
prediction_invert.csv  李宏毅机器学习2020-作业9:无监督学习.ipynb

准备训练数据

定义我们的 preprocess:将图片的数值介于 0~255 的 int 线性转为 -1~1 的 float。

import numpy as np

def preprocess(image_list):
    """ Normalize Image and Permute (N,H,W,C) to (N,C,H,W)
    Args:
      image_list: List of images (9000, 32, 32, 3)
    Returns:
      image_list: List of images (9000, 3, 32, 32)
    """
    image_list = np.array(image_list)
    image_list = np.transpose(image_list, (0, 3, 1, 2))
    image_list = (image_list / 255.0) * 2 - 1
    image_list = image_list.astype(np.float32)
    return image_list

自定义Dataset

from torch.utils.data import Dataset

class Image_Dataset(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):
        images = self.image_list[idx]
        return images

将训练资料读入,并且 preprocess。之后我们将 preprocess 完的训练资料变成我们需要的 dataset。请同学不要使用 valX 和 valY 来训练。

from torch.utils.data import DataLoader

trainX = np.load('trainX.npy')
trainX_preprocessed = preprocess(trainX)
img_dataset = Image_Dataset(trainX_preprocessed)

一些工具函数

这边提供一些有用的 functions。一个是计算 model 参数量的(report 会用到),另一个是固定训练的随机种子(以便 reproduce)。

import random
import torch

def count_parameters(model, only_trainable=False):
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False #不做网络加速
    torch.backends.cudnn.deterministic = True #每次返回的卷积算法固定

模型

定义我们的 baseline autoencoder
ConvTranspose2d-逆卷积
在这里插入图片描述

关于模型的改进,我只是加深了一层encoder和decoder,效果会变好,参数的调整,只有epoch改为了1000

import torch.nn as nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )
 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 5, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 9, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 17, stride=1),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x  = self.decoder(x1)
        return x1, x
!nvidia-smi
Thu Nov  4 17:03:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce RTX 3090    Off  | 00000000:1A:00.0 Off |                  N/A |
| 57%   70C    P2   325W / 350W |   8107MiB / 24268MiB |     91%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:68:00.0 Off |                  N/A |
|  0%   29C    P8    25W / 350W |    299MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     11145      C   python                           8103MiB |
|    1   N/A  N/A      2432      G   /usr/lib/xorg/Xorg                 14MiB |
|    1   N/A  N/A      4006      G   /usr/bin/gnome-shell               17MiB |
|    1   N/A  N/A      4984      G   /usr/lib/xorg/Xorg                 70MiB |
|    1   N/A  N/A      5058      G   /usr/lib/xorg/Xorg                 18MiB |
|    1   N/A  N/A      5233      G   /usr/bin/gnome-shell              100MiB |
|    1   N/A  N/A      5384      G   /usr/bin/gnome-shell               36MiB |
|    1   N/A  N/A      6548      G   ...2179,14311511775341437302       36MiB |
+-----------------------------------------------------------------------------+

-----------------------------+

训练

这个部分就是主要的训练阶段。我们先将准备好的 dataset 当作参数喂给 dataloader。将 dataloader、model、loss criterion、optimizer 都准备好之后,就可以开始训练。训练完成后,我们会将 model 存下来。

import torch
from torch import optim

same_seeds(0)

model = AE().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

model.train()
n_epoch = 1000

# 准备 dataloader, model, loss criterion 和 optimizer
img_dataloader = DataLoader(img_dataset, batch_size=64, shuffle=True)

epoch_loss = 0

# 主要的训练过程
for epoch in range(n_epoch):
    epoch_loss = 0
    for data in img_dataloader:
        img = data
        img = img.cuda()

        output1, output = model(img)
        loss = criterion(output, img)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), './checkpoints/checkpoint_{}.pth'.format(epoch+1))
        
        epoch_loss += loss.item()
            
    print('epoch [{}/{}], loss:{:.5f}'.format(epoch+1, n_epoch, epoch_loss))

# 训练完成后存储 model
torch.save(model.state_dict(), './checkpoints/last_checkpoint.pth')
epoch [1/1000], loss:30.54165
epoch [2/1000], loss:26.34405
epoch [3/1000], loss:21.83250
epoch [4/1000], loss:19.13653
epoch [5/1000], loss:16.89123
epoch [6/1000], loss:15.81137
epoch [7/1000], loss:15.24495
epoch [8/1000], loss:14.82142
epoch [9/1000], loss:14.43517
epoch [10/1000], loss:14.08439
epoch [11/1000], loss:13.73920
epoch [12/1000], loss:13.40639
epoch [13/1000], loss:13.08327
epoch [14/1000], loss:12.66554
epoch [15/1000], loss:12.26715
epoch [16/1000], loss:11.93717
epoch [17/1000], loss:11.67487
epoch [18/1000], loss:11.45737
epoch [19/1000], loss:11.28208
epoch [20/1000], loss:11.08628
epoch [21/1000], loss:10.94622
epoch [22/1000], loss:10.80847
epoch [23/1000], loss:10.70417
epoch [24/1000], loss:10.58255
epoch [25/1000], loss:10.48495
epoch [26/1000], loss:10.39527
epoch [27/1000], loss:10.30006
epoch [28/1000], loss:10.20910
epoch [29/1000], loss:10.13124
epoch [30/1000], loss:10.04456
epoch [31/1000], loss:9.96836
epoch [32/1000], loss:9.88246
epoch [33/1000], loss:9.81235
epoch [34/1000], loss:9.72425
epoch [35/1000], loss:9.65545
epoch [36/1000], loss:9.57657
epoch [37/1000], loss:9.51310
epoch [38/1000], loss:9.45421
epoch [39/1000], loss:9.38250
epoch [40/1000], loss:9.31712
epoch [41/1000], loss:9.25833
epoch [42/1000], loss:9.20196
epoch [43/1000], loss:9.14868
epoch [44/1000], loss:9.08939
epoch [45/1000], loss:9.02597
epoch [46/1000], loss:8.95911
epoch [47/1000], loss:8.91480
epoch [48/1000], loss:8.86116
epoch [49/1000], loss:8.79443
epoch [50/1000], loss:8.73779
epoch [51/1000], loss:8.68570
epoch [52/1000], loss:8.62910
epoch [53/1000], loss:8.57338
epoch [54/1000], loss:8.53807
epoch [55/1000], loss:8.48156
epoch [56/1000], loss:8.43463
epoch [57/1000], loss:8.39641
epoch [58/1000], loss:8.34074
epoch [59/1000], loss:8.30465
epoch [60/1000], loss:8.27341
epoch [61/1000], loss:8.23230
epoch [62/1000], loss:8.18089
epoch [63/1000], loss:8.15129
epoch [64/1000], loss:8.11520
epoch [65/1000], loss:8.07959
epoch [66/1000], loss:8.04687
epoch [67/1000], loss:8.02380
epoch [68/1000], loss:7.98933
epoch [69/1000], loss:7.95649
epoch [70/1000], loss:7.92910
epoch [71/1000], loss:7.88972
epoch [72/1000], loss:7.85813
epoch [73/1000], loss:7.82851
epoch [74/1000], loss:7.81065
epoch [75/1000], loss:7.78497
epoch [76/1000], loss:7.73110
epoch [77/1000], loss:7.71461
epoch [78/1000], loss:7.68887
epoch [79/1000], loss:7.65523
epoch [80/1000], loss:7.63705
epoch [81/1000], loss:7.61096
epoch [82/1000], loss:7.57877
epoch [83/1000], loss:7.54703
epoch [84/1000], loss:7.52961
epoch [85/1000], loss:7.48876
epoch [86/1000], loss:7.46642
epoch [87/1000], loss:7.43804
epoch [88/1000], loss:7.41458
epoch [89/1000], loss:7.38298
epoch [90/1000], loss:7.38157
epoch [91/1000], loss:7.34053
epoch [92/1000], loss:7.32307
epoch [93/1000], loss:7.28897
epoch [94/1000], loss:7.27476
epoch [95/1000], loss:7.25432
epoch [96/1000], loss:7.23210
epoch [97/1000], loss:7.20764
epoch [98/1000], loss:7.17726
epoch [99/1000], loss:7.16785
epoch [100/1000], loss:7.14477
epoch [101/1000], loss:7.12776
epoch [102/1000], loss:7.10490
epoch [103/1000], loss:7.08108
epoch [104/1000], loss:7.06430
epoch [105/1000], loss:7.04382
epoch [106/1000], loss:7.01336
epoch [107/1000], loss:7.00099
epoch [108/1000], loss:6.97758
epoch [109/1000], loss:6.95376
epoch [110/1000], loss:6.94354
epoch [111/1000], loss:6.91744
epoch [112/1000], loss:6.91015
epoch [113/1000], loss:6.88055
epoch [114/1000], loss:6.86521
epoch [115/1000], loss:6.84671
epoch [116/1000], loss:6.82973
epoch [117/1000], loss:6.80817
epoch [118/1000], loss:6.78769
epoch [119/1000], loss:6.77140
epoch [120/1000], loss:6.76178
epoch [121/1000], loss:6.74296
epoch [122/1000], loss:6.71641
epoch [123/1000], loss:6.69564
epoch [124/1000], loss:6.67923
epoch [125/1000], loss:6.66339
epoch [126/1000], loss:6.64667
epoch [127/1000], loss:6.62993
epoch [128/1000], loss:6.60127
epoch [129/1000], loss:6.58229
epoch [130/1000], loss:6.57563
epoch [131/1000], loss:6.55139
epoch [132/1000], loss:6.53123
epoch [133/1000], loss:6.51448
epoch [134/1000], loss:6.49753
epoch [135/1000], loss:6.46827
epoch [136/1000], loss:6.45886
epoch [137/1000], loss:6.43451
epoch [138/1000], loss:6.41819
epoch [139/1000], loss:6.39429
epoch [140/1000], loss:6.38479
epoch [141/1000], loss:6.36964
epoch [142/1000], loss:6.34008
epoch [143/1000], loss:6.32599
epoch [144/1000], loss:6.30631
epoch [145/1000], loss:6.29071
epoch [146/1000], loss:6.27065
epoch [147/1000], loss:6.25629
epoch [148/1000], loss:6.23477
epoch [149/1000], loss:6.22027
epoch [150/1000], loss:6.20892
epoch [151/1000], loss:6.18379
epoch [152/1000], loss:6.16717
epoch [153/1000], loss:6.15294
epoch [154/1000], loss:6.13922
epoch [155/1000], loss:6.12273
epoch [156/1000], loss:6.09983
epoch [157/1000], loss:6.09613
epoch [158/1000], loss:6.08098
epoch [159/1000], loss:6.06648
epoch [160/1000], loss:6.05687
epoch [161/1000], loss:6.03163
epoch [162/1000], loss:6.00917
epoch [163/1000], loss:6.00572
epoch [164/1000], loss:5.99157
epoch [165/1000], loss:5.97707
epoch [166/1000], loss:5.96627
epoch [167/1000], loss:5.96171
epoch [168/1000], loss:5.93227
epoch [169/1000], loss:5.92656
epoch [170/1000], loss:5.92673
epoch [171/1000], loss:5.90135
epoch [172/1000], loss:5.89017
epoch [173/1000], loss:5.87263
epoch [174/1000], loss:5.86483
epoch [175/1000], loss:5.85099
epoch [176/1000], loss:5.83615
epoch [177/1000], loss:5.83101
epoch [178/1000], loss:5.82030
epoch [179/1000], loss:5.82544
epoch [180/1000], loss:5.78977
epoch [181/1000], loss:5.78293
epoch [182/1000], loss:5.77460
epoch [183/1000], loss:5.76192
epoch [184/1000], loss:5.75049
epoch [185/1000], loss:5.74188
epoch [186/1000], loss:5.73882
epoch [187/1000], loss:5.72205
epoch [188/1000], loss:5.70864
epoch [189/1000], loss:5.70273
epoch [190/1000], loss:5.69353
epoch [191/1000], loss:5.68343
epoch [192/1000], loss:5.67216
epoch [193/1000], loss:5.66239
epoch [194/1000], loss:5.65125
epoch [195/1000], loss:5.63932
epoch [196/1000], loss:5.63388
epoch [197/1000], loss:5.62116
epoch [198/1000], loss:5.61385
epoch [199/1000], loss:5.61483
epoch [200/1000], loss:5.59609
epoch [201/1000], loss:5.57955
epoch [202/1000], loss:5.57469
epoch [203/1000], loss:5.56383
epoch [204/1000], loss:5.55489
epoch [205/1000], loss:5.54320
epoch [206/1000], loss:5.52971
epoch [207/1000], loss:5.53083
epoch [208/1000], loss:5.51958
epoch [209/1000], loss:5.50594
epoch [210/1000], loss:5.50194
epoch [211/1000], loss:5.49632
epoch [212/1000], loss:5.47642
epoch [213/1000], loss:5.47105
epoch [214/1000], loss:5.46103
epoch [215/1000], loss:5.45136
epoch [216/1000], loss:5.44927
epoch [217/1000], loss:5.43372
epoch [218/1000], loss:5.43229
epoch [219/1000], loss:5.41293
epoch [220/1000], loss:5.40677
epoch [221/1000], loss:5.39713
epoch [222/1000], loss:5.39402
epoch [223/1000], loss:5.38856
epoch [224/1000], loss:5.37551
epoch [225/1000], loss:5.36045
epoch [226/1000], loss:5.35389
epoch [227/1000], loss:5.34672
epoch [228/1000], loss:5.33802
epoch [229/1000], loss:5.33105
epoch [230/1000], loss:5.32277
epoch [231/1000], loss:5.30828
epoch [232/1000], loss:5.29910
epoch [233/1000], loss:5.29399
epoch [234/1000], loss:5.28984
epoch [235/1000], loss:5.27597
epoch [236/1000], loss:5.26934
epoch [237/1000], loss:5.26663
epoch [238/1000], loss:5.25943
epoch [239/1000], loss:5.24395
epoch [240/1000], loss:5.24214
epoch [241/1000], loss:5.23017
epoch [242/1000], loss:5.21525
epoch [243/1000], loss:5.21001
epoch [244/1000], loss:5.20533
epoch [245/1000], loss:5.19778
epoch [246/1000], loss:5.19444
epoch [247/1000], loss:5.17834
epoch [248/1000], loss:5.17032
epoch [249/1000], loss:5.16573
epoch [250/1000], loss:5.16030
epoch [251/1000], loss:5.15691
epoch [252/1000], loss:5.14337
epoch [253/1000], loss:5.13357
epoch [254/1000], loss:5.12614
epoch [255/1000], loss:5.12397
epoch [256/1000], loss:5.11111
epoch [257/1000], loss:5.09905
epoch [258/1000], loss:5.09718
epoch [259/1000], loss:5.09271
epoch [260/1000], loss:5.08443
epoch [261/1000], loss:5.07630
epoch [262/1000], loss:5.06473
epoch [263/1000], loss:5.06329
epoch [264/1000], loss:5.05452
epoch [265/1000], loss:5.04306
epoch [266/1000], loss:5.04899
epoch [267/1000], loss:5.03139
epoch [268/1000], loss:5.02383
epoch [269/1000], loss:5.01982
epoch [270/1000], loss:5.01273
epoch [271/1000], loss:5.00642
epoch [272/1000], loss:4.99454
epoch [273/1000], loss:4.99690
epoch [274/1000], loss:4.98375
epoch [275/1000], loss:4.98370
epoch [276/1000], loss:4.96812
epoch [277/1000], loss:4.96210
epoch [278/1000], loss:4.96167
epoch [279/1000], loss:4.94264
epoch [280/1000], loss:4.94708
epoch [281/1000], loss:4.93381
epoch [282/1000], loss:4.92656
epoch [283/1000], loss:4.92751
epoch [284/1000], loss:4.91519
epoch [285/1000], loss:4.90649
epoch [286/1000], loss:4.90130
epoch [287/1000], loss:4.89965
epoch [288/1000], loss:4.88647
epoch [289/1000], loss:4.88522
epoch [290/1000], loss:4.87119
epoch [291/1000], loss:4.86967
epoch [292/1000], loss:4.86545
epoch [293/1000], loss:4.85670
epoch [294/1000], loss:4.84635
epoch [295/1000], loss:4.84253
epoch [296/1000], loss:4.84705
epoch [297/1000], loss:4.82709
epoch [298/1000], loss:4.82251
epoch [299/1000], loss:4.81915
epoch [300/1000], loss:4.81493
epoch [301/1000], loss:4.80140
epoch [302/1000], loss:4.79302
epoch [303/1000], loss:4.79099
epoch [304/1000], loss:4.78271
epoch [305/1000], loss:4.77509
epoch [306/1000], loss:4.76755
epoch [307/1000], loss:4.76485
epoch [308/1000], loss:4.76169
epoch [309/1000], loss:4.75328
epoch [310/1000], loss:4.74254
epoch [311/1000], loss:4.74224
epoch [312/1000], loss:4.74067
epoch [313/1000], loss:4.72933
epoch [314/1000], loss:4.71486
epoch [315/1000], loss:4.71784
epoch [316/1000], loss:4.70222
epoch [317/1000], loss:4.70290
epoch [318/1000], loss:4.69542
epoch [319/1000], loss:4.69025
epoch [320/1000], loss:4.68246
epoch [321/1000], loss:4.67295
epoch [322/1000], loss:4.67523
epoch [323/1000], loss:4.67207
epoch [324/1000], loss:4.66636
epoch [325/1000], loss:4.64616
epoch [326/1000], loss:4.64512
epoch [327/1000], loss:4.64286
epoch [328/1000], loss:4.63428
epoch [329/1000], loss:4.62759
epoch [330/1000], loss:4.62275
epoch [331/1000], loss:4.61570
epoch [332/1000], loss:4.61228
epoch [333/1000], loss:4.60109
epoch [334/1000], loss:4.60413
epoch [335/1000], loss:4.58950
epoch [336/1000], loss:4.59071
epoch [337/1000], loss:4.58295
epoch [338/1000], loss:4.57782
epoch [339/1000], loss:4.57129
epoch [340/1000], loss:4.56505
epoch [341/1000], loss:4.56037
epoch [342/1000], loss:4.55598
epoch [343/1000], loss:4.54537
epoch [344/1000], loss:4.54019
epoch [345/1000], loss:4.53571
epoch [346/1000], loss:4.53185
epoch [347/1000], loss:4.53183
epoch [348/1000], loss:4.52009
epoch [349/1000], loss:4.51411
epoch [350/1000], loss:4.50916
epoch [351/1000], loss:4.50595
epoch [352/1000], loss:4.50171
epoch [353/1000], loss:4.49431
epoch [354/1000], loss:4.48945
epoch [355/1000], loss:4.48904
epoch [356/1000], loss:4.47484
epoch [357/1000], loss:4.47601
epoch [358/1000], loss:4.46283
epoch [359/1000], loss:4.46043
epoch [360/1000], loss:4.45623
epoch [361/1000], loss:4.45144
epoch [387/1000], loss:4.32588
epoch [388/1000], loss:4.31738
epoch [389/1000], loss:4.31798
epoch [390/1000], loss:4.31714
epoch [391/1000], loss:4.30985
epoch [392/1000], loss:4.29957
epoch [393/1000], loss:4.29696
epoch [394/1000], loss:4.29420
epoch [395/1000], loss:4.28667
epoch [396/1000], loss:4.28612
epoch [397/1000], loss:4.27635
epoch [398/1000], loss:4.27332
epoch [399/1000], loss:4.27225
epoch [400/1000], loss:4.26569
epoch [401/1000], loss:4.26683
epoch [402/1000], loss:4.25562
epoch [403/1000], loss:4.24940
epoch [404/1000], loss:4.24415
epoch [405/1000], loss:4.24422
epoch [406/1000], loss:4.24053
epoch [407/1000], loss:4.23612
epoch [408/1000], loss:4.23212
epoch [409/1000], loss:4.23014
epoch [410/1000], loss:4.22054
epoch [411/1000], loss:4.21572
epoch [412/1000], loss:4.21339
epoch [413/1000], loss:4.20922
epoch [414/1000], loss:4.20910
epoch [415/1000], loss:4.20353
epoch [416/1000], loss:4.19610
epoch [417/1000], loss:4.19232
epoch [418/1000], loss:4.18926
epoch [419/1000], loss:4.18134
epoch [420/1000], loss:4.17638
epoch [421/1000], loss:4.17397
epoch [422/1000], loss:4.17142
epoch [423/1000], loss:4.16676
epoch [424/1000], loss:4.17102
epoch [425/1000], loss:4.15542
epoch [426/1000], loss:4.15438
epoch [427/1000], loss:4.15161
epoch [428/1000], loss:4.14431
epoch [429/1000], loss:4.14308
epoch [430/1000], loss:4.14248
epoch [431/1000], loss:4.13705
epoch [432/1000], loss:4.13069
epoch [433/1000], loss:4.12359
epoch [434/1000], loss:4.12440
epoch [435/1000], loss:4.12047
epoch [436/1000], loss:4.11715
epoch [437/1000], loss:4.11095
epoch [438/1000], loss:4.10556
epoch [439/1000], loss:4.10342
epoch [440/1000], loss:4.10314
epoch [441/1000], loss:4.09450
epoch [442/1000], loss:4.08683
epoch [443/1000], loss:4.08545
epoch [444/1000], loss:4.08673
epoch [445/1000], loss:4.07830
epoch [446/1000], loss:4.07518
epoch [447/1000], loss:4.06704
epoch [448/1000], loss:4.06815
epoch [449/1000], loss:4.06158
epoch [450/1000], loss:4.06410
epoch [451/1000], loss:4.05870
epoch [452/1000], loss:4.05462
epoch [453/1000], loss:4.04799
epoch [454/1000], loss:4.04455
epoch [455/1000], loss:4.03678
epoch [456/1000], loss:4.04038
epoch [457/1000], loss:4.03390
epoch [458/1000], loss:4.02727
epoch [459/1000], loss:4.02408
epoch [460/1000], loss:4.02337
epoch [461/1000], loss:4.01824
epoch [462/1000], loss:4.01433
epoch [463/1000], loss:4.00995
epoch [464/1000], loss:4.00826
epoch [465/1000], loss:4.00209
epoch [466/1000], loss:4.00384
epoch [467/1000], loss:3.99173
epoch [468/1000], loss:3.99856
epoch [469/1000], loss:3.99148
epoch [470/1000], loss:3.98304
epoch [471/1000], loss:3.98313
epoch [472/1000], loss:3.97725
epoch [473/1000], loss:3.97736
epoch [474/1000], loss:3.97326
epoch [475/1000], loss:3.96900
epoch [476/1000], loss:3.96096
epoch [477/1000], loss:3.96076
epoch [478/1000], loss:3.96005
epoch [479/1000], loss:3.95441
epoch [480/1000], loss:3.95287
epoch [481/1000], loss:3.94587
epoch [482/1000], loss:3.94024
epoch [483/1000], loss:3.93922
epoch [484/1000], loss:3.93559
epoch [485/1000], loss:3.93831
epoch [486/1000], loss:3.92520
epoch [487/1000], loss:3.92634
epoch [488/1000], loss:3.92151
epoch [489/1000], loss:3.91649
epoch [490/1000], loss:3.91573
epoch [491/1000], loss:3.91516
epoch [492/1000], loss:3.90679
epoch [493/1000], loss:3.90961
epoch [494/1000], loss:3.89975
epoch [495/1000], loss:3.89675
epoch [496/1000], loss:3.89311
epoch [497/1000], loss:3.89344
epoch [498/1000], loss:3.89109
epoch [499/1000], loss:3.88556
epoch [500/1000], loss:3.87982
epoch [501/1000], loss:3.87826
epoch [502/1000], loss:3.87651
epoch [503/1000], loss:3.87134
epoch [504/1000], loss:3.86625
epoch [505/1000], loss:3.86563
epoch [506/1000], loss:3.86109
epoch [507/1000], loss:3.86168
epoch [508/1000], loss:3.85732
epoch [509/1000], loss:3.84998
epoch [510/1000], loss:3.85233
epoch [511/1000], loss:3.84760
epoch [512/1000], loss:3.84713
epoch [513/1000], loss:3.83537
epoch [514/1000], loss:3.83900
epoch [515/1000], loss:3.82796
epoch [516/1000], loss:3.82622
epoch [517/1000], loss:3.83100
epoch [518/1000], loss:3.82413
epoch [519/1000], loss:3.81903
epoch [520/1000], loss:3.81732
epoch [521/1000], loss:3.81084
epoch [522/1000], loss:3.81144
epoch [523/1000], loss:3.80305
epoch [524/1000], loss:3.80411
epoch [525/1000], loss:3.80302
epoch [526/1000], loss:3.79430
epoch [527/1000], loss:3.79282
epoch [528/1000], loss:3.79408
epoch [529/1000], loss:3.79307
epoch [530/1000], loss:3.78673
epoch [531/1000], loss:3.78254
epoch [532/1000], loss:3.77649
epoch [533/1000], loss:3.77460
epoch [534/1000], loss:3.77207
epoch [535/1000], loss:3.76966
epoch [536/1000], loss:3.76757
epoch [537/1000], loss:3.76382
epoch [538/1000], loss:3.75726
epoch [539/1000], loss:3.76330
epoch [540/1000], loss:3.75130
epoch [541/1000], loss:3.74979
epoch [542/1000], loss:3.74968
epoch [543/1000], loss:3.73983
epoch [544/1000], loss:3.73901
epoch [545/1000], loss:3.73932
epoch [546/1000], loss:3.73718
epoch [547/1000], loss:3.73794
epoch [548/1000], loss:3.72818
epoch [549/1000], loss:3.72528
epoch [550/1000], loss:3.72475
epoch [551/1000], loss:3.71988
epoch [552/1000], loss:3.71729
epoch [553/1000], loss:3.71119
epoch [554/1000], loss:3.71207
epoch [555/1000], loss:3.71167
epoch [556/1000], loss:3.70275
epoch [557/1000], loss:3.70654
epoch [558/1000], loss:3.69792
epoch [559/1000], loss:3.69927
epoch [560/1000], loss:3.69409
epoch [561/1000], loss:3.69188
epoch [562/1000], loss:3.68632
epoch [563/1000], loss:3.68308
epoch [564/1000], loss:3.68161
epoch [565/1000], loss:3.68463
epoch [566/1000], loss:3.67181
epoch [567/1000], loss:3.67101
epoch [568/1000], loss:3.66956
epoch [569/1000], loss:3.66723
epoch [570/1000], loss:3.66829
epoch [571/1000], loss:3.66422
epoch [572/1000], loss:3.66120
epoch [573/1000], loss:3.65323
epoch [574/1000], loss:3.65280
epoch [575/1000], loss:3.65279
epoch [576/1000], loss:3.64698
epoch [577/1000], loss:3.64525
epoch [578/1000], loss:3.64385
epoch [579/1000], loss:3.63892
epoch [580/1000], loss:3.63570
epoch [581/1000], loss:3.63038
epoch [582/1000], loss:3.63306
epoch [583/1000], loss:3.62456
epoch [584/1000], loss:3.62961
epoch [585/1000], loss:3.61710
epoch [586/1000], loss:3.62218
epoch [587/1000], loss:3.61367
epoch [588/1000], loss:3.61351
epoch [589/1000], loss:3.61048
epoch [590/1000], loss:3.60863
epoch [591/1000], loss:3.60503
epoch [592/1000], loss:3.60068
epoch [593/1000], loss:3.59856
epoch [594/1000], loss:3.59472
epoch [595/1000], loss:3.59365
epoch [596/1000], loss:3.59324
epoch [597/1000], loss:3.58769
epoch [598/1000], loss:3.58214
epoch [599/1000], loss:3.58244
epoch [600/1000], loss:3.57799
epoch [601/1000], loss:3.57877
epoch [602/1000], loss:3.57055
epoch [603/1000], loss:3.57307
epoch [604/1000], loss:3.57202
epoch [605/1000], loss:3.56517
epoch [606/1000], loss:3.56280
epoch [607/1000], loss:3.56200
epoch [608/1000], loss:3.56267
epoch [609/1000], loss:3.55470
epoch [610/1000], loss:3.55250
epoch [611/1000], loss:3.54826
epoch [612/1000], loss:3.55154
epoch [613/1000], loss:3.54208
epoch [614/1000], loss:3.54206
epoch [615/1000], loss:3.54105
epoch [616/1000], loss:3.53665
epoch [617/1000], loss:3.53198
epoch [618/1000], loss:3.52956
epoch [619/1000], loss:3.52716
epoch [620/1000], loss:3.52535
epoch [621/1000], loss:3.52693
epoch [622/1000], loss:3.51926
epoch [623/1000], loss:3.51655
epoch [624/1000], loss:3.51352
epoch [625/1000], loss:3.51410
epoch [626/1000], loss:3.50871
epoch [627/1000], loss:3.50490
epoch [628/1000], loss:3.50470
epoch [629/1000], loss:3.50429
epoch [630/1000], loss:3.50063
epoch [631/1000], loss:3.49522
epoch [632/1000], loss:3.49489
epoch [633/1000], loss:3.49385
epoch [634/1000], loss:3.48804
epoch [635/1000], loss:3.48522
epoch [636/1000], loss:3.48331
epoch [637/1000], loss:3.47941
epoch [638/1000], loss:3.47592
epoch [639/1000], loss:3.47459
epoch [640/1000], loss:3.47359
epoch [641/1000], loss:3.47270
epoch [642/1000], loss:3.46967
epoch [643/1000], loss:3.46600
epoch [644/1000], loss:3.46549
epoch [645/1000], loss:3.46019
epoch [646/1000], loss:3.45748
epoch [647/1000], loss:3.45389
epoch [648/1000], loss:3.44896
epoch [649/1000], loss:3.44991
epoch [650/1000], loss:3.44311
epoch [651/1000], loss:3.44865
epoch [652/1000], loss:3.44133
epoch [653/1000], loss:3.43858
epoch [654/1000], loss:3.44189
epoch [655/1000], loss:3.43480
epoch [656/1000], loss:3.43255
epoch [657/1000], loss:3.42989
epoch [658/1000], loss:3.42864
epoch [659/1000], loss:3.42396
epoch [660/1000], loss:3.42112
epoch [661/1000], loss:3.42302
epoch [662/1000], loss:3.41736
epoch [663/1000], loss:3.41416
epoch [664/1000], loss:3.41132
epoch [665/1000], loss:3.41046
epoch [666/1000], loss:3.40492
epoch [667/1000], loss:3.40502
epoch [668/1000], loss:3.40614
epoch [669/1000], loss:3.40063
epoch [670/1000], loss:3.40028
epoch [671/1000], loss:3.39271
epoch [672/1000], loss:3.39536
epoch [673/1000], loss:3.39127
epoch [674/1000], loss:3.38746
epoch [675/1000], loss:3.38874
epoch [676/1000], loss:3.38427
epoch [677/1000], loss:3.38143
epoch [678/1000], loss:3.37742
epoch [679/1000], loss:3.37587
epoch [680/1000], loss:3.37513
epoch [681/1000], loss:3.37196
epoch [682/1000], loss:3.36916
epoch [683/1000], loss:3.36594
epoch [684/1000], loss:3.36606
epoch [685/1000], loss:3.36292
epoch [686/1000], loss:3.35892
epoch [687/1000], loss:3.35532
epoch [688/1000], loss:3.35597
epoch [689/1000], loss:3.35689
epoch [690/1000], loss:3.34953
epoch [691/1000], loss:3.34964
epoch [692/1000], loss:3.34474
epoch [693/1000], loss:3.34500
epoch [694/1000], loss:3.34074
epoch [695/1000], loss:3.34088
epoch [696/1000], loss:3.33748
epoch [697/1000], loss:3.33662
epoch [698/1000], loss:3.33202
epoch [699/1000], loss:3.33229
epoch [700/1000], loss:3.32739
epoch [701/1000], loss:3.32630
epoch [702/1000], loss:3.32807
epoch [703/1000], loss:3.32146
epoch [704/1000], loss:3.31806
epoch [705/1000], loss:3.31831
epoch [706/1000], loss:3.31332
epoch [707/1000], loss:3.31269
epoch [708/1000], loss:3.30964
epoch [709/1000], loss:3.30984
epoch [710/1000], loss:3.30538
epoch [711/1000], loss:3.30281
epoch [712/1000], loss:3.30262
epoch [713/1000], loss:3.29772
epoch [714/1000], loss:3.29625
epoch [715/1000], loss:3.29219
epoch [716/1000], loss:3.29506
epoch [717/1000], loss:3.28936
epoch [718/1000], loss:3.28897
epoch [719/1000], loss:3.29049
epoch [720/1000], loss:3.28375
epoch [721/1000], loss:3.28123
epoch [722/1000], loss:3.27900
epoch [723/1000], loss:3.27359
epoch [724/1000], loss:3.27611
epoch [725/1000], loss:3.27433
epoch [726/1000], loss:3.27112
epoch [727/1000], loss:3.26646
epoch [728/1000], loss:3.26737
epoch [729/1000], loss:3.26536
epoch [730/1000], loss:3.26612
epoch [731/1000], loss:3.26075
epoch [732/1000], loss:3.26027
epoch [733/1000], loss:3.25291
epoch [734/1000], loss:3.25916
epoch [735/1000], loss:3.24919
epoch [736/1000], loss:3.25470
epoch [737/1000], loss:3.24516
epoch [738/1000], loss:3.24314
epoch [739/1000], loss:3.24429
epoch [740/1000], loss:3.24261
epoch [741/1000], loss:3.23813
epoch [742/1000], loss:3.23578
epoch [743/1000], loss:3.23666
epoch [744/1000], loss:3.23200
epoch [745/1000], loss:3.23238
epoch [746/1000], loss:3.22988
epoch [747/1000], loss:3.22826
epoch [748/1000], loss:3.23023
epoch [749/1000], loss:3.22209
epoch [750/1000], loss:3.21966
epoch [751/1000], loss:3.21754
epoch [752/1000], loss:3.21620
epoch [753/1000], loss:3.21760
epoch [754/1000], loss:3.21165
epoch [755/1000], loss:3.21131
epoch [756/1000], loss:3.21038
epoch [757/1000], loss:3.20712
epoch [758/1000], loss:3.20317
epoch [759/1000], loss:3.20223
epoch [760/1000], loss:3.20180
epoch [761/1000], loss:3.20010
epoch [762/1000], loss:3.19946
epoch [763/1000], loss:3.19183
epoch [764/1000], loss:3.19291
epoch [765/1000], loss:3.18863
epoch [766/1000], loss:3.18918
epoch [767/1000], loss:3.18898
epoch [768/1000], loss:3.18414
epoch [769/1000], loss:3.18572
epoch [770/1000], loss:3.18738
epoch [771/1000], loss:3.17861
epoch [772/1000], loss:3.17652
epoch [773/1000], loss:3.17587
epoch [774/1000], loss:3.17144
epoch [775/1000], loss:3.17319
epoch [776/1000], loss:3.17009
epoch [777/1000], loss:3.16943
epoch [778/1000], loss:3.16559
epoch [779/1000], loss:3.16415
epoch [780/1000], loss:3.16417
epoch [781/1000], loss:3.16414
epoch [782/1000], loss:3.15878
epoch [783/1000], loss:3.15620
epoch [784/1000], loss:3.15162
epoch [785/1000], loss:3.15188
epoch [786/1000], loss:3.15056
epoch [787/1000], loss:3.14792
epoch [788/1000], loss:3.14884
epoch [789/1000], loss:3.14594
epoch [790/1000], loss:3.14544
epoch [791/1000], loss:3.14156
epoch [792/1000], loss:3.13851
epoch [793/1000], loss:3.13792
epoch [794/1000], loss:3.13770
epoch [795/1000], loss:3.13333
epoch [796/1000], loss:3.13036
epoch [797/1000], loss:3.12862
epoch [798/1000], loss:3.13088
epoch [799/1000], loss:3.12679
epoch [800/1000], loss:3.12329
epoch [801/1000], loss:3.12549
epoch [802/1000], loss:3.12244
epoch [803/1000], loss:3.11828
epoch [804/1000], loss:3.11357
epoch [805/1000], loss:3.11698
epoch [806/1000], loss:3.11326
epoch [807/1000], loss:3.11584
epoch [808/1000], loss:3.10921
epoch [809/1000], loss:3.10769
epoch [810/1000], loss:3.10721
epoch [811/1000], loss:3.10426
epoch [812/1000], loss:3.10207
epoch [813/1000], loss:3.09837
epoch [814/1000], loss:3.09836
epoch [815/1000], loss:3.09801
epoch [816/1000], loss:3.09438
epoch [817/1000], loss:3.09267
epoch [818/1000], loss:3.09224
epoch [819/1000], loss:3.08851
epoch [820/1000], loss:3.08578
epoch [821/1000], loss:3.08942
epoch [822/1000], loss:3.08425
epoch [823/1000], loss:3.08528
epoch [824/1000], loss:3.08140
epoch [825/1000], loss:3.07830
epoch [826/1000], loss:3.07588
epoch [827/1000], loss:3.07775
epoch [828/1000], loss:3.07456
epoch [829/1000], loss:3.07019
epoch [830/1000], loss:3.07405
epoch [831/1000], loss:3.06494
epoch [832/1000], loss:3.06572
epoch [833/1000], loss:3.06405
epoch [834/1000], loss:3.06366
epoch [835/1000], loss:3.05963
epoch [836/1000], loss:3.05978
epoch [837/1000], loss:3.05587
epoch [838/1000], loss:3.05641
epoch [839/1000], loss:3.05452
epoch [840/1000], loss:3.05307
epoch [841/1000], loss:3.04878
epoch [842/1000], loss:3.05134
epoch [843/1000], loss:3.04592
epoch [844/1000], loss:3.04432
epoch [845/1000], loss:3.04292
epoch [846/1000], loss:3.04020
epoch [847/1000], loss:3.04101
epoch [848/1000], loss:3.04131
epoch [849/1000], loss:3.03655
epoch [850/1000], loss:3.03434
epoch [851/1000], loss:3.03037
epoch [852/1000], loss:3.03011
epoch [853/1000], loss:3.03031
epoch [854/1000], loss:3.02658
epoch [855/1000], loss:3.02762
epoch [856/1000], loss:3.02805
epoch [857/1000], loss:3.02052
epoch [858/1000], loss:3.02101
epoch [859/1000], loss:3.01820
epoch [860/1000], loss:3.01740
epoch [861/1000], loss:3.01673
epoch [862/1000], loss:3.01265
epoch [863/1000], loss:3.00953
epoch [864/1000], loss:3.01045
epoch [865/1000], loss:3.00850
epoch [866/1000], loss:3.01031
epoch [867/1000], loss:3.00408
epoch [868/1000], loss:3.00111
epoch [869/1000], loss:3.00130
epoch [870/1000], loss:3.00163
epoch [871/1000], loss:2.99810
epoch [872/1000], loss:2.99874
epoch [873/1000], loss:2.99178
epoch [874/1000], loss:2.99280
epoch [875/1000], loss:2.99230
epoch [876/1000], loss:2.98815
epoch [877/1000], loss:2.98851
epoch [878/1000], loss:2.98612
epoch [879/1000], loss:2.98797
epoch [880/1000], loss:2.98337
epoch [881/1000], loss:2.98161
epoch [882/1000], loss:2.98003
epoch [883/1000], loss:2.97484
epoch [884/1000], loss:2.97611
epoch [885/1000], loss:2.97621
epoch [886/1000], loss:2.97396
epoch [887/1000], loss:2.96927
epoch [888/1000], loss:2.96680
epoch [889/1000], loss:2.96926
epoch [890/1000], loss:2.96575
epoch [891/1000], loss:2.96431
epoch [892/1000], loss:2.96193
epoch [893/1000], loss:2.95761
epoch [894/1000], loss:2.96028
epoch [895/1000], loss:2.96046
epoch [896/1000], loss:2.95814
epoch [897/1000], loss:2.95228
epoch [898/1000], loss:2.94921
epoch [899/1000], loss:2.95213
epoch [900/1000], loss:2.94890
epoch [901/1000], loss:2.94738
epoch [902/1000], loss:2.94390
epoch [903/1000], loss:2.94118
epoch [904/1000], loss:2.94426
epoch [905/1000], loss:2.94239
epoch [906/1000], loss:2.93883
epoch [907/1000], loss:2.93823
epoch [908/1000], loss:2.93640
epoch [909/1000], loss:2.93234
epoch [910/1000], loss:2.93235
epoch [911/1000], loss:2.92981
epoch [912/1000], loss:2.93039
epoch [913/1000], loss:2.93373
epoch [914/1000], loss:2.92795
epoch [915/1000], loss:2.92420
epoch [916/1000], loss:2.92136
epoch [917/1000], loss:2.91813
epoch [918/1000], loss:2.91754
epoch [919/1000], loss:2.91795
epoch [920/1000], loss:2.91643
epoch [921/1000], loss:2.91321
epoch [922/1000], loss:2.91369
epoch [923/1000], loss:2.91094
epoch [924/1000], loss:2.91049
epoch [925/1000], loss:2.90867
epoch [926/1000], loss:2.90595
epoch [927/1000], loss:2.90455
epoch [928/1000], loss:2.90523
epoch [929/1000], loss:2.90355
epoch [930/1000], loss:2.90085
epoch [931/1000], loss:2.89791
epoch [932/1000], loss:2.89439
epoch [933/1000], loss:2.89587
epoch [934/1000], loss:2.89358
epoch [935/1000], loss:2.89229
epoch [936/1000], loss:2.88939
epoch [937/1000], loss:2.89070
epoch [938/1000], loss:2.88834
epoch [939/1000], loss:2.88700
epoch [940/1000], loss:2.88633
epoch [941/1000], loss:2.88195
epoch [942/1000], loss:2.88308
epoch [943/1000], loss:2.87824
epoch [944/1000], loss:2.87709
epoch [945/1000], loss:2.87709
epoch [946/1000], loss:2.87699
epoch [947/1000], loss:2.87330
epoch [948/1000], loss:2.87141
epoch [949/1000], loss:2.87136
epoch [950/1000], loss:2.86982
epoch [951/1000], loss:2.86829
epoch [952/1000], loss:2.86615
epoch [953/1000], loss:2.86325
epoch [954/1000], loss:2.86094
epoch [955/1000], loss:2.86219
epoch [956/1000], loss:2.85894
epoch [957/1000], loss:2.86180
epoch [958/1000], loss:2.85887
epoch [959/1000], loss:2.85384
epoch [960/1000], loss:2.85410
epoch [961/1000], loss:2.85243
epoch [962/1000], loss:2.85051
epoch [963/1000], loss:2.84668
epoch [964/1000], loss:2.84494
epoch [965/1000], loss:2.84352
epoch [966/1000], loss:2.84500
epoch [967/1000], loss:2.84642
epoch [968/1000], loss:2.83922
epoch [969/1000], loss:2.83965
epoch [970/1000], loss:2.84072
epoch [971/1000], loss:2.83823
epoch [972/1000], loss:2.83543
epoch [973/1000], loss:2.83415
epoch [974/1000], loss:2.83639
epoch [975/1000], loss:2.82995
epoch [976/1000], loss:2.82914
epoch [977/1000], loss:2.82669
epoch [978/1000], loss:2.83094
epoch [979/1000], loss:2.82190
epoch [980/1000], loss:2.82548
epoch [981/1000], loss:2.82011
epoch [982/1000], loss:2.82137
epoch [983/1000], loss:2.81966
epoch [984/1000], loss:2.81743
epoch [985/1000], loss:2.81949
epoch [986/1000], loss:2.81346
epoch [987/1000], loss:2.81393
epoch [988/1000], loss:2.81204
epoch [989/1000], loss:2.81101
epoch [990/1000], loss:2.81068
epoch [991/1000], loss:2.80631
epoch [992/1000], loss:2.80828
epoch [993/1000], loss:2.80407
epoch [994/1000], loss:2.80417
epoch [995/1000], loss:2.80385
epoch [996/1000], loss:2.80122
epoch [997/1000], loss:2.80091
epoch [998/1000], loss:2.79750
epoch [999/1000], loss:2.79585
epoch [1000/1000], loss:2.79409

降维和聚类

import numpy as np

def cal_acc(gt, pred):
    """ Computes categorization accuracy of our task.
    Args:
      gt: Ground truth labels (9000, )
      pred: Predicted labels (9000, )
    Returns:
      acc: Accuracy (0~1 scalar)
    """
    # Calculate Correct predictions
    correct = np.sum(gt == pred)
    acc = correct / gt.shape[0]
    # 因为是binary unsupervised clustering,因此取max(acc,1-acc)# 因为我们只在乎有没有成功将图片分成两群
    return max(acc, 1-acc)
import matplotlib.pyplot as plt

def plot_scatter(feat, label, savefig=None):
    """ Plot Scatter Image.
    Args:
      feat: the (x, y) coordinate of clustering result, shape: (9000, 2)
      label: ground truth label of image (0/1), shape: (9000,)
    Returns:
      None
    """
    X = feat[:, 0]
    Y = feat[:, 1]
    plt.scatter(X, Y, c = label)
    plt.legend(loc='best')
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()
    return

接着我们使用训练好的 model,来预测 testing data 的类别。

由于 testing data 与 training data 一样,因此我们使用同样的 dataset 来实作 dataloader。与 training 不同的地方在于 shuffle 这个参数值在这边是 False。

准备好 model 与 dataloader,我们就可以进行预测了。

我们只需要 encoder 的结果(latents),利用 latents 进行 clustering 之后,就可以分类了。

import torch
from sklearn.decomposition import KernelPCA
#主成分分析(PCA)
#主成分分析(Principal Component Analysis)是目前为止最流行的降维算法。首先它找到接近数据集分布的超平面,然后将所有的数据都投影到这个超平面上。
#保留最大方差的超平面
# kPCA 是无监督学习算法,因此没有明显的性能指标可以帮助我们选择最佳的核和超参数值。不过,降维通常是监督学习任务(例如分类)的准备步骤.
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans

def inference(X, model, batch_size=256):
    X = preprocess(X)
    dataset = Image_Dataset(X)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    latents = []
    for i, x in enumerate(dataloader):
        #数据格式转换,以及取出相应格式的数据
        x = torch.FloatTensor(x)
        vec, img = model(x.cuda())
        if i == 0:
            #view()函数的功能根reshape类似,用来转换size大小。
          #x = x.view(batchsize, -1)中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。
            latents = vec.view(img.size()[0], -1).cpu().detach().numpy()
        else:
            latents = np.concatenate((latents, vec.view(img.size()[0], -1).cpu().detach().numpy()), axis = 0)
            #在零轴方向上合并
    print('Latents Shape:', latents.shape)
    return latents

def predict(latents):
    # First Dimension Reduction
    #这里用到的rbf核函数
    transformer = KernelPCA(n_components=200, kernel='rbf', n_jobs=-1)
    #n_components:  
    #意义:PCA算法中所要保留的主成分个数n,也即保留下来的特征个数n
    #n_jobs:int型变量,并行运行的个数。 
    #-1:使用所有CPU. n_jobs<-1时,使用(n_cpus+1+n_jobs)个CPU


    #transform函数是一定可以替换为fit_transform函数的
    #fit_transform函数不能替换为transform函数!
    #fit前缀只是方便后面API调用.
    kpca = transformer.fit_transform(latents)
    print('First Reduction Shape:', kpca.shape)

    # # Second Dimesnion Reduction
    X_embedded = TSNE(n_components=2).fit_transform(kpca)
    print('Second Reduction Shape:', X_embedded.shape)

    # Clustering
    #n_cluster:类中心的个数,默认为8
    #random_state:参数为int,RandomState instance or None.用来设置生成随机数的方式 
    pred = MiniBatchKMeans(n_clusters=2, random_state=0).fit(X_embedded)
    pred = [int(i) for i in pred.labels_]
    pred = np.array(pred)
    return pred, X_embedded

def invert(pred):
    return np.abs(1-pred)

def save_prediction(pred, out_csv='prediction.csv'):
    with open(out_csv, 'w') as f:
        f.write('id,label\n')
        for i, p in enumerate(pred):
            f.write(f'{
      i},{
      p}\n')
    print(f'Save prediction to {
      out_csv}.')

# load model
model = AE().cuda()
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()

# 准备 data
trainX = np.load('trainX.npy')

# 预测答案
latents = inference(X=trainX, model=model)
pred, X_embedded = predict(latents)

# 將预测結果存檔,上上传 kaggle
save_prediction(pred, 'prediction.csv')

# 由于是unsupervised的二分类问题,我们只在乎有没有成功将图片分成两群
# 如果上面的档案上传kaggle后正确率不足0.5,只要将label反过来就行了
save_prediction(invert(pred), 'prediction_invert.csv')
Latents Shape: (8500, 2048)
First Reduction Shape: (8500, 200)
Second Reduction Shape: (8500, 2)
Save prediction to prediction.csv.
Save prediction to prediction_invert.csv.

问题1(作图)

将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。

valX = np.load('valX.npy')
valY = np.load('valY.npy')

# ==============================================
#  我们示范basline model的作图,
#  report请同学另外还要再画一张improved model的图。
# ==============================================
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()
latents = inference(valX, model)
pred_from_latent, emb_from_latent = predict(latents)
acc_latent = cal_acc(valY, pred_from_latent)
print('The clustering accuracy is:', acc_latent)
print('The clustering result:')
plot_scatter(emb_from_latent, valY, savefig='p1_baseline.png')
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)


No handles with labels found to put in legend.


Second Reduction Shape: (500, 2)
The clustering accuracy is: 0.75
The clustering result:

在这里插入图片描述

问题2

使用你 test accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片 画出他们的原图以及 reconstruct 之后的图片

import matplotlib.pyplot as plt
import numpy as np

# 画出原图
plt.figure(figsize=(10,4))
indexes = [1,2,3,6,7,9]
imgs = trainX[indexes,]
for i, img in enumerate(imgs):
    plt.subplot(2, 6, i+1, xticks=[], yticks=[])
    plt.imshow(img)

# 画出 reconstruct 的图
inp = torch.Tensor(trainX_preprocessed[indexes,]).cuda()
latents, recs = model(inp)
recs = ((recs+1)/2 ).cpu().detach().numpy()
recs = recs.transpose(0, 2, 3, 1)
for i, img in enumerate(recs):
    plt.subplot(2, 6, 6+i+1, xticks=[], yticks=[])
    plt.imshow(img)
  
plt.tight_layout()

在这里插入图片描述

问题3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints 请用 model 的 train reconstruction error 对 val accuracy 作图 简单说明你观察到的现象

import os
import glob
checkpoints_list = sorted(glob.glob('checkpoints/checkpoint_*.pth'), key= lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))
print(checkpoints_list)
# load data
dataset = Image_Dataset(trainX_preprocessed)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

points = []
with torch.no_grad():
    for i, checkpoint in enumerate(checkpoints_list):
        print('[{}/{}] {}'.format(i+1, len(checkpoints_list), checkpoint))
        model.load_state_dict(torch.load(checkpoint))
        model.eval()
        err = 0
        n = 0
        for x in dataloader:
            x = x.cuda()
            _, rec = model(x)
            err += torch.nn.MSELoss(reduction='sum')(x, rec).item()
            n += x.flatten().size(0)
        print('Reconstruction error (MSE):', err/n)
        latents = inference(X=valX, model=model)
        pred, X_embedded = predict(latents)
        acc = cal_acc(valY, pred)
        print('Accuracy:', acc)
        points.append((err/n, acc))
['checkpoints/checkpoint_10.pth', 'checkpoints/checkpoint_20.pth', 'checkpoints/checkpoint_30.pth', 'checkpoints/checkpoint_40.pth', 'checkpoints/checkpoint_50.pth', 'checkpoints/checkpoint_60.pth', 'checkpoints/checkpoint_70.pth', 'checkpoints/checkpoint_80.pth', 'checkpoints/checkpoint_90.pth', 'checkpoints/checkpoint_100.pth', 'checkpoints/checkpoint_110.pth', 'checkpoints/checkpoint_120.pth', 'checkpoints/checkpoint_130.pth', 'checkpoints/checkpoint_140.pth', 'checkpoints/checkpoint_150.pth', 'checkpoints/checkpoint_160.pth', 'checkpoints/checkpoint_170.pth', 'checkpoints/checkpoint_180.pth', 'checkpoints/checkpoint_190.pth', 'checkpoints/checkpoint_200.pth', 'checkpoints/checkpoint_210.pth', 'checkpoints/checkpoint_220.pth', 'checkpoints/checkpoint_230.pth', 'checkpoints/checkpoint_240.pth', 'checkpoints/checkpoint_250.pth', 'checkpoints/checkpoint_260.pth', 'checkpoints/checkpoint_270.pth', 'checkpoints/checkpoint_280.pth', 'checkpoints/checkpoint_290.pth', 'checkpoints/checkpoint_300.pth', 'checkpoints/checkpoint_310.pth', 'checkpoints/checkpoint_320.pth', 'checkpoints/checkpoint_330.pth', 'checkpoints/checkpoint_340.pth', 'checkpoints/checkpoint_350.pth', 'checkpoints/checkpoint_360.pth', 'checkpoints/checkpoint_370.pth', 'checkpoints/checkpoint_380.pth', 'checkpoints/checkpoint_390.pth', 'checkpoints/checkpoint_400.pth', 'checkpoints/checkpoint_410.pth', 'checkpoints/checkpoint_420.pth', 'checkpoints/checkpoint_430.pth', 'checkpoints/checkpoint_440.pth', 'checkpoints/checkpoint_450.pth', 'checkpoints/checkpoint_460.pth', 'checkpoints/checkpoint_470.pth', 'checkpoints/checkpoint_480.pth', 'checkpoints/checkpoint_490.pth', 'checkpoints/checkpoint_500.pth', 'checkpoints/checkpoint_510.pth', 'checkpoints/checkpoint_520.pth', 'checkpoints/checkpoint_530.pth', 'checkpoints/checkpoint_540.pth', 'checkpoints/checkpoint_550.pth', 'checkpoints/checkpoint_560.pth', 'checkpoints/checkpoint_570.pth', 'checkpoints/checkpoint_580.pth', 'checkpoints/checkpoint_590.pth', 'checkpoints/checkpoint_600.pth', 'checkpoints/checkpoint_610.pth', 'checkpoints/checkpoint_620.pth', 'checkpoints/checkpoint_630.pth', 'checkpoints/checkpoint_640.pth', 'checkpoints/checkpoint_650.pth', 'checkpoints/checkpoint_660.pth', 'checkpoints/checkpoint_670.pth', 'checkpoints/checkpoint_680.pth', 'checkpoints/checkpoint_690.pth', 'checkpoints/checkpoint_700.pth', 'checkpoints/checkpoint_710.pth', 'checkpoints/checkpoint_720.pth', 'checkpoints/checkpoint_730.pth', 'checkpoints/checkpoint_740.pth', 'checkpoints/checkpoint_750.pth', 'checkpoints/checkpoint_760.pth', 'checkpoints/checkpoint_770.pth', 'checkpoints/checkpoint_780.pth', 'checkpoints/checkpoint_790.pth', 'checkpoints/checkpoint_800.pth', 'checkpoints/checkpoint_810.pth', 'checkpoints/checkpoint_820.pth', 'checkpoints/checkpoint_830.pth', 'checkpoints/checkpoint_840.pth', 'checkpoints/checkpoint_850.pth', 'checkpoints/checkpoint_860.pth', 'checkpoints/checkpoint_870.pth', 'checkpoints/checkpoint_880.pth', 'checkpoints/checkpoint_890.pth', 'checkpoints/checkpoint_900.pth', 'checkpoints/checkpoint_910.pth', 'checkpoints/checkpoint_920.pth', 'checkpoints/checkpoint_930.pth', 'checkpoints/checkpoint_940.pth', 'checkpoints/checkpoint_950.pth', 'checkpoints/checkpoint_960.pth', 'checkpoints/checkpoint_970.pth', 'checkpoints/checkpoint_980.pth', 'checkpoints/checkpoint_990.pth', 'checkpoints/checkpoint_1000.pth']
[1/100] checkpoints/checkpoint_10.pth
Reconstruction error (MSE): 0.10465650191961551
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.56
[2/100] checkpoints/checkpoint_20.pth
Reconstruction error (MSE): 0.08282024884691426
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.564
[3/100] checkpoints/checkpoint_30.pth
Reconstruction error (MSE): 0.07529751972123688
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[4/100] checkpoints/checkpoint_40.pth
Reconstruction error (MSE): 0.06996455570295745
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[5/100] checkpoints/checkpoint_50.pth
Reconstruction error (MSE): 0.06556768215403837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[6/100] checkpoints/checkpoint_60.pth
Reconstruction error (MSE): 0.0623151860704609
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[7/100] checkpoints/checkpoint_70.pth
Reconstruction error (MSE): 0.05941213181439568
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[8/100] checkpoints/checkpoint_80.pth
Reconstruction error (MSE): 0.057350127874636184
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[9/100] checkpoints/checkpoint_90.pth
Reconstruction error (MSE): 0.05522508699753705
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[10/100] checkpoints/checkpoint_100.pth
Reconstruction error (MSE): 0.05384457483478621
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[11/100] checkpoints/checkpoint_110.pth
Reconstruction error (MSE): 0.05195554022695504
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.5
[12/100] checkpoints/checkpoint_120.pth
Reconstruction error (MSE): 0.05074959627787272
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[13/100] checkpoints/checkpoint_130.pth
Reconstruction error (MSE): 0.04992709094402837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[14/100] checkpoints/checkpoint_140.pth
Reconstruction error (MSE): 0.04817914684146058
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[15/100] checkpoints/checkpoint_150.pth
Reconstruction error (MSE): 0.04657277587815827
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.512
[16/100] checkpoints/checkpoint_160.pth
Reconstruction error (MSE): 0.045626810316945994
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[17/100] checkpoints/checkpoint_170.pth
Reconstruction error (MSE): 0.04440261214387183
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.548
[18/100] checkpoints/checkpoint_180.pth
Reconstruction error (MSE): 0.04345548491384469
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[19/100] checkpoints/checkpoint_190.pth
Reconstruction error (MSE): 0.04282478637321323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[20/100] checkpoints/checkpoint_200.pth
Reconstruction error (MSE): 0.042173867076051
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[21/100] checkpoints/checkpoint_210.pth
Reconstruction error (MSE): 0.041361579988517014
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[22/100] checkpoints/checkpoint_220.pth
Reconstruction error (MSE): 0.040615920683916874
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[23/100] checkpoints/checkpoint_230.pth
Reconstruction error (MSE): 0.039873808785980826
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[24/100] checkpoints/checkpoint_240.pth
Reconstruction error (MSE): 0.03932966136932373
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[25/100] checkpoints/checkpoint_250.pth
Reconstruction error (MSE): 0.038771576432620775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[26/100] checkpoints/checkpoint_260.pth
Reconstruction error (MSE): 0.0381339080099966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[27/100] checkpoints/checkpoint_270.pth
Reconstruction error (MSE): 0.03751208638209923
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[28/100] checkpoints/checkpoint_280.pth
Reconstruction error (MSE): 0.037052626366708794
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[29/100] checkpoints/checkpoint_290.pth
Reconstruction error (MSE): 0.03666375287373861
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[30/100] checkpoints/checkpoint_300.pth
Reconstruction error (MSE): 0.03611967169069776
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[31/100] checkpoints/checkpoint_310.pth
Reconstruction error (MSE): 0.03564991631227381
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[32/100] checkpoints/checkpoint_320.pth
Reconstruction error (MSE): 0.035199516689076144
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[33/100] checkpoints/checkpoint_330.pth
Reconstruction error (MSE): 0.034691137108148314
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[34/100] checkpoints/checkpoint_340.pth
Reconstruction error (MSE): 0.03432022960513246
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[35/100] checkpoints/checkpoint_350.pth
Reconstruction error (MSE): 0.033855246824376725
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[36/100] checkpoints/checkpoint_360.pth
Reconstruction error (MSE): 0.03339220189113243
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[37/100] checkpoints/checkpoint_370.pth
Reconstruction error (MSE): 0.03329624884736304
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[38/100] checkpoints/checkpoint_380.pth
Reconstruction error (MSE): 0.03264928217495189
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[39/100] checkpoints/checkpoint_390.pth
Reconstruction error (MSE): 0.03237577991859586
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[40/100] checkpoints/checkpoint_400.pth
Reconstruction error (MSE): 0.03208851829229617
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[41/100] checkpoints/checkpoint_410.pth
Reconstruction error (MSE): 0.0316866933037253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[42/100] checkpoints/checkpoint_420.pth
Reconstruction error (MSE): 0.031364078933117434
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[43/100] checkpoints/checkpoint_430.pth
Reconstruction error (MSE): 0.031136608348173254
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[44/100] checkpoints/checkpoint_440.pth
Reconstruction error (MSE): 0.0309632776297775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[45/100] checkpoints/checkpoint_450.pth
Reconstruction error (MSE): 0.030496950392629587
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[46/100] checkpoints/checkpoint_460.pth
Reconstruction error (MSE): 0.030128193126005284
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[47/100] checkpoints/checkpoint_470.pth
Reconstruction error (MSE): 0.029998875262690527
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[48/100] checkpoints/checkpoint_480.pth
Reconstruction error (MSE): 0.029572404412662283
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[49/100] checkpoints/checkpoint_490.pth
Reconstruction error (MSE): 0.02939370559243595
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[50/100] checkpoints/checkpoint_500.pth
Reconstruction error (MSE): 0.02911538221321854
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[51/100] checkpoints/checkpoint_510.pth
Reconstruction error (MSE): 0.02889633548960966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.546
[52/100] checkpoints/checkpoint_520.pth
Reconstruction error (MSE): 0.02860628096262614
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[53/100] checkpoints/checkpoint_530.pth
Reconstruction error (MSE): 0.028405724600249645
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.54
[54/100] checkpoints/checkpoint_540.pth
Reconstruction error (MSE): 0.028084655219433353
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[55/100] checkpoints/checkpoint_550.pth
Reconstruction error (MSE): 0.02798689774905934
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[56/100] checkpoints/checkpoint_560.pth
Reconstruction error (MSE): 0.027731095856311276
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[57/100] checkpoints/checkpoint_570.pth
Reconstruction error (MSE): 0.027528591081207875
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[58/100] checkpoints/checkpoint_580.pth
Reconstruction error (MSE): 0.02748092877631094
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[59/100] checkpoints/checkpoint_590.pth
Reconstruction error (MSE): 0.027148202690423704
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[60/100] checkpoints/checkpoint_600.pth
Reconstruction error (MSE): 0.02693716204400156
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[61/100] checkpoints/checkpoint_610.pth
Reconstruction error (MSE): 0.02663602849548938
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.656
[62/100] checkpoints/checkpoint_620.pth
Reconstruction error (MSE): 0.026486863996468338
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[63/100] checkpoints/checkpoint_630.pth
Reconstruction error (MSE): 0.026279585034239526
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[64/100] checkpoints/checkpoint_640.pth
Reconstruction error (MSE): 0.02615043982337503
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.61
[65/100] checkpoints/checkpoint_650.pth
Reconstruction error (MSE): 0.025924385631785674
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[66/100] checkpoints/checkpoint_660.pth
Reconstruction error (MSE): 0.025687772582559023
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[67/100] checkpoints/checkpoint_670.pth
Reconstruction error (MSE): 0.025555453281776577
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.628
[68/100] checkpoints/checkpoint_680.pth
Reconstruction error (MSE): 0.025691534911884983
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[69/100] checkpoints/checkpoint_690.pth
Reconstruction error (MSE): 0.025101487271925984
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.708
[70/100] checkpoints/checkpoint_700.pth
Reconstruction error (MSE): 0.02504801980186911
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.732
[71/100] checkpoints/checkpoint_710.pth
Reconstruction error (MSE): 0.02484492769428328
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.752
[72/100] checkpoints/checkpoint_720.pth
Reconstruction error (MSE): 0.02478704075719796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[73/100] checkpoints/checkpoint_730.pth
Reconstruction error (MSE): 0.02446424291648117
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.608
[74/100] checkpoints/checkpoint_740.pth
Reconstruction error (MSE): 0.024349503433003145
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[75/100] checkpoints/checkpoint_750.pth
Reconstruction error (MSE): 0.02417324640236649
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.764
[76/100] checkpoints/checkpoint_760.pth
Reconstruction error (MSE): 0.024010706882850796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.726
[77/100] checkpoints/checkpoint_770.pth
Reconstruction error (MSE): 0.02394120900771197
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.762
[78/100] checkpoints/checkpoint_780.pth
Reconstruction error (MSE): 0.023713757514953613
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.55
[79/100] checkpoints/checkpoint_790.pth
Reconstruction error (MSE): 0.02374166191325468
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[80/100] checkpoints/checkpoint_800.pth
Reconstruction error (MSE): 0.023461397339315977
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.64
[81/100] checkpoints/checkpoint_810.pth
Reconstruction error (MSE): 0.023291605500613943
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.756
[82/100] checkpoints/checkpoint_820.pth
Reconstruction error (MSE): 0.023138159677094105
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.526
[83/100] checkpoints/checkpoint_830.pth
Reconstruction error (MSE): 0.02306466459760479
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[84/100] checkpoints/checkpoint_840.pth
Reconstruction error (MSE): 0.022922015835257138
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.664
[85/100] checkpoints/checkpoint_850.pth
Reconstruction error (MSE): 0.022727084767584706
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.67
[86/100] checkpoints/checkpoint_860.pth
Reconstruction error (MSE): 0.022709223756603166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[87/100] checkpoints/checkpoint_870.pth
Reconstruction error (MSE): 0.022506213861353257
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.542
[88/100] checkpoints/checkpoint_880.pth
Reconstruction error (MSE): 0.022330569921755323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.682
[89/100] checkpoints/checkpoint_890.pth
Reconstruction error (MSE): 0.022259797694636325
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[90/100] checkpoints/checkpoint_900.pth
Reconstruction error (MSE): 0.022161509541904226
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[91/100] checkpoints/checkpoint_910.pth
Reconstruction error (MSE): 0.022015575642679253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.712
[92/100] checkpoints/checkpoint_920.pth
Reconstruction error (MSE): 0.021944920754900166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.758
[93/100] checkpoints/checkpoint_930.pth
Reconstruction error (MSE): 0.021774335898605047
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[94/100] checkpoints/checkpoint_940.pth
Reconstruction error (MSE): 0.021657160151238534
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[95/100] checkpoints/checkpoint_950.pth
Reconstruction error (MSE): 0.021555810619803037
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.698
[96/100] checkpoints/checkpoint_960.pth
Reconstruction error (MSE): 0.021441521494996313
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[97/100] checkpoints/checkpoint_970.pth
Reconstruction error (MSE): 0.02138799679513071
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.802
[98/100] checkpoints/checkpoint_980.pth
Reconstruction error (MSE): 0.021166577629014558
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[99/100] checkpoints/checkpoint_990.pth
Reconstruction error (MSE): 0.02112917330685784
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.736
[100/100] checkpoints/checkpoint_1000.pth
Reconstruction error (MSE): 0.02094145799150654
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
ps = list(zip(*points))
plt.figure(figsize=(6,6))
plt.subplot(211, title='Reconstruction error (MSE)').plot(ps[0])
plt.subplot(212, title='Accuracy (val)').plot(ps[1])
plt.show()

在这里插入图片描述

精度抖动相当剧烈,无监督果然难train

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/wl1780852311/article/details/121149036

智能推荐

稀疏编码的数学基础与理论分析-程序员宅基地

文章浏览阅读290次,点赞8次,收藏10次。1.背景介绍稀疏编码是一种用于处理稀疏数据的编码技术,其主要应用于信息传输、存储和处理等领域。稀疏数据是指数据中大部分元素为零或近似于零的数据,例如文本、图像、音频、视频等。稀疏编码的核心思想是将稀疏数据表示为非零元素和它们对应的位置信息,从而减少存储空间和计算复杂度。稀疏编码的研究起源于1990年代,随着大数据时代的到来,稀疏编码技术的应用范围和影响力不断扩大。目前,稀疏编码已经成为计算...

EasyGBS国标流媒体服务器GB28181国标方案安装使用文档-程序员宅基地

文章浏览阅读217次。EasyGBS - GB28181 国标方案安装使用文档下载安装包下载,正式使用需商业授权, 功能一致在线演示在线API架构图EasySIPCMSSIP 中心信令服务, 单节点, 自带一个 Redis Server, 随 EasySIPCMS 自启动, 不需要手动运行EasySIPSMSSIP 流媒体服务, 根..._easygbs-windows-2.6.0-23042316使用文档

【Web】记录巅峰极客2023 BabyURL题目复现——Jackson原生链_原生jackson 反序列化链子-程序员宅基地

文章浏览阅读1.2k次,点赞27次,收藏7次。2023巅峰极客 BabyURL之前AliyunCTF Bypassit I这题考查了这样一条链子:其实就是Jackson的原生反序列化利用今天复现的这题也是大同小异,一起来整一下。_原生jackson 反序列化链子

一文搞懂SpringCloud,详解干货,做好笔记_spring cloud-程序员宅基地

文章浏览阅读734次,点赞9次,收藏7次。微服务架构简单的说就是将单体应用进一步拆分,拆分成更小的服务,每个服务都是一个可以独立运行的项目。这么多小服务,如何管理他们?(服务治理 注册中心[服务注册 发现 剔除])这么多小服务,他们之间如何通讯?这么多小服务,客户端怎么访问他们?(网关)这么多小服务,一旦出现问题了,应该如何自处理?(容错)这么多小服务,一旦出现问题了,应该如何排错?(链路追踪)对于上面的问题,是任何一个微服务设计者都不能绕过去的,因此大部分的微服务产品都针对每一个问题提供了相应的组件来解决它们。_spring cloud

Js实现图片点击切换与轮播-程序员宅基地

文章浏览阅读5.9k次,点赞6次,收藏20次。Js实现图片点击切换与轮播图片点击切换<!DOCTYPE html><html> <head> <meta charset="UTF-8"> <title></title> <script type="text/ja..._点击图片进行轮播图切换

tensorflow-gpu版本安装教程(过程详细)_tensorflow gpu版本安装-程序员宅基地

文章浏览阅读10w+次,点赞245次,收藏1.5k次。在开始安装前,如果你的电脑装过tensorflow,请先把他们卸载干净,包括依赖的包(tensorflow-estimator、tensorboard、tensorflow、keras-applications、keras-preprocessing),不然后续安装了tensorflow-gpu可能会出现找不到cuda的问题。cuda、cudnn。..._tensorflow gpu版本安装

随便推点

物联网时代 权限滥用漏洞的攻击及防御-程序员宅基地

文章浏览阅读243次。0x00 简介权限滥用漏洞一般归类于逻辑问题,是指服务端功能开放过多或权限限制不严格,导致攻击者可以通过直接或间接调用的方式达到攻击效果。随着物联网时代的到来,这种漏洞已经屡见不鲜,各种漏洞组合利用也是千奇百怪、五花八门,这里总结漏洞是为了更好地应对和预防,如有不妥之处还请业内人士多多指教。0x01 背景2014年4月,在比特币飞涨的时代某网站曾经..._使用物联网漏洞的使用者

Visual Odometry and Depth Calculation--Epipolar Geometry--Direct Method--PnP_normalized plane coordinates-程序员宅基地

文章浏览阅读786次。A. Epipolar geometry and triangulationThe epipolar geometry mainly adopts the feature point method, such as SIFT, SURF and ORB, etc. to obtain the feature points corresponding to two frames of images. As shown in Figure 1, let the first image be ​ and th_normalized plane coordinates

开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先抽取关系)_语义角色增强的关系抽取-程序员宅基地

文章浏览阅读708次,点赞2次,收藏3次。开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先关系再实体)一.第二代开放信息抽取系统背景​ 第一代开放信息抽取系统(Open Information Extraction, OIE, learning-based, 自学习, 先抽取实体)通常抽取大量冗余信息,为了消除这些冗余信息,诞生了第二代开放信息抽取系统。二.第二代开放信息抽取系统历史第二代开放信息抽取系统着眼于解决第一代系统的三大问题: 大量非信息性提取(即省略关键信息的提取)、_语义角色增强的关系抽取

10个顶尖响应式HTML5网页_html欢迎页面-程序员宅基地

文章浏览阅读1.1w次,点赞6次,收藏51次。快速完成网页设计,10个顶尖响应式HTML5网页模板助你一臂之力为了寻找一个优质的网页模板,网页设计师和开发者往往可能会花上大半天的时间。不过幸运的是,现在的网页设计师和开发人员已经开始共享HTML5,Bootstrap和CSS3中的免费网页模板资源。鉴于网站模板的灵活性和强大的功能,现在广大设计师和开发者对html5网站的实际需求日益增长。为了造福大众,Mockplus的小伙伴整理了2018年最..._html欢迎页面

计算机二级 考试科目,2018全国计算机等级考试调整,一、二级都增加了考试科目...-程序员宅基地

文章浏览阅读282次。原标题:2018全国计算机等级考试调整,一、二级都增加了考试科目全国计算机等级考试将于9月15-17日举行。在备考的最后冲刺阶段,小编为大家整理了今年新公布的全国计算机等级考试调整方案,希望对备考的小伙伴有所帮助,快随小编往下看吧!从2018年3月开始,全国计算机等级考试实施2018版考试大纲,并按新体系开考各个考试级别。具体调整内容如下:一、考试级别及科目1.一级新增“网络安全素质教育”科目(代..._计算机二级增报科目什么意思

conan简单使用_apt install conan-程序员宅基地

文章浏览阅读240次。conan简单使用。_apt install conan