|
|
|
import argparse |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
from torch.autograd import Variable |
|
|
|
from tqdm import tqdm |
|
|
|
from models.protonet_embedding import ProtoNetEmbedding |
|
from models.R2D2_embedding import R2D2Embedding |
|
from models.ResNet12_embedding import resnet12 |
|
|
|
from models.classification_heads import ClassificationHead |
|
|
|
from utils import pprint, set_gpu, Timer, count_accuracy, log |
|
from sklearn.metrics import confusion_matrix, f1_score, roc_curve, auc |
|
import scikitplot as skplt |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import numpy as np |
|
import os |
|
import random |
|
|
|
import pickle |
|
|
|
from dataloader.chest import label_dict |
|
|
|
|
|
import pandas as pd |
|
|
|
def multiclass_roc(y_test, y_score,n_classes = 3): |
|
|
|
|
|
|
|
fpr = dict() |
|
tpr = dict() |
|
roc_auc = dict() |
|
|
|
|
|
y_test_dummies = pd.get_dummies(y_test, drop_first=False).values |
|
for i in range(n_classes): |
|
fpr[i], tpr[i], _ = roc_curve(y_test_dummies[:, i], y_score[:, i]) |
|
roc_auc[i] = auc(fpr[i], tpr[i]) |
|
|
|
return fpr,tpr,roc_auc |
|
|
|
|
|
|
|
def seed_everything(seed: int): |
|
random.seed(seed) |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
def euclidean_dist(x, y): |
|
|
|
|
|
|
|
n = x.size(0) |
|
m = y.size(0) |
|
d = x.size(1) |
|
|
|
assert d == y.size(1) |
|
|
|
x = x.unsqueeze(1).expand(n, m, d) |
|
y = y.unsqueeze(0).expand(n, m, d) |
|
|
|
return torch.pow(x - y, 2).sum(2) |
|
|
|
def flip(x, dim): |
|
xsize = x.size() |
|
dim = x.dim() + dim if dim < 0 else dim |
|
x = x.view(-1, *xsize[dim:]) |
|
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, |
|
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] |
|
return x.view(xsize) |
|
|
|
|
|
def get_model(options): |
|
|
|
if options.network == 'ProtoNet': |
|
network = ProtoNetEmbedding().cuda() |
|
elif options.network == 'R2D2': |
|
network = R2D2Embedding().cuda() |
|
elif options.network == 'ResNet': |
|
if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet': |
|
network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5,num_layer=options.num_layer).cuda() |
|
network = torch.nn.DataParallel(network) |
|
else: |
|
network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2,num_layer=options.num_layer).cuda() |
|
else: |
|
print ("Cannot recognize the network type") |
|
assert(False) |
|
|
|
|
|
if opt.head == 'ProtoNet': |
|
cls_head = ClassificationHead(base_learner='ProtoNet').cuda() |
|
elif options.head == 'SubspaceTrans': |
|
cls_head = ClassificationHead(base_learner='SubspaceTrans').cuda() |
|
elif options.head == 'Subspace': |
|
cls_head = ClassificationHead(base_learner='Subspace').cuda() |
|
elif options.head == 'SubspaceFast': |
|
cls_head = ClassificationHead(base_learner='SubspaceFast').cuda() |
|
elif opt.head == 'Ridge': |
|
cls_head = ClassificationHead(base_learner='Ridge').cuda() |
|
elif opt.head == 'R2D2': |
|
cls_head = ClassificationHead(base_learner='R2D2').cuda() |
|
elif opt.head == 'SVM': |
|
cls_head = ClassificationHead(base_learner='SVM-CS').cuda() |
|
else: |
|
print ("Cannot recognize the classification head type") |
|
assert(False) |
|
|
|
return (network, cls_head) |
|
|
|
def get_dataset(options): |
|
|
|
if options.dataset == 'miniImageNet': |
|
from dataloader.mini_imagenet import MiniImageNet, FewShotDataloader |
|
dataset_test = MiniImageNet(phase='test') |
|
data_loader = FewShotDataloader |
|
elif options.dataset == 'tieredImageNet': |
|
from dataloader.tiered_imagenet import tieredImageNet, FewShotDataloader |
|
dataset_test = tieredImageNet(phase='test') |
|
data_loader = FewShotDataloader |
|
elif options.dataset == 'CIFAR_FS': |
|
from dataloader.CIFAR_FS import CIFAR_FS, FewShotDataloader |
|
dataset_test = CIFAR_FS(phase='test') |
|
data_loader = FewShotDataloader |
|
elif options.dataset == 'FC100': |
|
from dataloader.FC100 import FC100, FewShotDataloader |
|
dataset_test = FC100(phase='test') |
|
data_loader = FewShotDataloader |
|
elif options.dataset == 'Chest': |
|
from dataloader.chest import Chest, FewShotDataloader |
|
dataset_test = Chest(phase='test') |
|
data_loader = FewShotDataloader |
|
else: |
|
print ("Cannot recognize the dataset type") |
|
assert(False) |
|
|
|
return (dataset_test, data_loader) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--gpu', default='3') |
|
|
|
parser.add_argument('--load', |
|
default='experiments/group2_subspace30_CE_train/best_model.pth', |
|
help='path of the checkpoint file') |
|
|
|
parser.add_argument('--num_layer', type=int, default=30, |
|
help='num of layer') |
|
|
|
parser.add_argument('--episode', type=int, default=1000, |
|
help='number of episodes to test') |
|
parser.add_argument('--way', type=int, default=3, |
|
help='number of classes in one test episode') |
|
parser.add_argument('--shot', type=int, default=5, |
|
help='number of support examples per training class') |
|
parser.add_argument('--query', type=int, default=5, |
|
help='number of query examples per training class') |
|
parser.add_argument('--network', type=str, default='ResNet', |
|
help='choose which embedding network to use. ProtoNet, R2D2, ResNet') |
|
parser.add_argument('--head', type=str, default='Subspace', |
|
help='choose which embedding network to use. ProtoNet, Ridge, R2D2, SVM') |
|
parser.add_argument('--dataset', type=str, default='Chest', |
|
help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') |
|
|
|
|
|
opt = parser.parse_args() |
|
|
|
seed_everything(42) |
|
|
|
(dataset_test, data_loader) = get_dataset(opt) |
|
|
|
set_gpu(opt.gpu) |
|
|
|
|
|
(embedding_net, cls_head) = get_model(opt) |
|
|
|
|
|
saved_models = torch.load(opt.load) |
|
embedding_net.load_state_dict(saved_models['embedding']) |
|
embedding_net.eval() |
|
cls_head.load_state_dict(saved_models['head']) |
|
cls_head.eval() |
|
|
|
|
|
aug=False |
|
|
|
label_dict_inv = {v:k for k,v in label_dict.items()} |
|
|
|
test_accuracies = [] |
|
per_class_accuracies = [] |
|
y_pred_list = [] |
|
y_list = [] |
|
dloader_test = data_loader( |
|
dataset=dataset_test, |
|
nKnovel=opt.way, |
|
nKbase=0, |
|
nExemplars=opt.shot, |
|
nTestNovel=opt.query * opt.way, |
|
nTestBase=0, |
|
batch_size=1, |
|
num_workers=1, |
|
epoch_size=opt.episode, |
|
) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
for i, batch in enumerate(tqdm(dloader_test()), 1): |
|
data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] |
|
|
|
n_support = opt.way * opt.shot |
|
n_query = opt.way * opt.query |
|
|
|
if opt.shot == 1 and aug: |
|
flipped_data_support = flip(data_support, 3) |
|
data_support = torch.cat((data_support, flipped_data_support), dim=0) |
|
labels_support = torch.cat((labels_support, labels_support), dim=0) |
|
|
|
list_emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) |
|
list_emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) |
|
|
|
logits = torch.zeros(n_query, opt.way).cuda() |
|
|
|
for emb_support, emb_query in zip(list_emb_support, list_emb_query): |
|
|
|
|
|
emb_support = emb_support.view(1, opt.way, opt.shot, -1).mean(2) |
|
|
|
emb_query = emb_query.reshape(1, n_query, -1) |
|
|
|
dists = euclidean_dist(emb_query[0], emb_support[0]) |
|
|
|
|
|
logits += F.softmax(-dists, dim=1).view(1 * opt.way * opt.query, -1) |
|
|
|
|
|
|
|
logits /= opt.num_layer |
|
|
|
logits = logits.reshape(-1, opt.way) |
|
labels_query = labels_query.reshape(-1) |
|
|
|
|
|
acc,pca = count_accuracy(logits, labels_query) |
|
test_accuracies.append(acc.item()) |
|
per_class_accuracies.append(pca) |
|
|
|
y_pred_list.append(logits.detach().cpu().numpy()) |
|
y_list.append(labels_query.detach().cpu().numpy()) |
|
|
|
avg = np.mean(np.array(test_accuracies)) |
|
std = np.std(np.array(test_accuracies)) |
|
ci95 = 1.96 * std / np.sqrt(i + 1) |
|
|
|
if i % 10 == 0: |
|
|
|
|
|
|
|
|
|
|
|
pca = np.array(per_class_accuracies).mean(0) |
|
pcs = np.array(per_class_accuracies).std(0) |
|
|
|
print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} ({:.2f}) % ({:.2f} %)'\ |
|
.format(i, opt.episode, avg, ci95,std, acc)) |
|
print(f'{label_dict_inv[9]}: {pca[0]:.2f} ± {pcs[0]:.2f} % | {label_dict_inv[10]}: {pca[1]:.2f} ± {pcs[1]:.2f} % | {label_dict_inv[11]}: {pca[2]:.2f} ± {pcs[2]:.2f}%') |
|
|
|
|
|
|
|
pca = np.array(per_class_accuracies).mean(0) |
|
pcs = np.array(per_class_accuracies).std(0) |
|
|
|
print("Mean") |
|
print(pca) |
|
print('Standard Deviation') |
|
print(pcs) |
|
|
|
|
|
y_pred_proba = np.array( |
|
y_pred_list).reshape(-1, 3) |
|
|
|
y_pred = np.argmax(y_pred_proba, axis=1) |
|
|
|
y_true = np.array(y_list).reshape(-1) |
|
|
|
f1 = f1_score(y_true, y_pred, average=None) |
|
|
|
print('F1 Score') |
|
print(f1) |
|
|
|
fpr,tpr, auc = multiclass_roc(y_true,y_pred_proba) |
|
save_tuple = (fpr,tpr,auc) |
|
|
|
print(auc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_dict = {'Fibrosis': 0, 'Hernia': 1, 'Pneumonia': 2} |
|
|
|
|
|
|
|
|
|
|
|
class_dict_inv = {v: k for k, v in class_dict.items()} |
|
|
|
y_true = np.array([class_dict_inv[i] |
|
for i in np.array(y_list).reshape(-1)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|