xinyiW915 commited on
Commit
957b1d0
·
verified ·
1 Parent(s): caaa040

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +97 -0
  2. demo_test.py +166 -0
  3. model_regression.py +682 -0
  4. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import pandas as pd
5
+ from types import SimpleNamespace
6
+
7
+ from extractor.extract_rf_feats import VideoDataset_feature
8
+ from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
9
+ from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
10
+ from model_regression import Mlp, preprocess_data
11
+ from demo_test import evaluate_video_quality, load_model, get_transform
12
+
13
+
14
+ def run_diva_vqa(video_path, is_finetune, train_data_name, test_data_name, network_name):
15
+ if not os.path.exists(video_path):
16
+ return "❌ No video uploaded or the uploaded file has expired. Please upload again."
17
+
18
+ # print("CUDA available:", torch.cuda.is_available())
19
+ # print("Current device:", torch.cuda.current_device())
20
+
21
+ config = SimpleNamespace(**{
22
+ 'select_criteria': 'byrmse',
23
+ 'is_finetune': is_finetune,
24
+ 'save_path': 'model/',
25
+ 'train_data_name': train_data_name,
26
+ 'test_data_name': test_data_name,
27
+ 'test_video_path': video_path,
28
+ 'network_name': network_name,
29
+ 'num_workers': 0,
30
+ 'resize': 224,
31
+ 'patch_size': 16,
32
+ 'target_size': 224,
33
+ 'model_name': 'Mlp',
34
+ })
35
+ print(config.test_video_path)
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ # test demo video
39
+ resize_transform = get_transform(config.resize)
40
+ top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
41
+ data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
42
+ 'test_data_name': [config.test_data_name],
43
+ 'test_video_path': [config.test_video_path]}
44
+ videos_dir = os.path.dirname(config.test_video_path)
45
+ test_df = pd.DataFrame(data)
46
+ print(test_df.T)
47
+
48
+ dataset = VideoDataset_feature(videos_dir, test_df, resize_transform, config.resize, config.test_data_name, config.patch_size, config.target_size, top_n)
49
+ data_loader = torch.utils.data.DataLoader(
50
+ dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count()), pin_memory=True
51
+ )
52
+
53
+ # load models to device
54
+ model_slowfast = SlowFast().to(device)
55
+ if config.network_name == 'diva-vqa':
56
+ model_swint = SwinT(global_pool='avg').to(device) # 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
57
+ input_features = 9984
58
+ elif config.network_name == 'diva-vqa_large':
59
+ model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
60
+ input_features = 11520
61
+ model_mlp = load_model(config, device, input_features)
62
+
63
+ try:
64
+ score, runtime = evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device)
65
+ return f"Predicted Quality Score: {score:.4f} (in {runtime:.2f}s)"
66
+ except Exception as e:
67
+ return f"❌ Error: {str(e)}"
68
+ finally:
69
+ if "gradio" in video_path and os.path.exists(video_path):
70
+ os.remove(video_path)
71
+
72
+
73
+ demo = gr.Interface(
74
+ fn=run_diva_vqa,
75
+ inputs=[
76
+ gr.Video(label="Upload a Video (e.g. mp4)"),
77
+ gr.Checkbox(label="Use Finetuning?", value=False),
78
+ gr.Dropdown(label="Train Dataset Name", choices=["konvid_1k", "youtube_ugc", "live_vqc", "lsvq_train", "other"], value="lsvq_train"),
79
+ gr.Dropdown(label="Test Dataset Name", choices=["konvid_1k", "youtube_ugc", "live_vqc", "lsvq", "other"], value="konvid_1k"),
80
+ gr.Dropdown(label="Our Models", choices=["diva-vqa", "diva-vqa_large"], value="diva-vqa_large")
81
+ ],
82
+ outputs=gr.Textbox(label="Predicted Perceptual Quality Score (0–100)"),
83
+
84
+ title="🎬 DIVA-VQA Online Demo",
85
+ description=(
86
+ "Upload a short video and get the predicted perceptual quality score using the DIVA-VQA model. "
87
+ "You can try our demo video from the "
88
+ "<a href='https://huggingface.co/spaces/xinyiW915/DIVA-VQA/blob/main/ugc_original_videos/5636101558_540p.mp4' target='_blank'>demo video</a>. "
89
+ "<br><br>"
90
+ "⚙️ This demo is currently running on <strong>Hugging Face CPU Basic</strong>: 2 vCPU • 16 GB RAM."
91
+ # "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
92
+
93
+
94
+ ),
95
+ )
96
+
97
+ demo.launch()
demo_test.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+ from torchvision import models, transforms
9
+
10
+ from extractor.extract_rf_feats import VideoDataset_feature
11
+ from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
12
+ from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
13
+ from model_regression import Mlp, preprocess_data
14
+
15
+
16
+ def get_transform(resize):
17
+ return transforms.Compose([transforms.Resize([resize, resize]),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
20
+
21
+ def setup_device(config):
22
+ if config.device == "gpu":
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ if device.type == "cuda":
25
+ torch.cuda.set_device(0)
26
+ else:
27
+ device = torch.device("cpu")
28
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
29
+ return device
30
+
31
+ def fix_state_dict(state_dict):
32
+ new_state_dict = {}
33
+ for k, v in state_dict.items():
34
+ if k.startswith('module.'):
35
+ name = k[7:]
36
+ elif k == 'n_averaged':
37
+ continue
38
+ else:
39
+ name = k
40
+ new_state_dict[name] = v
41
+ return new_state_dict
42
+
43
+ def load_model(config, device, input_features=11520):
44
+ model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device)
45
+ if config.is_finetune:
46
+ model_path = os.path.join(config.save_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth")
47
+ else:
48
+ if config.train_data_name == 'lsvq_train':
49
+ model_path = os.path.join(config.save_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}"
50
+ f"_trained_median_model_param_kfold.pth")
51
+ else:
52
+ model_path = os.path.join(config.save_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}"
53
+ f"_trained_median_model_param.pth")
54
+ # print("Loading model from:", model_path)
55
+ state_dict = torch.load(model_path, map_location=device)
56
+ fixed_state_dict = fix_state_dict(state_dict)
57
+ try:
58
+ model.load_state_dict(fixed_state_dict)
59
+ except RuntimeError as e:
60
+ print(e)
61
+ return model
62
+
63
+ def evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device):
64
+ is_finetune = config.is_finetune
65
+ # get video features
66
+ model_slowfast.eval()
67
+ model_swint.eval()
68
+ with torch.no_grad():
69
+ for i, (video_segments, video_res_frag_all, video_frag_all, video_name) in enumerate(tqdm(data_loader, desc="Processing Videos")):
70
+ start_time = time.time()
71
+ # slowfast features
72
+ _, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
73
+ _, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
74
+ _, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
75
+ slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
76
+ slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
77
+ slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
78
+
79
+ # swinT feature
80
+ swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
81
+ swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
82
+ swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
83
+ swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
84
+ swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
85
+ swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
86
+
87
+ # frame + residual fragment + frame fragment features
88
+ rf_vqa_feats = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg,
89
+ swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
90
+
91
+ rf_vqa_feats = rf_vqa_feats
92
+ feature_tensor, _ = preprocess_data(rf_vqa_feats, None)
93
+ if feature_tensor.dim() == 1:
94
+ feature_tensor = feature_tensor.unsqueeze(0)
95
+ # print(f"Feature tensor shape before MLP: {feature_tensor.shape}")
96
+
97
+ model_mlp.eval()
98
+ with torch.no_grad():
99
+ with torch.cuda.amp.autocast():
100
+ prediction = model_mlp(feature_tensor)
101
+ run_time = time.time() - start_time
102
+ predicted_score = prediction.item()
103
+ return predicted_score, run_time
104
+
105
+
106
+ def parse_arguments():
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
109
+ parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
110
+ parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
111
+ parser.add_argument('-is_finetune', type=bool, default=True, help='With or without finetune')
112
+ parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
113
+
114
+ parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
115
+ parser.add_argument('-test_data_name', type=str, default='konvid_1k', help='Name of the testing data')
116
+ parser.add_argument('-test_video_path', type=str, default='ugc_original_videos/5636101558_540p.mp4', help='demo test video')
117
+
118
+ parser.add_argument('--network_name', type=str, default='diva-vqa_large')
119
+ parser.add_argument('--num_workers', type=int, default=4)
120
+ parser.add_argument('--resize', type=int, default=224, help='224, 384')
121
+ parser.add_argument('--patch_size', type=int, default=16, help='8, 16, 32, 8, 16, 32')
122
+ parser.add_argument('--target_size', type=int, default=224, help='224, 224, 224, 384, 384, 384')
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+ if __name__ == '__main__':
127
+ config = parse_arguments()
128
+ device = setup_device(config)
129
+
130
+ # test demo video
131
+ resize_transform = get_transform(config.resize)
132
+ top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
133
+ data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
134
+ 'test_data_name': [config.test_data_name],
135
+ 'test_video_path': [config.test_video_path]}
136
+ videos_dir = os.path.dirname(config.test_video_path)
137
+ test_df = pd.DataFrame(data)
138
+ # print(test_df.T)
139
+
140
+ dataset = VideoDataset_feature(videos_dir, test_df, resize_transform, config.resize, config.test_data_name, config.patch_size, config.target_size, top_n)
141
+ data_loader = torch.utils.data.DataLoader(
142
+ dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count()), pin_memory=True
143
+ )
144
+ # print(f"Dataset loaded. Total videos: {len(dataset)}, Total batches: {len(data_loader)}")
145
+
146
+ # load models to device
147
+ model_slowfast = SlowFast().to(device)
148
+ if config.network_name == 'diva-vqa':
149
+ model_swint = SwinT(global_pool='avg').to(device) # 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
150
+ input_features = 9984
151
+ elif config.network_name == 'diva-vqa_large':
152
+ model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
153
+ input_features = 11520
154
+ model_mlp = load_model(config, device, input_features)
155
+
156
+ total_time = 0
157
+ num_runs = 1
158
+ for i in range(num_runs):
159
+ quality_prediction, run_time = evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device)
160
+ print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")
161
+
162
+ total_time += run_time
163
+ average_time = total_time / num_runs
164
+
165
+ print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
166
+ print("Predicted Quality Score:", quality_prediction)
model_regression.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import math
7
+ import scipy.io
8
+ import scipy.stats
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.preprocessing import MinMaxScaler
11
+ from sklearn.metrics import mean_squared_error
12
+ from scipy.optimize import curve_fit
13
+ import joblib
14
+
15
+ import seaborn as sns
16
+ import matplotlib.pyplot as plt
17
+ import copy
18
+ import argparse
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ from torch.optim.lr_scheduler import CosineAnnealingLR
25
+ from torch.optim.swa_utils import AveragedModel, SWALR
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+ from sklearn.model_selection import KFold
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from data_processing import split_train_test
31
+
32
+ # ignore all warnings
33
+ import warnings
34
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
35
+
36
+
37
+ class Mlp(nn.Module):
38
+ def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
39
+ super().__init__()
40
+ self.fc1 = nn.Linear(input_features, hidden_features)
41
+ self.bn1 = nn.BatchNorm1d(hidden_features)
42
+ self.act1 = act_layer()
43
+ self.drop1 = nn.Dropout(drop_rate)
44
+ self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
45
+ self.act2 = act_layer()
46
+ self.drop2 = nn.Dropout(drop_rate)
47
+ self.fc3 = nn.Linear(hidden_features // 2, out_features)
48
+
49
+ def forward(self, input_feature):
50
+ x = self.fc1(input_feature)
51
+ x = self.bn1(x)
52
+ x = self.act1(x)
53
+ x = self.drop1(x)
54
+ x = self.fc2(x)
55
+ x = self.act2(x)
56
+ x = self.drop2(x)
57
+ output = self.fc3(x)
58
+ return output
59
+
60
+
61
+ class MAEAndRankLoss(nn.Module):
62
+ def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
63
+ super(MAEAndRankLoss, self).__init__()
64
+ self.l1_w = l1_w
65
+ self.rank_w = rank_w
66
+ self.margin = margin
67
+ self.use_margin = use_margin
68
+
69
+ def forward(self, y_pred, y_true):
70
+ # L1 loss/MAE loss
71
+ l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
72
+ # Rank loss
73
+ n = y_pred.size(0)
74
+ pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
75
+ true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
76
+
77
+ # e(ytrue_i, ytrue_j)
78
+ masks = torch.sign(true_diff)
79
+
80
+ if self.use_margin and self.margin > 0:
81
+ true_diff = true_diff.abs() - self.margin
82
+ true_diff = F.relu(true_diff)
83
+ masks = true_diff.sign()
84
+
85
+ l_rank = F.relu(true_diff - masks * pred_diff)
86
+ l_rank = l_rank.sum() / (n * (n - 1))
87
+
88
+ loss = l_mae + l_rank * self.rank_w
89
+ return loss
90
+
91
+ def load_data(csv, data, data_name, set_name):
92
+ try:
93
+ df = pd.read_csv(csv, skiprows=[], header=None)
94
+ except Exception as e:
95
+ logging.error(f'Read CSV file error: {e}')
96
+ raise
97
+
98
+ y_data = df.values[1:, 2].astype(float)
99
+ y = torch.tensor(y_data, dtype=torch.float32)
100
+
101
+ if data_name == 'cross_dataset':
102
+ y = torch.clamp(y, max=5)
103
+ if set_name == 'test':
104
+ print(f"Modified y_true: {y}")
105
+ X = data
106
+ return X, y
107
+
108
+ def preprocess_data(X, y):
109
+ X[torch.isnan(X)] = 0
110
+ X[torch.isinf(X)] = 0
111
+
112
+ # MinMaxScaler (use PyTorch implementation)
113
+ X_min = X.min(dim=0, keepdim=True).values
114
+ X_max = X.max(dim=0, keepdim=True).values
115
+ X = (X - X_min) / (X_max - X_min)
116
+ if y is not None:
117
+ y = y.view(-1, 1).squeeze()
118
+ return X, y
119
+
120
+ # define 4-parameter logistic regression
121
+ def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
122
+ logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
123
+ yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
124
+ return yhat
125
+
126
+ def fit_logistic_regression(y_pred, y_true):
127
+ beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
128
+ popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
129
+ y_pred_logistic = logistic_func(y_pred, *popt)
130
+ return y_pred_logistic, beta, popt
131
+
132
+ def compute_correlation_metrics(y_true, y_pred):
133
+ y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
134
+
135
+ plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
136
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
137
+ srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
138
+
139
+ try:
140
+ krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
141
+ except Exception as e:
142
+ logging.error(f'krcc calculation: {e}')
143
+ krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
144
+ return y_pred_logistic, plcc, rmse, srcc, krcc
145
+
146
+ def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
147
+ # nonlinear logistic fitted curve / logistic regression
148
+ if isinstance(y_test, torch.Tensor):
149
+ mos1 = y_test.numpy()
150
+ y1 = y_test_pred_logistic
151
+
152
+ try:
153
+ beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
154
+ popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
155
+ sigma = np.sqrt(np.diag(pcov))
156
+ except:
157
+ raise Exception('Fitting logistic function time-out!!')
158
+ x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
159
+ plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
160
+
161
+ fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
162
+ plt.legend(loc='upper left')
163
+ if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
164
+ plt.ylim(0, 100)
165
+ plt.xlim(0, 100)
166
+ else:
167
+ plt.ylim(1, 5)
168
+ plt.xlim(1, 5)
169
+ plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
170
+ plt.xlabel('Predicted Score')
171
+ plt.ylabel('MOS')
172
+ reg_fig1 = fig1.get_figure()
173
+
174
+ fig_path = f'../figs/{data_name}/'
175
+ os.makedirs(fig_path, exist_ok=True)
176
+ reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}_kfold.png", dpi=300)
177
+ plt.clf()
178
+ plt.close()
179
+
180
+ def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
181
+ plt.figure(figsize=(10, 6))
182
+
183
+ plt.plot(avg_train_losses, label='Average Training Loss')
184
+ plt.plot(avg_val_losses, label='Average Validation Loss')
185
+
186
+ plt.xlabel('Epoch')
187
+ plt.ylabel('Loss')
188
+ plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
189
+
190
+ plt.legend()
191
+ fig_par_path = f'../log/result/{data_name}/'
192
+ os.makedirs(fig_par_path, exist_ok=True)
193
+ plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
194
+ plt.clf()
195
+ plt.close()
196
+
197
+ def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
198
+ log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.log")
199
+ logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
200
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
201
+ logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
202
+ logging.info(f"torch cuda: {torch.cuda.is_available()}")
203
+
204
+ def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
205
+ if data_name == 'cross_dataset':
206
+ data_name1 = 'youtube_ugc_all'
207
+ data_name2 = 'cvd_2014_all'
208
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv')
209
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv')
210
+ train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name1}_train_features.pt')
211
+ test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name2}_test_features.pt')
212
+ X_train, y_train = load_data(train_csv, train_data, data_name1, 'train')
213
+ X_test, y_test = load_data(test_csv, test_data, data_name2, 'test')
214
+
215
+ elif data_name == 'lsvq_train':
216
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
217
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
218
+ X_train, y_train = load_data(train_csv, train_features, data_name, 'train')
219
+ X_test, y_test = load_data(test_csv, test_features, data_name, 'test')
220
+
221
+ else:
222
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
223
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
224
+ train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_train_features.pt')
225
+ test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_test_features.pt')
226
+ X_train, y_train = load_data(train_csv, train_data, data_name, 'train')
227
+ X_test, y_test = load_data(test_csv, test_data, data_name, 'test')
228
+
229
+ # standard min-max normalization of training features
230
+ X_train, y_train = preprocess_data(X_train, y_train)
231
+ X_test, y_test = preprocess_data(X_test, y_test)
232
+
233
+ return X_train, y_train, X_test, y_test
234
+
235
+ def train_one_epoch(model, train_loader, criterion, optimizer, device):
236
+ """Train the model for one epoch"""
237
+ model.train()
238
+ train_loss = 0.0
239
+ for inputs, targets in train_loader:
240
+ inputs, targets = inputs.to(device), targets.to(device)
241
+
242
+ optimizer.zero_grad()
243
+ outputs = model(inputs)
244
+ loss = criterion(outputs, targets.view(-1, 1))
245
+ loss.backward()
246
+ optimizer.step()
247
+ train_loss += loss.item() * inputs.size(0)
248
+ train_loss /= len(train_loader.dataset)
249
+ return train_loss
250
+
251
+ def evaluate(model, val_loader, criterion, device):
252
+ """Evaluate model performance on validation sets"""
253
+ model.eval()
254
+ val_loss = 0.0
255
+ y_val_pred = []
256
+ y_val_true = []
257
+ with torch.no_grad():
258
+ for inputs, targets in val_loader:
259
+ inputs, targets = inputs.to(device), targets.to(device)
260
+
261
+ outputs = model(inputs)
262
+ y_val_pred.append(outputs)
263
+ y_val_true.append(targets)
264
+ loss = criterion(outputs, targets.view(-1, 1))
265
+ val_loss += loss.item() * inputs.size(0)
266
+
267
+ val_loss /= len(val_loader.dataset)
268
+ y_val_pred = torch.cat(y_val_pred, dim=0)
269
+ y_val_true = torch.cat(y_val_true, dim=0)
270
+ return val_loss, y_val_pred, y_val_true
271
+
272
+ def update_best_model(select_criteria, best_metric, current_val, model):
273
+ is_better = False
274
+ if select_criteria == 'byrmse' and current_val < best_metric:
275
+ is_better = True
276
+ elif select_criteria == 'bykrcc' and current_val > best_metric:
277
+ is_better = True
278
+
279
+ if is_better:
280
+ return current_val, copy.deepcopy(model), is_better
281
+ return best_metric, model, is_better
282
+
283
+ def train_and_evaluate(X_train, y_train, config):
284
+ # parameters
285
+ n_repeats = config['n_repeats']
286
+ n_splits = config['n_splits']
287
+ batch_size = config['batch_size']
288
+ epochs = config['epochs']
289
+ hidden_features = config['hidden_features']
290
+ drop_rate = config['drop_rate']
291
+ loss_type = config['loss_type']
292
+ optimizer_type = config['optimizer_type']
293
+ select_criteria = config['select_criteria']
294
+ initial_lr = config['initial_lr']
295
+ weight_decay = config['weight_decay']
296
+ patience = config['patience']
297
+ l1_w = config['l1_w']
298
+ rank_w = config['rank_w']
299
+ use_swa = config.get('use_swa', False)
300
+ logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Number of splits for kfold: {n_splits}, Batch size: {batch_size}, Number of epochs: {epochs}')
301
+ logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
302
+ logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
303
+ logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
304
+
305
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
306
+ best_model = None
307
+ best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
308
+
309
+ # loss for every fold
310
+ all_train_losses = []
311
+ all_val_losses = []
312
+ for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
313
+ print(f"Fold {fold + 1}/{n_splits}")
314
+
315
+ X_train_fold, X_val_fold = X_train[train_idx], X_train[val_idx]
316
+ y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]
317
+
318
+ # initialisation of model, loss function, optimiser
319
+ model = Mlp(input_features=X_train_fold.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
320
+ model = model.to(device) # to gpu
321
+
322
+ if loss_type == 'MAERankLoss':
323
+ criterion = MAEAndRankLoss()
324
+ criterion.l1_w = l1_w
325
+ criterion.rank_w = rank_w
326
+ else:
327
+ criterion = nn.MSELoss()
328
+
329
+ if optimizer_type == 'sgd':
330
+ optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
331
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_nim=1e-5
332
+ else:
333
+ optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
334
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
335
+ if use_swa:
336
+ swa_model = AveragedModel(model).to(device)
337
+ swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
338
+
339
+ # dataset loader
340
+ train_dataset = TensorDataset(X_train_fold, y_train_fold)
341
+ val_dataset = TensorDataset(X_val_fold, y_val_fold)
342
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
343
+ val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
344
+
345
+ train_losses, val_losses = [], []
346
+
347
+ # early stopping parameters
348
+ best_val_loss = float('inf')
349
+ epochs_no_improve = 0
350
+ early_stop_active = False
351
+ swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
352
+
353
+ for epoch in range(epochs):
354
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
355
+ train_losses.append(train_loss)
356
+ scheduler.step() # update learning rate
357
+ if use_swa and epoch >= swa_start:
358
+ swa_model.update_parameters(model)
359
+ swa_scheduler.step()
360
+ early_stop_active = True
361
+ print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
362
+
363
+ lr = optimizer.param_groups[0]['lr']
364
+ print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
365
+
366
+ # decide which model to evaluate: SWA model or regular model
367
+ current_model = swa_model if use_swa and epoch >= swa_start else model
368
+ current_model.eval()
369
+ val_loss, y_val_pred, y_val_true = evaluate(current_model, val_loader, criterion, device)
370
+ val_losses.append(val_loss)
371
+ print(f"Epoch {epoch + 1}, Fold {fold + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
372
+
373
+ y_val_pred = torch.cat([pred for pred in y_val_pred])
374
+ _, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
375
+ current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
376
+ best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
377
+ if is_better:
378
+ logging.info(f"Epoch {epoch + 1}, Fold {fold + 1}:")
379
+ y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
380
+ logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
381
+
382
+ X_train_fold_tensor = X_train_fold
383
+ y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().squeeze()
384
+ y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train_fold.cpu().numpy(), y_tra_pred_tmp.cpu().numpy())
385
+ logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
386
+
387
+ # check for loss improvement
388
+ if early_stop_active:
389
+ if val_loss < best_val_loss:
390
+ best_val_loss = val_loss
391
+ # save the best model if validation loss improves
392
+ best_model = copy.deepcopy(model)
393
+ epochs_no_improve = 0
394
+ else:
395
+ epochs_no_improve += 1
396
+ if epochs_no_improve >= patience:
397
+ # epochs to wait for improvement before stopping
398
+ print(f"Early stopping triggered after {epoch + 1} epochs.")
399
+ break
400
+
401
+ # saving SWA models and updating BN statistics
402
+ if use_swa:
403
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
404
+ best_model = best_model.to(device)
405
+ best_model.eval()
406
+ torch.optim.swa_utils.update_bn(train_loader, best_model)
407
+ # swa_model_path = os.path.join('save_swa_path='../model/', f'model_swa_fold{fold}.pth')
408
+ # torch.save(swa_model.state_dict(), swa_model_path)
409
+ # logging.info(f'SWA model saved at {swa_model_path}')
410
+
411
+ all_train_losses.append(train_losses)
412
+ all_val_losses.append(val_losses)
413
+ max_length = max(len(x) for x in all_train_losses)
414
+ all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
415
+ max_length = max(len(x) for x in all_val_losses)
416
+ all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
417
+
418
+ return best_model, all_train_losses, all_val_losses
419
+
420
+ def collate_to_device(batch, device):
421
+ data, targets = zip(*batch)
422
+ return torch.stack(data).to(device), torch.stack(targets).to(device)
423
+
424
+ def model_test(best_model, X, y, device):
425
+ test_dataset = TensorDataset(X, y)
426
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
427
+
428
+ best_model.eval()
429
+ y_pred = []
430
+ with torch.no_grad():
431
+ for inputs, _ in test_loader:
432
+ inputs = inputs.to(device)
433
+
434
+ outputs = best_model(inputs)
435
+ y_pred.extend(outputs.view(-1).tolist())
436
+
437
+ return y_pred
438
+
439
+ def main(config):
440
+ model_name = config['model_name']
441
+ data_name = config['data_name']
442
+ network_name = config['network_name']
443
+
444
+ metadata_path = config['metadata_path']
445
+ feature_path = config['feature_path']
446
+ log_path = config['log_path']
447
+ save_path = config['save_path']
448
+ score_path = config['score_path']
449
+ result_path = config['result_path']
450
+
451
+ # parameters
452
+ select_criteria = config['select_criteria']
453
+ n_repeats = config['n_repeats']
454
+
455
+ # logging and result
456
+ os.makedirs(log_path, exist_ok=True)
457
+ os.makedirs(save_path, exist_ok=True)
458
+ os.makedirs(score_path, exist_ok=True)
459
+ os.makedirs(result_path, exist_ok=True)
460
+ result_file = f'{result_path}{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.mat'
461
+ pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{model_name}_Predicted_Score_{select_criteria}_kfold.csv")
462
+ file_path = os.path.join(save_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}_trained_median_model_param_kfold.pth")
463
+ configure_logging(log_path, model_name, data_name, network_name, select_criteria)
464
+
465
+ '''======================== Main Body ==========================='''
466
+ PLCC_all_repeats_test = []
467
+ SRCC_all_repeats_test = []
468
+ KRCC_all_repeats_test = []
469
+ RMSE_all_repeats_test = []
470
+ PLCC_all_repeats_train = []
471
+ SRCC_all_repeats_train = []
472
+ KRCC_all_repeats_train = []
473
+ RMSE_all_repeats_train = []
474
+ all_repeats_test_vids = []
475
+ all_repeats_df_test_pred = []
476
+ best_model_list = []
477
+
478
+ for i in range(1, n_repeats + 1):
479
+ print(f"{i}th repeated 80-20 hold out test")
480
+ logging.info(f"{i}th repeated 80-20 hold out test")
481
+ t0 = time.time()
482
+
483
+ # train test split
484
+ test_size = 0.2
485
+ random_state = math.ceil(8.8 * i)
486
+ # NR: original
487
+ if data_name == 'lsvq_train':
488
+ test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
489
+ train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
490
+ elif data_name == 'cross_dataset':
491
+ train_data_name = 'youtube_ugc_all'
492
+ test_data_name = 'cvd_2014_all'
493
+ _, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name)
494
+ else:
495
+ _, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
496
+
497
+ '''======================== read files =============================== '''
498
+ if data_name == 'lsvq_train':
499
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
500
+ else:
501
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
502
+
503
+ '''======================== regression model =============================== '''
504
+ best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
505
+
506
+ # average loss plots
507
+ avg_train_losses = np.mean(all_train_losses, axis=0)
508
+ avg_val_losses = np.mean(all_val_losses, axis=0)
509
+ test_vids = test_vids.tolist()
510
+ plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
511
+
512
+ # predict best model on the train dataset
513
+ y_train_pred = model_test(best_model, X_train, y_train, device)
514
+ y_train_pred = torch.tensor(list(y_train_pred), dtype=torch.float32)
515
+ y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train.cpu().numpy(), y_train_pred.cpu().numpy())
516
+
517
+ # test best model on the test dataset
518
+ y_test_pred = model_test(best_model, X_test, y_test, device)
519
+ y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
520
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test.cpu().numpy(), y_test_pred.cpu().numpy())
521
+
522
+ # save the predict score results
523
+ test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
524
+ df_test_pred = pd.DataFrame(test_pred_score)
525
+
526
+ # logging logistic predicted scores
527
+ logging.info("============================================================================================================")
528
+ SRCC_all_repeats_test.append(srcc_test)
529
+ KRCC_all_repeats_test.append(krcc_test)
530
+ PLCC_all_repeats_test.append(plcc_test)
531
+ RMSE_all_repeats_test.append(rmse_test)
532
+ SRCC_all_repeats_train.append(srcc_train)
533
+ KRCC_all_repeats_train.append(krcc_train)
534
+ PLCC_all_repeats_train.append(plcc_train)
535
+ RMSE_all_repeats_train.append(rmse_train)
536
+ all_repeats_test_vids.append(test_vids)
537
+ all_repeats_df_test_pred.append(df_test_pred)
538
+ best_model_list.append(copy.deepcopy(best_model))
539
+
540
+ # logging.info results for each iteration
541
+ logging.info('Best results in Mlp model within one split')
542
+ logging.info(f'MODEL: {best_model}')
543
+ logging.info('======================================================')
544
+ logging.info(f'Train set - Evaluation Results')
545
+ logging.info(f'SRCC_train: {srcc_train}')
546
+ logging.info(f'KRCC_train: {krcc_train}')
547
+ logging.info(f'PLCC_train: {plcc_train}')
548
+ logging.info(f'RMSE_train: {rmse_train}')
549
+ logging.info('======================================================')
550
+ logging.info(f'Test set - Evaluation Results')
551
+ logging.info(f'SRCC_test: {srcc_test}')
552
+ logging.info(f'KRCC_test: {krcc_test}')
553
+ logging.info(f'PLCC_test: {plcc_test}')
554
+ logging.info(f'RMSE_test: {rmse_test}')
555
+ logging.info('======================================================')
556
+ logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
557
+
558
+ logging.info('')
559
+ SRCC_all_repeats_test = torch.tensor(SRCC_all_repeats_test, dtype=torch.float32)
560
+ KRCC_all_repeats_test = torch.tensor(KRCC_all_repeats_test, dtype=torch.float32)
561
+ PLCC_all_repeats_test = torch.tensor(PLCC_all_repeats_test, dtype=torch.float32)
562
+ RMSE_all_repeats_test = torch.tensor(RMSE_all_repeats_test, dtype=torch.float32)
563
+ SRCC_all_repeats_train = torch.tensor(SRCC_all_repeats_train, dtype=torch.float32)
564
+ KRCC_all_repeats_train = torch.tensor(KRCC_all_repeats_train, dtype=torch.float32)
565
+ PLCC_all_repeats_train = torch.tensor(PLCC_all_repeats_train, dtype=torch.float32)
566
+ RMSE_all_repeats_train = torch.tensor(RMSE_all_repeats_train, dtype=torch.float32)
567
+
568
+ logging.info('======================================================')
569
+ logging.info('Average training results among all repeated 80-20 holdouts:')
570
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_train).item(), torch.std(SRCC_all_repeats_train).item())
571
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_train).item(), torch.std(KRCC_all_repeats_train).item())
572
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_train).item(), torch.std(PLCC_all_repeats_train).item())
573
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_train).item(), torch.std(RMSE_all_repeats_train).item())
574
+ logging.info('======================================================')
575
+ logging.info('Average testing results among all repeated 80-20 holdouts:')
576
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_test).item(), torch.std(SRCC_all_repeats_test).item())
577
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_test).item(), torch.std(KRCC_all_repeats_test).item())
578
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_test).item(), torch.std(PLCC_all_repeats_test).item())
579
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_test).item(), torch.std(RMSE_all_repeats_test).item())
580
+ logging.info('======================================================')
581
+ logging.info('\n')
582
+
583
+ # find the median model and the index of the median
584
+ print('======================================================')
585
+ if select_criteria == 'byrmse':
586
+ median_metrics = torch.median(RMSE_all_repeats_test).item()
587
+ indices = (RMSE_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
588
+ select_criteria = select_criteria.replace('by', '').upper()
589
+ print(RMSE_all_repeats_test)
590
+ logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
591
+ elif select_criteria == 'bykrcc':
592
+ median_metrics = torch.median(KRCC_all_repeats_test).item()
593
+ indices = (KRCC_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
594
+ select_criteria = select_criteria.replace('by', '').upper()
595
+ print(KRCC_all_repeats_test)
596
+ logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
597
+
598
+ median_test_vids = [all_repeats_test_vids[i] for i in indices]
599
+ test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
600
+
601
+ # select the model with the first index where the median is located
602
+ # Note: If there are multiple iterations with the same median RMSE, the first index is selected here
603
+ median_model = None
604
+ if len(indices) > 0:
605
+ median_index = indices[0] # select the first index
606
+ median_model = best_model_list[median_index]
607
+ median_model_df_test_pred = all_repeats_df_test_pred[median_index]
608
+
609
+ median_model_df_test_pred.to_csv(pred_score_filename, index=False)
610
+ plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
611
+
612
+ print(f'Median Metrics: {median_metrics}')
613
+ print(f'Indices: {indices}')
614
+ # print(f'Test Videos: {test_vids}')
615
+ print(f'Best model: {median_model}')
616
+
617
+ logging.info(f'median test {select_criteria}: {median_metrics}')
618
+ logging.info(f"Indices of median metrics: {indices}")
619
+ # logging.info(f'Best training and test dataset: {test_vids}')
620
+ logging.info(f'Best model predict score: {median_model_df_test_pred}')
621
+ logging.info(f'Best model: {median_model}')
622
+
623
+ # ================================================================================
624
+ # save mats
625
+ scipy.io.savemat(result_file, mdict={'SRCC_train': SRCC_all_repeats_train.numpy(),
626
+ 'KRCC_train': KRCC_all_repeats_train.numpy(),
627
+ 'PLCC_train': PLCC_all_repeats_train.numpy(),
628
+ 'RMSE_train': RMSE_all_repeats_train.numpy(),
629
+ 'SRCC_test': SRCC_all_repeats_test.numpy(),
630
+ 'KRCC_test': KRCC_all_repeats_test.numpy(),
631
+ 'PLCC_test': PLCC_all_repeats_test.numpy(),
632
+ 'RMSE_test': RMSE_all_repeats_test.numpy(),
633
+ f'Median_{select_criteria}': median_metrics,
634
+ 'Test_Videos_list': all_repeats_test_vids,
635
+ 'Test_videos_Median_model': test_vids})
636
+
637
+ # save model
638
+ torch.save(median_model.state_dict(), file_path)
639
+ print(f"Model state_dict saved to {file_path}")
640
+
641
+
642
+ if __name__ == '__main__':
643
+ parser = argparse.ArgumentParser()
644
+ # input parameters
645
+ parser.add_argument('--model_name', type=str, default='Mlp')
646
+ parser.add_argument('--data_name', type=str, default='lsvq_train', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, live_qualcomm, lsvq_train, cross_dataset')
647
+ parser.add_argument('--network_name', type=str, default='diva-vqa_large' , help='diva-vqa')
648
+
649
+ parser.add_argument('--metadata_path', type=str, default='../metadata/')
650
+ parser.add_argument('--feature_path', type=str, default=f'../features/diva-vqa/diva-vqa_large/')
651
+ parser.add_argument('--log_path', type=str, default='../log/')
652
+ parser.add_argument('--save_path', type=str, default='../model/')
653
+ parser.add_argument('--score_path', type=str, default='../log/predict_score/')
654
+ parser.add_argument('--result_path', type=str, default='../log/result/')
655
+ # training parameters
656
+ parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc')
657
+ parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test')
658
+ parser.add_argument('--n_splits', type=int, default=10, help='Number of splits for k-fold validation')
659
+ parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
660
+ parser.add_argument('--epochs', type=int, default=50, help='Epochs for training')
661
+ parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features')
662
+ parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.')
663
+ # misc
664
+ parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss')
665
+ parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd')
666
+ parser.add_argument('--initial_lr', type=float, default=1e-1, help='Initial learning rate: 1e-2')
667
+ parser.add_argument('--weight_decay', type=float, default=0.005, help='Weight decay (L2 loss): 1e-4')
668
+ parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.')
669
+ parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging')
670
+ parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight')
671
+ parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight')
672
+
673
+ args = parser.parse_args()
674
+ config = vars(args) # args to dict
675
+ print(config)
676
+
677
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
678
+ print(device)
679
+ if device.type == "cuda":
680
+ torch.cuda.set_device(0)
681
+
682
+ main(config)
requirements.txt ADDED
Binary file (428 Bytes). View file