Spaces:
Running
on
Zero
Running
on
Zero
Upload 4 files
Browse files- app.py +97 -0
- demo_test.py +166 -0
- model_regression.py +682 -0
- requirements.txt +0 -0
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
from types import SimpleNamespace
|
6 |
+
|
7 |
+
from extractor.extract_rf_feats import VideoDataset_feature
|
8 |
+
from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
|
9 |
+
from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
|
10 |
+
from model_regression import Mlp, preprocess_data
|
11 |
+
from demo_test import evaluate_video_quality, load_model, get_transform
|
12 |
+
|
13 |
+
|
14 |
+
def run_diva_vqa(video_path, is_finetune, train_data_name, test_data_name, network_name):
|
15 |
+
if not os.path.exists(video_path):
|
16 |
+
return "❌ No video uploaded or the uploaded file has expired. Please upload again."
|
17 |
+
|
18 |
+
# print("CUDA available:", torch.cuda.is_available())
|
19 |
+
# print("Current device:", torch.cuda.current_device())
|
20 |
+
|
21 |
+
config = SimpleNamespace(**{
|
22 |
+
'select_criteria': 'byrmse',
|
23 |
+
'is_finetune': is_finetune,
|
24 |
+
'save_path': 'model/',
|
25 |
+
'train_data_name': train_data_name,
|
26 |
+
'test_data_name': test_data_name,
|
27 |
+
'test_video_path': video_path,
|
28 |
+
'network_name': network_name,
|
29 |
+
'num_workers': 0,
|
30 |
+
'resize': 224,
|
31 |
+
'patch_size': 16,
|
32 |
+
'target_size': 224,
|
33 |
+
'model_name': 'Mlp',
|
34 |
+
})
|
35 |
+
print(config.test_video_path)
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
# test demo video
|
39 |
+
resize_transform = get_transform(config.resize)
|
40 |
+
top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
|
41 |
+
data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
|
42 |
+
'test_data_name': [config.test_data_name],
|
43 |
+
'test_video_path': [config.test_video_path]}
|
44 |
+
videos_dir = os.path.dirname(config.test_video_path)
|
45 |
+
test_df = pd.DataFrame(data)
|
46 |
+
print(test_df.T)
|
47 |
+
|
48 |
+
dataset = VideoDataset_feature(videos_dir, test_df, resize_transform, config.resize, config.test_data_name, config.patch_size, config.target_size, top_n)
|
49 |
+
data_loader = torch.utils.data.DataLoader(
|
50 |
+
dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count()), pin_memory=True
|
51 |
+
)
|
52 |
+
|
53 |
+
# load models to device
|
54 |
+
model_slowfast = SlowFast().to(device)
|
55 |
+
if config.network_name == 'diva-vqa':
|
56 |
+
model_swint = SwinT(global_pool='avg').to(device) # 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
|
57 |
+
input_features = 9984
|
58 |
+
elif config.network_name == 'diva-vqa_large':
|
59 |
+
model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
|
60 |
+
input_features = 11520
|
61 |
+
model_mlp = load_model(config, device, input_features)
|
62 |
+
|
63 |
+
try:
|
64 |
+
score, runtime = evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device)
|
65 |
+
return f"Predicted Quality Score: {score:.4f} (in {runtime:.2f}s)"
|
66 |
+
except Exception as e:
|
67 |
+
return f"❌ Error: {str(e)}"
|
68 |
+
finally:
|
69 |
+
if "gradio" in video_path and os.path.exists(video_path):
|
70 |
+
os.remove(video_path)
|
71 |
+
|
72 |
+
|
73 |
+
demo = gr.Interface(
|
74 |
+
fn=run_diva_vqa,
|
75 |
+
inputs=[
|
76 |
+
gr.Video(label="Upload a Video (e.g. mp4)"),
|
77 |
+
gr.Checkbox(label="Use Finetuning?", value=False),
|
78 |
+
gr.Dropdown(label="Train Dataset Name", choices=["konvid_1k", "youtube_ugc", "live_vqc", "lsvq_train", "other"], value="lsvq_train"),
|
79 |
+
gr.Dropdown(label="Test Dataset Name", choices=["konvid_1k", "youtube_ugc", "live_vqc", "lsvq", "other"], value="konvid_1k"),
|
80 |
+
gr.Dropdown(label="Our Models", choices=["diva-vqa", "diva-vqa_large"], value="diva-vqa_large")
|
81 |
+
],
|
82 |
+
outputs=gr.Textbox(label="Predicted Perceptual Quality Score (0–100)"),
|
83 |
+
|
84 |
+
title="🎬 DIVA-VQA Online Demo",
|
85 |
+
description=(
|
86 |
+
"Upload a short video and get the predicted perceptual quality score using the DIVA-VQA model. "
|
87 |
+
"You can try our demo video from the "
|
88 |
+
"<a href='https://huggingface.co/spaces/xinyiW915/DIVA-VQA/blob/main/ugc_original_videos/5636101558_540p.mp4' target='_blank'>demo video</a>. "
|
89 |
+
"<br><br>"
|
90 |
+
"⚙️ This demo is currently running on <strong>Hugging Face CPU Basic</strong>: 2 vCPU • 16 GB RAM."
|
91 |
+
# "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
|
92 |
+
|
93 |
+
|
94 |
+
),
|
95 |
+
)
|
96 |
+
|
97 |
+
demo.launch()
|
demo_test.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torchvision import models, transforms
|
9 |
+
|
10 |
+
from extractor.extract_rf_feats import VideoDataset_feature
|
11 |
+
from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
|
12 |
+
from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
|
13 |
+
from model_regression import Mlp, preprocess_data
|
14 |
+
|
15 |
+
|
16 |
+
def get_transform(resize):
|
17 |
+
return transforms.Compose([transforms.Resize([resize, resize]),
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
|
20 |
+
|
21 |
+
def setup_device(config):
|
22 |
+
if config.device == "gpu":
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
if device.type == "cuda":
|
25 |
+
torch.cuda.set_device(0)
|
26 |
+
else:
|
27 |
+
device = torch.device("cpu")
|
28 |
+
print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
|
29 |
+
return device
|
30 |
+
|
31 |
+
def fix_state_dict(state_dict):
|
32 |
+
new_state_dict = {}
|
33 |
+
for k, v in state_dict.items():
|
34 |
+
if k.startswith('module.'):
|
35 |
+
name = k[7:]
|
36 |
+
elif k == 'n_averaged':
|
37 |
+
continue
|
38 |
+
else:
|
39 |
+
name = k
|
40 |
+
new_state_dict[name] = v
|
41 |
+
return new_state_dict
|
42 |
+
|
43 |
+
def load_model(config, device, input_features=11520):
|
44 |
+
model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device)
|
45 |
+
if config.is_finetune:
|
46 |
+
model_path = os.path.join(config.save_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth")
|
47 |
+
else:
|
48 |
+
if config.train_data_name == 'lsvq_train':
|
49 |
+
model_path = os.path.join(config.save_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}"
|
50 |
+
f"_trained_median_model_param_kfold.pth")
|
51 |
+
else:
|
52 |
+
model_path = os.path.join(config.save_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}"
|
53 |
+
f"_trained_median_model_param.pth")
|
54 |
+
# print("Loading model from:", model_path)
|
55 |
+
state_dict = torch.load(model_path, map_location=device)
|
56 |
+
fixed_state_dict = fix_state_dict(state_dict)
|
57 |
+
try:
|
58 |
+
model.load_state_dict(fixed_state_dict)
|
59 |
+
except RuntimeError as e:
|
60 |
+
print(e)
|
61 |
+
return model
|
62 |
+
|
63 |
+
def evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device):
|
64 |
+
is_finetune = config.is_finetune
|
65 |
+
# get video features
|
66 |
+
model_slowfast.eval()
|
67 |
+
model_swint.eval()
|
68 |
+
with torch.no_grad():
|
69 |
+
for i, (video_segments, video_res_frag_all, video_frag_all, video_name) in enumerate(tqdm(data_loader, desc="Processing Videos")):
|
70 |
+
start_time = time.time()
|
71 |
+
# slowfast features
|
72 |
+
_, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
|
73 |
+
_, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
|
74 |
+
_, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
|
75 |
+
slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
|
76 |
+
slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
|
77 |
+
slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
|
78 |
+
|
79 |
+
# swinT feature
|
80 |
+
swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
|
81 |
+
swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
|
82 |
+
swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
|
83 |
+
swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
|
84 |
+
swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
|
85 |
+
swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
|
86 |
+
|
87 |
+
# frame + residual fragment + frame fragment features
|
88 |
+
rf_vqa_feats = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg,
|
89 |
+
swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
|
90 |
+
|
91 |
+
rf_vqa_feats = rf_vqa_feats
|
92 |
+
feature_tensor, _ = preprocess_data(rf_vqa_feats, None)
|
93 |
+
if feature_tensor.dim() == 1:
|
94 |
+
feature_tensor = feature_tensor.unsqueeze(0)
|
95 |
+
# print(f"Feature tensor shape before MLP: {feature_tensor.shape}")
|
96 |
+
|
97 |
+
model_mlp.eval()
|
98 |
+
with torch.no_grad():
|
99 |
+
with torch.cuda.amp.autocast():
|
100 |
+
prediction = model_mlp(feature_tensor)
|
101 |
+
run_time = time.time() - start_time
|
102 |
+
predicted_score = prediction.item()
|
103 |
+
return predicted_score, run_time
|
104 |
+
|
105 |
+
|
106 |
+
def parse_arguments():
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
|
109 |
+
parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
|
110 |
+
parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
|
111 |
+
parser.add_argument('-is_finetune', type=bool, default=True, help='With or without finetune')
|
112 |
+
parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
|
113 |
+
|
114 |
+
parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
|
115 |
+
parser.add_argument('-test_data_name', type=str, default='konvid_1k', help='Name of the testing data')
|
116 |
+
parser.add_argument('-test_video_path', type=str, default='ugc_original_videos/5636101558_540p.mp4', help='demo test video')
|
117 |
+
|
118 |
+
parser.add_argument('--network_name', type=str, default='diva-vqa_large')
|
119 |
+
parser.add_argument('--num_workers', type=int, default=4)
|
120 |
+
parser.add_argument('--resize', type=int, default=224, help='224, 384')
|
121 |
+
parser.add_argument('--patch_size', type=int, default=16, help='8, 16, 32, 8, 16, 32')
|
122 |
+
parser.add_argument('--target_size', type=int, default=224, help='224, 224, 224, 384, 384, 384')
|
123 |
+
args = parser.parse_args()
|
124 |
+
return args
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
config = parse_arguments()
|
128 |
+
device = setup_device(config)
|
129 |
+
|
130 |
+
# test demo video
|
131 |
+
resize_transform = get_transform(config.resize)
|
132 |
+
top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
|
133 |
+
data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
|
134 |
+
'test_data_name': [config.test_data_name],
|
135 |
+
'test_video_path': [config.test_video_path]}
|
136 |
+
videos_dir = os.path.dirname(config.test_video_path)
|
137 |
+
test_df = pd.DataFrame(data)
|
138 |
+
# print(test_df.T)
|
139 |
+
|
140 |
+
dataset = VideoDataset_feature(videos_dir, test_df, resize_transform, config.resize, config.test_data_name, config.patch_size, config.target_size, top_n)
|
141 |
+
data_loader = torch.utils.data.DataLoader(
|
142 |
+
dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count()), pin_memory=True
|
143 |
+
)
|
144 |
+
# print(f"Dataset loaded. Total videos: {len(dataset)}, Total batches: {len(data_loader)}")
|
145 |
+
|
146 |
+
# load models to device
|
147 |
+
model_slowfast = SlowFast().to(device)
|
148 |
+
if config.network_name == 'diva-vqa':
|
149 |
+
model_swint = SwinT(global_pool='avg').to(device) # 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
|
150 |
+
input_features = 9984
|
151 |
+
elif config.network_name == 'diva-vqa_large':
|
152 |
+
model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
|
153 |
+
input_features = 11520
|
154 |
+
model_mlp = load_model(config, device, input_features)
|
155 |
+
|
156 |
+
total_time = 0
|
157 |
+
num_runs = 1
|
158 |
+
for i in range(num_runs):
|
159 |
+
quality_prediction, run_time = evaluate_video_quality(config, data_loader, model_slowfast, model_swint, model_mlp, device)
|
160 |
+
print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")
|
161 |
+
|
162 |
+
total_time += run_time
|
163 |
+
average_time = total_time / num_runs
|
164 |
+
|
165 |
+
print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
|
166 |
+
print("Predicted Quality Score:", quality_prediction)
|
model_regression.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import scipy.io
|
8 |
+
import scipy.stats
|
9 |
+
from sklearn.impute import SimpleImputer
|
10 |
+
from sklearn.preprocessing import MinMaxScaler
|
11 |
+
from sklearn.metrics import mean_squared_error
|
12 |
+
from scipy.optimize import curve_fit
|
13 |
+
import joblib
|
14 |
+
|
15 |
+
import seaborn as sns
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import copy
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torch.optim as optim
|
24 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
25 |
+
from torch.optim.swa_utils import AveragedModel, SWALR
|
26 |
+
from torch.utils.data import DataLoader, TensorDataset
|
27 |
+
from sklearn.model_selection import KFold
|
28 |
+
from sklearn.model_selection import train_test_split
|
29 |
+
|
30 |
+
from data_processing import split_train_test
|
31 |
+
|
32 |
+
# ignore all warnings
|
33 |
+
import warnings
|
34 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
35 |
+
|
36 |
+
|
37 |
+
class Mlp(nn.Module):
|
38 |
+
def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
|
39 |
+
super().__init__()
|
40 |
+
self.fc1 = nn.Linear(input_features, hidden_features)
|
41 |
+
self.bn1 = nn.BatchNorm1d(hidden_features)
|
42 |
+
self.act1 = act_layer()
|
43 |
+
self.drop1 = nn.Dropout(drop_rate)
|
44 |
+
self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
|
45 |
+
self.act2 = act_layer()
|
46 |
+
self.drop2 = nn.Dropout(drop_rate)
|
47 |
+
self.fc3 = nn.Linear(hidden_features // 2, out_features)
|
48 |
+
|
49 |
+
def forward(self, input_feature):
|
50 |
+
x = self.fc1(input_feature)
|
51 |
+
x = self.bn1(x)
|
52 |
+
x = self.act1(x)
|
53 |
+
x = self.drop1(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.act2(x)
|
56 |
+
x = self.drop2(x)
|
57 |
+
output = self.fc3(x)
|
58 |
+
return output
|
59 |
+
|
60 |
+
|
61 |
+
class MAEAndRankLoss(nn.Module):
|
62 |
+
def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
|
63 |
+
super(MAEAndRankLoss, self).__init__()
|
64 |
+
self.l1_w = l1_w
|
65 |
+
self.rank_w = rank_w
|
66 |
+
self.margin = margin
|
67 |
+
self.use_margin = use_margin
|
68 |
+
|
69 |
+
def forward(self, y_pred, y_true):
|
70 |
+
# L1 loss/MAE loss
|
71 |
+
l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
|
72 |
+
# Rank loss
|
73 |
+
n = y_pred.size(0)
|
74 |
+
pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
|
75 |
+
true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
|
76 |
+
|
77 |
+
# e(ytrue_i, ytrue_j)
|
78 |
+
masks = torch.sign(true_diff)
|
79 |
+
|
80 |
+
if self.use_margin and self.margin > 0:
|
81 |
+
true_diff = true_diff.abs() - self.margin
|
82 |
+
true_diff = F.relu(true_diff)
|
83 |
+
masks = true_diff.sign()
|
84 |
+
|
85 |
+
l_rank = F.relu(true_diff - masks * pred_diff)
|
86 |
+
l_rank = l_rank.sum() / (n * (n - 1))
|
87 |
+
|
88 |
+
loss = l_mae + l_rank * self.rank_w
|
89 |
+
return loss
|
90 |
+
|
91 |
+
def load_data(csv, data, data_name, set_name):
|
92 |
+
try:
|
93 |
+
df = pd.read_csv(csv, skiprows=[], header=None)
|
94 |
+
except Exception as e:
|
95 |
+
logging.error(f'Read CSV file error: {e}')
|
96 |
+
raise
|
97 |
+
|
98 |
+
y_data = df.values[1:, 2].astype(float)
|
99 |
+
y = torch.tensor(y_data, dtype=torch.float32)
|
100 |
+
|
101 |
+
if data_name == 'cross_dataset':
|
102 |
+
y = torch.clamp(y, max=5)
|
103 |
+
if set_name == 'test':
|
104 |
+
print(f"Modified y_true: {y}")
|
105 |
+
X = data
|
106 |
+
return X, y
|
107 |
+
|
108 |
+
def preprocess_data(X, y):
|
109 |
+
X[torch.isnan(X)] = 0
|
110 |
+
X[torch.isinf(X)] = 0
|
111 |
+
|
112 |
+
# MinMaxScaler (use PyTorch implementation)
|
113 |
+
X_min = X.min(dim=0, keepdim=True).values
|
114 |
+
X_max = X.max(dim=0, keepdim=True).values
|
115 |
+
X = (X - X_min) / (X_max - X_min)
|
116 |
+
if y is not None:
|
117 |
+
y = y.view(-1, 1).squeeze()
|
118 |
+
return X, y
|
119 |
+
|
120 |
+
# define 4-parameter logistic regression
|
121 |
+
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
|
122 |
+
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
|
123 |
+
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
|
124 |
+
return yhat
|
125 |
+
|
126 |
+
def fit_logistic_regression(y_pred, y_true):
|
127 |
+
beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
|
128 |
+
popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
|
129 |
+
y_pred_logistic = logistic_func(y_pred, *popt)
|
130 |
+
return y_pred_logistic, beta, popt
|
131 |
+
|
132 |
+
def compute_correlation_metrics(y_true, y_pred):
|
133 |
+
y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
|
134 |
+
|
135 |
+
plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
|
136 |
+
rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
|
137 |
+
srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
|
138 |
+
|
139 |
+
try:
|
140 |
+
krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
|
141 |
+
except Exception as e:
|
142 |
+
logging.error(f'krcc calculation: {e}')
|
143 |
+
krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
|
144 |
+
return y_pred_logistic, plcc, rmse, srcc, krcc
|
145 |
+
|
146 |
+
def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
|
147 |
+
# nonlinear logistic fitted curve / logistic regression
|
148 |
+
if isinstance(y_test, torch.Tensor):
|
149 |
+
mos1 = y_test.numpy()
|
150 |
+
y1 = y_test_pred_logistic
|
151 |
+
|
152 |
+
try:
|
153 |
+
beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
|
154 |
+
popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
|
155 |
+
sigma = np.sqrt(np.diag(pcov))
|
156 |
+
except:
|
157 |
+
raise Exception('Fitting logistic function time-out!!')
|
158 |
+
x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
|
159 |
+
plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
|
160 |
+
|
161 |
+
fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
|
162 |
+
plt.legend(loc='upper left')
|
163 |
+
if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
|
164 |
+
plt.ylim(0, 100)
|
165 |
+
plt.xlim(0, 100)
|
166 |
+
else:
|
167 |
+
plt.ylim(1, 5)
|
168 |
+
plt.xlim(1, 5)
|
169 |
+
plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
|
170 |
+
plt.xlabel('Predicted Score')
|
171 |
+
plt.ylabel('MOS')
|
172 |
+
reg_fig1 = fig1.get_figure()
|
173 |
+
|
174 |
+
fig_path = f'../figs/{data_name}/'
|
175 |
+
os.makedirs(fig_path, exist_ok=True)
|
176 |
+
reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}_kfold.png", dpi=300)
|
177 |
+
plt.clf()
|
178 |
+
plt.close()
|
179 |
+
|
180 |
+
def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
|
181 |
+
plt.figure(figsize=(10, 6))
|
182 |
+
|
183 |
+
plt.plot(avg_train_losses, label='Average Training Loss')
|
184 |
+
plt.plot(avg_val_losses, label='Average Validation Loss')
|
185 |
+
|
186 |
+
plt.xlabel('Epoch')
|
187 |
+
plt.ylabel('Loss')
|
188 |
+
plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
|
189 |
+
|
190 |
+
plt.legend()
|
191 |
+
fig_par_path = f'../log/result/{data_name}/'
|
192 |
+
os.makedirs(fig_par_path, exist_ok=True)
|
193 |
+
plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
|
194 |
+
plt.clf()
|
195 |
+
plt.close()
|
196 |
+
|
197 |
+
def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
|
198 |
+
log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.log")
|
199 |
+
logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
|
200 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
201 |
+
logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
|
202 |
+
logging.info(f"torch cuda: {torch.cuda.is_available()}")
|
203 |
+
|
204 |
+
def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
|
205 |
+
if data_name == 'cross_dataset':
|
206 |
+
data_name1 = 'youtube_ugc_all'
|
207 |
+
data_name2 = 'cvd_2014_all'
|
208 |
+
train_csv = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv')
|
209 |
+
test_csv = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv')
|
210 |
+
train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name1}_train_features.pt')
|
211 |
+
test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name2}_test_features.pt')
|
212 |
+
X_train, y_train = load_data(train_csv, train_data, data_name1, 'train')
|
213 |
+
X_test, y_test = load_data(test_csv, test_data, data_name2, 'test')
|
214 |
+
|
215 |
+
elif data_name == 'lsvq_train':
|
216 |
+
train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
217 |
+
test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
218 |
+
X_train, y_train = load_data(train_csv, train_features, data_name, 'train')
|
219 |
+
X_test, y_test = load_data(test_csv, test_features, data_name, 'test')
|
220 |
+
|
221 |
+
else:
|
222 |
+
train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
223 |
+
test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
224 |
+
train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_train_features.pt')
|
225 |
+
test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_test_features.pt')
|
226 |
+
X_train, y_train = load_data(train_csv, train_data, data_name, 'train')
|
227 |
+
X_test, y_test = load_data(test_csv, test_data, data_name, 'test')
|
228 |
+
|
229 |
+
# standard min-max normalization of training features
|
230 |
+
X_train, y_train = preprocess_data(X_train, y_train)
|
231 |
+
X_test, y_test = preprocess_data(X_test, y_test)
|
232 |
+
|
233 |
+
return X_train, y_train, X_test, y_test
|
234 |
+
|
235 |
+
def train_one_epoch(model, train_loader, criterion, optimizer, device):
|
236 |
+
"""Train the model for one epoch"""
|
237 |
+
model.train()
|
238 |
+
train_loss = 0.0
|
239 |
+
for inputs, targets in train_loader:
|
240 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
241 |
+
|
242 |
+
optimizer.zero_grad()
|
243 |
+
outputs = model(inputs)
|
244 |
+
loss = criterion(outputs, targets.view(-1, 1))
|
245 |
+
loss.backward()
|
246 |
+
optimizer.step()
|
247 |
+
train_loss += loss.item() * inputs.size(0)
|
248 |
+
train_loss /= len(train_loader.dataset)
|
249 |
+
return train_loss
|
250 |
+
|
251 |
+
def evaluate(model, val_loader, criterion, device):
|
252 |
+
"""Evaluate model performance on validation sets"""
|
253 |
+
model.eval()
|
254 |
+
val_loss = 0.0
|
255 |
+
y_val_pred = []
|
256 |
+
y_val_true = []
|
257 |
+
with torch.no_grad():
|
258 |
+
for inputs, targets in val_loader:
|
259 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
260 |
+
|
261 |
+
outputs = model(inputs)
|
262 |
+
y_val_pred.append(outputs)
|
263 |
+
y_val_true.append(targets)
|
264 |
+
loss = criterion(outputs, targets.view(-1, 1))
|
265 |
+
val_loss += loss.item() * inputs.size(0)
|
266 |
+
|
267 |
+
val_loss /= len(val_loader.dataset)
|
268 |
+
y_val_pred = torch.cat(y_val_pred, dim=0)
|
269 |
+
y_val_true = torch.cat(y_val_true, dim=0)
|
270 |
+
return val_loss, y_val_pred, y_val_true
|
271 |
+
|
272 |
+
def update_best_model(select_criteria, best_metric, current_val, model):
|
273 |
+
is_better = False
|
274 |
+
if select_criteria == 'byrmse' and current_val < best_metric:
|
275 |
+
is_better = True
|
276 |
+
elif select_criteria == 'bykrcc' and current_val > best_metric:
|
277 |
+
is_better = True
|
278 |
+
|
279 |
+
if is_better:
|
280 |
+
return current_val, copy.deepcopy(model), is_better
|
281 |
+
return best_metric, model, is_better
|
282 |
+
|
283 |
+
def train_and_evaluate(X_train, y_train, config):
|
284 |
+
# parameters
|
285 |
+
n_repeats = config['n_repeats']
|
286 |
+
n_splits = config['n_splits']
|
287 |
+
batch_size = config['batch_size']
|
288 |
+
epochs = config['epochs']
|
289 |
+
hidden_features = config['hidden_features']
|
290 |
+
drop_rate = config['drop_rate']
|
291 |
+
loss_type = config['loss_type']
|
292 |
+
optimizer_type = config['optimizer_type']
|
293 |
+
select_criteria = config['select_criteria']
|
294 |
+
initial_lr = config['initial_lr']
|
295 |
+
weight_decay = config['weight_decay']
|
296 |
+
patience = config['patience']
|
297 |
+
l1_w = config['l1_w']
|
298 |
+
rank_w = config['rank_w']
|
299 |
+
use_swa = config.get('use_swa', False)
|
300 |
+
logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Number of splits for kfold: {n_splits}, Batch size: {batch_size}, Number of epochs: {epochs}')
|
301 |
+
logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
|
302 |
+
logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
|
303 |
+
logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
|
304 |
+
|
305 |
+
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
|
306 |
+
best_model = None
|
307 |
+
best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
|
308 |
+
|
309 |
+
# loss for every fold
|
310 |
+
all_train_losses = []
|
311 |
+
all_val_losses = []
|
312 |
+
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
|
313 |
+
print(f"Fold {fold + 1}/{n_splits}")
|
314 |
+
|
315 |
+
X_train_fold, X_val_fold = X_train[train_idx], X_train[val_idx]
|
316 |
+
y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]
|
317 |
+
|
318 |
+
# initialisation of model, loss function, optimiser
|
319 |
+
model = Mlp(input_features=X_train_fold.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
|
320 |
+
model = model.to(device) # to gpu
|
321 |
+
|
322 |
+
if loss_type == 'MAERankLoss':
|
323 |
+
criterion = MAEAndRankLoss()
|
324 |
+
criterion.l1_w = l1_w
|
325 |
+
criterion.rank_w = rank_w
|
326 |
+
else:
|
327 |
+
criterion = nn.MSELoss()
|
328 |
+
|
329 |
+
if optimizer_type == 'sgd':
|
330 |
+
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
|
331 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_nim=1e-5
|
332 |
+
else:
|
333 |
+
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
|
334 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
|
335 |
+
if use_swa:
|
336 |
+
swa_model = AveragedModel(model).to(device)
|
337 |
+
swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
|
338 |
+
|
339 |
+
# dataset loader
|
340 |
+
train_dataset = TensorDataset(X_train_fold, y_train_fold)
|
341 |
+
val_dataset = TensorDataset(X_val_fold, y_val_fold)
|
342 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
|
343 |
+
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
|
344 |
+
|
345 |
+
train_losses, val_losses = [], []
|
346 |
+
|
347 |
+
# early stopping parameters
|
348 |
+
best_val_loss = float('inf')
|
349 |
+
epochs_no_improve = 0
|
350 |
+
early_stop_active = False
|
351 |
+
swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
|
352 |
+
|
353 |
+
for epoch in range(epochs):
|
354 |
+
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
|
355 |
+
train_losses.append(train_loss)
|
356 |
+
scheduler.step() # update learning rate
|
357 |
+
if use_swa and epoch >= swa_start:
|
358 |
+
swa_model.update_parameters(model)
|
359 |
+
swa_scheduler.step()
|
360 |
+
early_stop_active = True
|
361 |
+
print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
|
362 |
+
|
363 |
+
lr = optimizer.param_groups[0]['lr']
|
364 |
+
print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
|
365 |
+
|
366 |
+
# decide which model to evaluate: SWA model or regular model
|
367 |
+
current_model = swa_model if use_swa and epoch >= swa_start else model
|
368 |
+
current_model.eval()
|
369 |
+
val_loss, y_val_pred, y_val_true = evaluate(current_model, val_loader, criterion, device)
|
370 |
+
val_losses.append(val_loss)
|
371 |
+
print(f"Epoch {epoch + 1}, Fold {fold + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
|
372 |
+
|
373 |
+
y_val_pred = torch.cat([pred for pred in y_val_pred])
|
374 |
+
_, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
|
375 |
+
current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
|
376 |
+
best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
|
377 |
+
if is_better:
|
378 |
+
logging.info(f"Epoch {epoch + 1}, Fold {fold + 1}:")
|
379 |
+
y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
|
380 |
+
logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
|
381 |
+
|
382 |
+
X_train_fold_tensor = X_train_fold
|
383 |
+
y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().squeeze()
|
384 |
+
y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train_fold.cpu().numpy(), y_tra_pred_tmp.cpu().numpy())
|
385 |
+
logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
|
386 |
+
|
387 |
+
# check for loss improvement
|
388 |
+
if early_stop_active:
|
389 |
+
if val_loss < best_val_loss:
|
390 |
+
best_val_loss = val_loss
|
391 |
+
# save the best model if validation loss improves
|
392 |
+
best_model = copy.deepcopy(model)
|
393 |
+
epochs_no_improve = 0
|
394 |
+
else:
|
395 |
+
epochs_no_improve += 1
|
396 |
+
if epochs_no_improve >= patience:
|
397 |
+
# epochs to wait for improvement before stopping
|
398 |
+
print(f"Early stopping triggered after {epoch + 1} epochs.")
|
399 |
+
break
|
400 |
+
|
401 |
+
# saving SWA models and updating BN statistics
|
402 |
+
if use_swa:
|
403 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
|
404 |
+
best_model = best_model.to(device)
|
405 |
+
best_model.eval()
|
406 |
+
torch.optim.swa_utils.update_bn(train_loader, best_model)
|
407 |
+
# swa_model_path = os.path.join('save_swa_path='../model/', f'model_swa_fold{fold}.pth')
|
408 |
+
# torch.save(swa_model.state_dict(), swa_model_path)
|
409 |
+
# logging.info(f'SWA model saved at {swa_model_path}')
|
410 |
+
|
411 |
+
all_train_losses.append(train_losses)
|
412 |
+
all_val_losses.append(val_losses)
|
413 |
+
max_length = max(len(x) for x in all_train_losses)
|
414 |
+
all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
|
415 |
+
max_length = max(len(x) for x in all_val_losses)
|
416 |
+
all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
|
417 |
+
|
418 |
+
return best_model, all_train_losses, all_val_losses
|
419 |
+
|
420 |
+
def collate_to_device(batch, device):
|
421 |
+
data, targets = zip(*batch)
|
422 |
+
return torch.stack(data).to(device), torch.stack(targets).to(device)
|
423 |
+
|
424 |
+
def model_test(best_model, X, y, device):
|
425 |
+
test_dataset = TensorDataset(X, y)
|
426 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
427 |
+
|
428 |
+
best_model.eval()
|
429 |
+
y_pred = []
|
430 |
+
with torch.no_grad():
|
431 |
+
for inputs, _ in test_loader:
|
432 |
+
inputs = inputs.to(device)
|
433 |
+
|
434 |
+
outputs = best_model(inputs)
|
435 |
+
y_pred.extend(outputs.view(-1).tolist())
|
436 |
+
|
437 |
+
return y_pred
|
438 |
+
|
439 |
+
def main(config):
|
440 |
+
model_name = config['model_name']
|
441 |
+
data_name = config['data_name']
|
442 |
+
network_name = config['network_name']
|
443 |
+
|
444 |
+
metadata_path = config['metadata_path']
|
445 |
+
feature_path = config['feature_path']
|
446 |
+
log_path = config['log_path']
|
447 |
+
save_path = config['save_path']
|
448 |
+
score_path = config['score_path']
|
449 |
+
result_path = config['result_path']
|
450 |
+
|
451 |
+
# parameters
|
452 |
+
select_criteria = config['select_criteria']
|
453 |
+
n_repeats = config['n_repeats']
|
454 |
+
|
455 |
+
# logging and result
|
456 |
+
os.makedirs(log_path, exist_ok=True)
|
457 |
+
os.makedirs(save_path, exist_ok=True)
|
458 |
+
os.makedirs(score_path, exist_ok=True)
|
459 |
+
os.makedirs(result_path, exist_ok=True)
|
460 |
+
result_file = f'{result_path}{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.mat'
|
461 |
+
pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{model_name}_Predicted_Score_{select_criteria}_kfold.csv")
|
462 |
+
file_path = os.path.join(save_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}_trained_median_model_param_kfold.pth")
|
463 |
+
configure_logging(log_path, model_name, data_name, network_name, select_criteria)
|
464 |
+
|
465 |
+
'''======================== Main Body ==========================='''
|
466 |
+
PLCC_all_repeats_test = []
|
467 |
+
SRCC_all_repeats_test = []
|
468 |
+
KRCC_all_repeats_test = []
|
469 |
+
RMSE_all_repeats_test = []
|
470 |
+
PLCC_all_repeats_train = []
|
471 |
+
SRCC_all_repeats_train = []
|
472 |
+
KRCC_all_repeats_train = []
|
473 |
+
RMSE_all_repeats_train = []
|
474 |
+
all_repeats_test_vids = []
|
475 |
+
all_repeats_df_test_pred = []
|
476 |
+
best_model_list = []
|
477 |
+
|
478 |
+
for i in range(1, n_repeats + 1):
|
479 |
+
print(f"{i}th repeated 80-20 hold out test")
|
480 |
+
logging.info(f"{i}th repeated 80-20 hold out test")
|
481 |
+
t0 = time.time()
|
482 |
+
|
483 |
+
# train test split
|
484 |
+
test_size = 0.2
|
485 |
+
random_state = math.ceil(8.8 * i)
|
486 |
+
# NR: original
|
487 |
+
if data_name == 'lsvq_train':
|
488 |
+
test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
|
489 |
+
train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
|
490 |
+
elif data_name == 'cross_dataset':
|
491 |
+
train_data_name = 'youtube_ugc_all'
|
492 |
+
test_data_name = 'cvd_2014_all'
|
493 |
+
_, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name)
|
494 |
+
else:
|
495 |
+
_, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
|
496 |
+
|
497 |
+
'''======================== read files =============================== '''
|
498 |
+
if data_name == 'lsvq_train':
|
499 |
+
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
|
500 |
+
else:
|
501 |
+
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
|
502 |
+
|
503 |
+
'''======================== regression model =============================== '''
|
504 |
+
best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
|
505 |
+
|
506 |
+
# average loss plots
|
507 |
+
avg_train_losses = np.mean(all_train_losses, axis=0)
|
508 |
+
avg_val_losses = np.mean(all_val_losses, axis=0)
|
509 |
+
test_vids = test_vids.tolist()
|
510 |
+
plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
|
511 |
+
|
512 |
+
# predict best model on the train dataset
|
513 |
+
y_train_pred = model_test(best_model, X_train, y_train, device)
|
514 |
+
y_train_pred = torch.tensor(list(y_train_pred), dtype=torch.float32)
|
515 |
+
y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train.cpu().numpy(), y_train_pred.cpu().numpy())
|
516 |
+
|
517 |
+
# test best model on the test dataset
|
518 |
+
y_test_pred = model_test(best_model, X_test, y_test, device)
|
519 |
+
y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
|
520 |
+
y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test.cpu().numpy(), y_test_pred.cpu().numpy())
|
521 |
+
|
522 |
+
# save the predict score results
|
523 |
+
test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
|
524 |
+
df_test_pred = pd.DataFrame(test_pred_score)
|
525 |
+
|
526 |
+
# logging logistic predicted scores
|
527 |
+
logging.info("============================================================================================================")
|
528 |
+
SRCC_all_repeats_test.append(srcc_test)
|
529 |
+
KRCC_all_repeats_test.append(krcc_test)
|
530 |
+
PLCC_all_repeats_test.append(plcc_test)
|
531 |
+
RMSE_all_repeats_test.append(rmse_test)
|
532 |
+
SRCC_all_repeats_train.append(srcc_train)
|
533 |
+
KRCC_all_repeats_train.append(krcc_train)
|
534 |
+
PLCC_all_repeats_train.append(plcc_train)
|
535 |
+
RMSE_all_repeats_train.append(rmse_train)
|
536 |
+
all_repeats_test_vids.append(test_vids)
|
537 |
+
all_repeats_df_test_pred.append(df_test_pred)
|
538 |
+
best_model_list.append(copy.deepcopy(best_model))
|
539 |
+
|
540 |
+
# logging.info results for each iteration
|
541 |
+
logging.info('Best results in Mlp model within one split')
|
542 |
+
logging.info(f'MODEL: {best_model}')
|
543 |
+
logging.info('======================================================')
|
544 |
+
logging.info(f'Train set - Evaluation Results')
|
545 |
+
logging.info(f'SRCC_train: {srcc_train}')
|
546 |
+
logging.info(f'KRCC_train: {krcc_train}')
|
547 |
+
logging.info(f'PLCC_train: {plcc_train}')
|
548 |
+
logging.info(f'RMSE_train: {rmse_train}')
|
549 |
+
logging.info('======================================================')
|
550 |
+
logging.info(f'Test set - Evaluation Results')
|
551 |
+
logging.info(f'SRCC_test: {srcc_test}')
|
552 |
+
logging.info(f'KRCC_test: {krcc_test}')
|
553 |
+
logging.info(f'PLCC_test: {plcc_test}')
|
554 |
+
logging.info(f'RMSE_test: {rmse_test}')
|
555 |
+
logging.info('======================================================')
|
556 |
+
logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
|
557 |
+
|
558 |
+
logging.info('')
|
559 |
+
SRCC_all_repeats_test = torch.tensor(SRCC_all_repeats_test, dtype=torch.float32)
|
560 |
+
KRCC_all_repeats_test = torch.tensor(KRCC_all_repeats_test, dtype=torch.float32)
|
561 |
+
PLCC_all_repeats_test = torch.tensor(PLCC_all_repeats_test, dtype=torch.float32)
|
562 |
+
RMSE_all_repeats_test = torch.tensor(RMSE_all_repeats_test, dtype=torch.float32)
|
563 |
+
SRCC_all_repeats_train = torch.tensor(SRCC_all_repeats_train, dtype=torch.float32)
|
564 |
+
KRCC_all_repeats_train = torch.tensor(KRCC_all_repeats_train, dtype=torch.float32)
|
565 |
+
PLCC_all_repeats_train = torch.tensor(PLCC_all_repeats_train, dtype=torch.float32)
|
566 |
+
RMSE_all_repeats_train = torch.tensor(RMSE_all_repeats_train, dtype=torch.float32)
|
567 |
+
|
568 |
+
logging.info('======================================================')
|
569 |
+
logging.info('Average training results among all repeated 80-20 holdouts:')
|
570 |
+
logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_train).item(), torch.std(SRCC_all_repeats_train).item())
|
571 |
+
logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_train).item(), torch.std(KRCC_all_repeats_train).item())
|
572 |
+
logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_train).item(), torch.std(PLCC_all_repeats_train).item())
|
573 |
+
logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_train).item(), torch.std(RMSE_all_repeats_train).item())
|
574 |
+
logging.info('======================================================')
|
575 |
+
logging.info('Average testing results among all repeated 80-20 holdouts:')
|
576 |
+
logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_test).item(), torch.std(SRCC_all_repeats_test).item())
|
577 |
+
logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_test).item(), torch.std(KRCC_all_repeats_test).item())
|
578 |
+
logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_test).item(), torch.std(PLCC_all_repeats_test).item())
|
579 |
+
logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_test).item(), torch.std(RMSE_all_repeats_test).item())
|
580 |
+
logging.info('======================================================')
|
581 |
+
logging.info('\n')
|
582 |
+
|
583 |
+
# find the median model and the index of the median
|
584 |
+
print('======================================================')
|
585 |
+
if select_criteria == 'byrmse':
|
586 |
+
median_metrics = torch.median(RMSE_all_repeats_test).item()
|
587 |
+
indices = (RMSE_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
|
588 |
+
select_criteria = select_criteria.replace('by', '').upper()
|
589 |
+
print(RMSE_all_repeats_test)
|
590 |
+
logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
|
591 |
+
elif select_criteria == 'bykrcc':
|
592 |
+
median_metrics = torch.median(KRCC_all_repeats_test).item()
|
593 |
+
indices = (KRCC_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
|
594 |
+
select_criteria = select_criteria.replace('by', '').upper()
|
595 |
+
print(KRCC_all_repeats_test)
|
596 |
+
logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
|
597 |
+
|
598 |
+
median_test_vids = [all_repeats_test_vids[i] for i in indices]
|
599 |
+
test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
|
600 |
+
|
601 |
+
# select the model with the first index where the median is located
|
602 |
+
# Note: If there are multiple iterations with the same median RMSE, the first index is selected here
|
603 |
+
median_model = None
|
604 |
+
if len(indices) > 0:
|
605 |
+
median_index = indices[0] # select the first index
|
606 |
+
median_model = best_model_list[median_index]
|
607 |
+
median_model_df_test_pred = all_repeats_df_test_pred[median_index]
|
608 |
+
|
609 |
+
median_model_df_test_pred.to_csv(pred_score_filename, index=False)
|
610 |
+
plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
|
611 |
+
|
612 |
+
print(f'Median Metrics: {median_metrics}')
|
613 |
+
print(f'Indices: {indices}')
|
614 |
+
# print(f'Test Videos: {test_vids}')
|
615 |
+
print(f'Best model: {median_model}')
|
616 |
+
|
617 |
+
logging.info(f'median test {select_criteria}: {median_metrics}')
|
618 |
+
logging.info(f"Indices of median metrics: {indices}")
|
619 |
+
# logging.info(f'Best training and test dataset: {test_vids}')
|
620 |
+
logging.info(f'Best model predict score: {median_model_df_test_pred}')
|
621 |
+
logging.info(f'Best model: {median_model}')
|
622 |
+
|
623 |
+
# ================================================================================
|
624 |
+
# save mats
|
625 |
+
scipy.io.savemat(result_file, mdict={'SRCC_train': SRCC_all_repeats_train.numpy(),
|
626 |
+
'KRCC_train': KRCC_all_repeats_train.numpy(),
|
627 |
+
'PLCC_train': PLCC_all_repeats_train.numpy(),
|
628 |
+
'RMSE_train': RMSE_all_repeats_train.numpy(),
|
629 |
+
'SRCC_test': SRCC_all_repeats_test.numpy(),
|
630 |
+
'KRCC_test': KRCC_all_repeats_test.numpy(),
|
631 |
+
'PLCC_test': PLCC_all_repeats_test.numpy(),
|
632 |
+
'RMSE_test': RMSE_all_repeats_test.numpy(),
|
633 |
+
f'Median_{select_criteria}': median_metrics,
|
634 |
+
'Test_Videos_list': all_repeats_test_vids,
|
635 |
+
'Test_videos_Median_model': test_vids})
|
636 |
+
|
637 |
+
# save model
|
638 |
+
torch.save(median_model.state_dict(), file_path)
|
639 |
+
print(f"Model state_dict saved to {file_path}")
|
640 |
+
|
641 |
+
|
642 |
+
if __name__ == '__main__':
|
643 |
+
parser = argparse.ArgumentParser()
|
644 |
+
# input parameters
|
645 |
+
parser.add_argument('--model_name', type=str, default='Mlp')
|
646 |
+
parser.add_argument('--data_name', type=str, default='lsvq_train', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, live_qualcomm, lsvq_train, cross_dataset')
|
647 |
+
parser.add_argument('--network_name', type=str, default='diva-vqa_large' , help='diva-vqa')
|
648 |
+
|
649 |
+
parser.add_argument('--metadata_path', type=str, default='../metadata/')
|
650 |
+
parser.add_argument('--feature_path', type=str, default=f'../features/diva-vqa/diva-vqa_large/')
|
651 |
+
parser.add_argument('--log_path', type=str, default='../log/')
|
652 |
+
parser.add_argument('--save_path', type=str, default='../model/')
|
653 |
+
parser.add_argument('--score_path', type=str, default='../log/predict_score/')
|
654 |
+
parser.add_argument('--result_path', type=str, default='../log/result/')
|
655 |
+
# training parameters
|
656 |
+
parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc')
|
657 |
+
parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test')
|
658 |
+
parser.add_argument('--n_splits', type=int, default=10, help='Number of splits for k-fold validation')
|
659 |
+
parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
|
660 |
+
parser.add_argument('--epochs', type=int, default=50, help='Epochs for training')
|
661 |
+
parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features')
|
662 |
+
parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.')
|
663 |
+
# misc
|
664 |
+
parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss')
|
665 |
+
parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd')
|
666 |
+
parser.add_argument('--initial_lr', type=float, default=1e-1, help='Initial learning rate: 1e-2')
|
667 |
+
parser.add_argument('--weight_decay', type=float, default=0.005, help='Weight decay (L2 loss): 1e-4')
|
668 |
+
parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.')
|
669 |
+
parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging')
|
670 |
+
parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight')
|
671 |
+
parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight')
|
672 |
+
|
673 |
+
args = parser.parse_args()
|
674 |
+
config = vars(args) # args to dict
|
675 |
+
print(config)
|
676 |
+
|
677 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
678 |
+
print(device)
|
679 |
+
if device.type == "cuda":
|
680 |
+
torch.cuda.set_device(0)
|
681 |
+
|
682 |
+
main(config)
|
requirements.txt
ADDED
Binary file (428 Bytes). View file
|
|