|  | import logging | 
					
						
						|  | import time | 
					
						
						|  | import os | 
					
						
						|  | import pandas as pd | 
					
						
						|  | import numpy as np | 
					
						
						|  | import math | 
					
						
						|  | import scipy.io | 
					
						
						|  | import scipy.stats | 
					
						
						|  | from sklearn.impute import SimpleImputer | 
					
						
						|  | from sklearn.preprocessing import StandardScaler, MinMaxScaler | 
					
						
						|  | from sklearn.metrics import mean_squared_error | 
					
						
						|  | from scipy.optimize import curve_fit | 
					
						
						|  | import joblib | 
					
						
						|  |  | 
					
						
						|  | import seaborn as sns | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | import copy | 
					
						
						|  | import argparse | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import torch.optim as optim | 
					
						
						|  | from torch.optim.lr_scheduler import CosineAnnealingLR | 
					
						
						|  | from torch.optim.swa_utils import AveragedModel, SWALR | 
					
						
						|  | from torch.utils.data import DataLoader, TensorDataset | 
					
						
						|  | from sklearn.model_selection import train_test_split | 
					
						
						|  |  | 
					
						
						|  | from data_processing import split_train_test | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import warnings | 
					
						
						|  | warnings.filterwarnings("ignore", category=DeprecationWarning) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Mlp(nn.Module): | 
					
						
						|  | def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.fc1 = nn.Linear(input_features, hidden_features) | 
					
						
						|  |  | 
					
						
						|  | self.act1 = act_layer() | 
					
						
						|  | self.drop1 = nn.Dropout(drop_rate) | 
					
						
						|  | self.fc2 = nn.Linear(hidden_features, hidden_features // 2) | 
					
						
						|  | self.act2 = act_layer() | 
					
						
						|  | self.drop2 = nn.Dropout(drop_rate) | 
					
						
						|  | self.fc3 = nn.Linear(hidden_features // 2, out_features) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_feature): | 
					
						
						|  | x = self.fc1(input_feature) | 
					
						
						|  |  | 
					
						
						|  | x = self.act1(x) | 
					
						
						|  | x = self.drop1(x) | 
					
						
						|  | x = self.fc2(x) | 
					
						
						|  | x = self.act2(x) | 
					
						
						|  | x = self.drop2(x) | 
					
						
						|  | output = self.fc3(x) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MAEAndRankLoss(nn.Module): | 
					
						
						|  | def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False): | 
					
						
						|  | super(MAEAndRankLoss, self).__init__() | 
					
						
						|  | self.l1_w = l1_w | 
					
						
						|  | self.rank_w = rank_w | 
					
						
						|  | self.margin = margin | 
					
						
						|  | self.use_margin = use_margin | 
					
						
						|  |  | 
					
						
						|  | def forward(self, y_pred, y_true): | 
					
						
						|  |  | 
					
						
						|  | l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w | 
					
						
						|  |  | 
					
						
						|  | n = y_pred.size(0) | 
					
						
						|  | pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0) | 
					
						
						|  | true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | masks = torch.sign(true_diff) | 
					
						
						|  |  | 
					
						
						|  | if self.use_margin and self.margin > 0: | 
					
						
						|  | true_diff = true_diff.abs() - self.margin | 
					
						
						|  | true_diff = F.relu(true_diff) | 
					
						
						|  | masks = true_diff.sign() | 
					
						
						|  |  | 
					
						
						|  | l_rank = F.relu(true_diff - masks * pred_diff) | 
					
						
						|  | l_rank = l_rank.sum() / (n * (n - 1)) | 
					
						
						|  |  | 
					
						
						|  | loss = l_mae + l_rank * self.rank_w | 
					
						
						|  | return loss | 
					
						
						|  |  | 
					
						
						|  | def load_data(csv_file, mat_file, features, data_name, set_name): | 
					
						
						|  | try: | 
					
						
						|  | df = pd.read_csv(csv_file, skiprows=[], header=None) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error(f'Read CSV file error: {e}') | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if data_name == 'lsvq_train': | 
					
						
						|  | X_mat = features | 
					
						
						|  | else: | 
					
						
						|  | X_mat = scipy.io.loadmat(mat_file) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error(f'Read MAT file error: {e}') | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | y_data = df.values[1:, 2] | 
					
						
						|  | y = np.array(list(y_data), dtype=float) | 
					
						
						|  |  | 
					
						
						|  | if data_name == 'cross_dataset': | 
					
						
						|  | y[y > 5] = 5 | 
					
						
						|  | if set_name == 'test': | 
					
						
						|  | print(f"Modified y_true: {y}") | 
					
						
						|  | if data_name == 'lsvq_train': | 
					
						
						|  | X = np.asarray(X_mat, dtype=float) | 
					
						
						|  | else: | 
					
						
						|  | data_name = f'{data_name}_{set_name}_features' | 
					
						
						|  | X = np.asarray(X_mat[data_name], dtype=float) | 
					
						
						|  |  | 
					
						
						|  | return X, y | 
					
						
						|  |  | 
					
						
						|  | def preprocess_data(X, y): | 
					
						
						|  | X[np.isnan(X)] = 0 | 
					
						
						|  | X[np.isinf(X)] = 0 | 
					
						
						|  | imp = SimpleImputer(missing_values=np.nan, strategy='mean').fit(X) | 
					
						
						|  | X = imp.transform(X) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scaler = MinMaxScaler().fit(X) | 
					
						
						|  | X = scaler.transform(X) | 
					
						
						|  | logging.info(f'Scaler: {scaler}') | 
					
						
						|  |  | 
					
						
						|  | y = y.reshape(-1, 1).squeeze() | 
					
						
						|  | return X, y, imp, scaler | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): | 
					
						
						|  | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) | 
					
						
						|  | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) | 
					
						
						|  | return yhat | 
					
						
						|  |  | 
					
						
						|  | def fit_logistic_regression(y_pred, y_true): | 
					
						
						|  | beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5] | 
					
						
						|  | popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000) | 
					
						
						|  | y_pred_logistic = logistic_func(y_pred, *popt) | 
					
						
						|  | return y_pred_logistic, beta, popt | 
					
						
						|  |  | 
					
						
						|  | def compute_correlation_metrics(y_true, y_pred): | 
					
						
						|  | y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true) | 
					
						
						|  |  | 
					
						
						|  | plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0] | 
					
						
						|  | rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic)) | 
					
						
						|  | srcc = scipy.stats.spearmanr(y_true, y_pred)[0] | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | krcc = scipy.stats.kendalltau(y_true, y_pred)[0] | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error(f'krcc calculation: {e}') | 
					
						
						|  | krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0] | 
					
						
						|  | return y_pred_logistic, plcc, rmse, srcc, krcc | 
					
						
						|  |  | 
					
						
						|  | def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria): | 
					
						
						|  |  | 
					
						
						|  | mos1 = y_test | 
					
						
						|  | y1 = y_test_pred_logistic | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5] | 
					
						
						|  | popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000) | 
					
						
						|  | sigma = np.sqrt(np.diag(pcov)) | 
					
						
						|  | except: | 
					
						
						|  | raise Exception('Fitting logistic function time-out!!') | 
					
						
						|  | x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1)) | 
					
						
						|  | plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)') | 
					
						
						|  |  | 
					
						
						|  | fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name) | 
					
						
						|  | plt.legend(loc='upper left') | 
					
						
						|  | if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train': | 
					
						
						|  | plt.ylim(0, 100) | 
					
						
						|  | plt.xlim(0, 100) | 
					
						
						|  | else: | 
					
						
						|  | plt.ylim(1, 5) | 
					
						
						|  | plt.xlim(1, 5) | 
					
						
						|  | plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10) | 
					
						
						|  | plt.xlabel('Predicted Score') | 
					
						
						|  | plt.ylabel('MOS') | 
					
						
						|  | reg_fig1 = fig1.get_figure() | 
					
						
						|  |  | 
					
						
						|  | fig_path = f'../figs/{data_name}/' | 
					
						
						|  | os.makedirs(fig_path, exist_ok=True) | 
					
						
						|  | reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}.png", dpi=300) | 
					
						
						|  | plt.clf() | 
					
						
						|  | plt.close() | 
					
						
						|  |  | 
					
						
						|  | def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i): | 
					
						
						|  | plt.figure(figsize=(10, 6)) | 
					
						
						|  |  | 
					
						
						|  | plt.plot(avg_train_losses, label='Average Training Loss') | 
					
						
						|  | plt.plot(avg_val_losses, label='Average Validation Loss') | 
					
						
						|  |  | 
					
						
						|  | plt.xlabel('Epoch') | 
					
						
						|  | plt.ylabel('Loss') | 
					
						
						|  | plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10) | 
					
						
						|  |  | 
					
						
						|  | plt.legend() | 
					
						
						|  | fig_par_path = f'../log/result/{data_name}/' | 
					
						
						|  | os.makedirs(fig_par_path, exist_ok=True) | 
					
						
						|  | plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50) | 
					
						
						|  | plt.clf() | 
					
						
						|  | plt.close() | 
					
						
						|  |  | 
					
						
						|  | def configure_logging(log_path, model_name, data_name, network_name, select_criteria): | 
					
						
						|  | log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}.log") | 
					
						
						|  | logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s') | 
					
						
						|  | logging.getLogger('matplotlib').setLevel(logging.WARNING) | 
					
						
						|  | logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}") | 
					
						
						|  | logging.info(f"torch cuda: {torch.cuda.is_available()}") | 
					
						
						|  |  | 
					
						
						|  | def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features): | 
					
						
						|  | if data_name == 'cross_dataset': | 
					
						
						|  | data_name1 = 'youtube_ugc_all' | 
					
						
						|  | data_name2 = 'cvd_2014_all' | 
					
						
						|  | csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv') | 
					
						
						|  | csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv') | 
					
						
						|  | mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name1}_{network_name}_train_features.mat') | 
					
						
						|  | mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name2}_{network_name}_test_features.mat') | 
					
						
						|  | X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name1, 'train') | 
					
						
						|  | X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name2, 'test') | 
					
						
						|  |  | 
					
						
						|  | elif data_name == 'lsvq_train': | 
					
						
						|  | csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv') | 
					
						
						|  | csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv') | 
					
						
						|  | X_train, y_train = load_data(csv_train_file, None, train_features, data_name, 'train') | 
					
						
						|  | X_test, y_test = load_data(csv_test_file, None, test_features, data_name, 'test') | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv') | 
					
						
						|  | csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv') | 
					
						
						|  | mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_train_features.mat') | 
					
						
						|  | mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_test_features.mat') | 
					
						
						|  | X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name, 'train') | 
					
						
						|  | X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name, 'test') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | X_train, y_train, _, _ = preprocess_data(X_train, y_train) | 
					
						
						|  | X_test, y_test, _, _ = preprocess_data(X_test, y_test) | 
					
						
						|  |  | 
					
						
						|  | return X_train, y_train, X_test, y_test | 
					
						
						|  |  | 
					
						
						|  | def train_one_epoch(model, train_loader, criterion, optimizer, device): | 
					
						
						|  | """Train the model for one epoch""" | 
					
						
						|  | model.train() | 
					
						
						|  | train_loss = 0.0 | 
					
						
						|  | for inputs, targets in train_loader: | 
					
						
						|  | inputs, targets = inputs.to(device), targets.to(device) | 
					
						
						|  |  | 
					
						
						|  | optimizer.zero_grad() | 
					
						
						|  | outputs = model(inputs) | 
					
						
						|  | loss = criterion(outputs, targets.view(-1, 1)) | 
					
						
						|  | loss.backward() | 
					
						
						|  | optimizer.step() | 
					
						
						|  | train_loss += loss.item() * inputs.size(0) | 
					
						
						|  | train_loss /= len(train_loader.dataset) | 
					
						
						|  | return train_loss | 
					
						
						|  |  | 
					
						
						|  | def evaluate(model, val_loader, criterion, device): | 
					
						
						|  | """Evaluate model performance on validation sets""" | 
					
						
						|  | model.eval() | 
					
						
						|  | val_loss = 0.0 | 
					
						
						|  | y_val_pred = [] | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for inputs, targets in val_loader: | 
					
						
						|  | inputs, targets = inputs.to(device), targets.to(device) | 
					
						
						|  |  | 
					
						
						|  | outputs = model(inputs) | 
					
						
						|  | y_val_pred.extend(outputs.view(-1).tolist()) | 
					
						
						|  | loss = criterion(outputs, targets.view(-1, 1)) | 
					
						
						|  | val_loss += loss.item() * inputs.size(0) | 
					
						
						|  | val_loss /= len(val_loader.dataset) | 
					
						
						|  | return val_loss, np.array(y_val_pred) | 
					
						
						|  |  | 
					
						
						|  | def update_best_model(select_criteria, best_metric, current_val, model): | 
					
						
						|  | is_better = False | 
					
						
						|  | if select_criteria == 'byrmse' and current_val < best_metric: | 
					
						
						|  | is_better = True | 
					
						
						|  | elif select_criteria == 'bykrcc' and current_val > best_metric: | 
					
						
						|  | is_better = True | 
					
						
						|  |  | 
					
						
						|  | if is_better: | 
					
						
						|  | return current_val, copy.deepcopy(model), is_better | 
					
						
						|  | return best_metric, model, is_better | 
					
						
						|  |  | 
					
						
						|  | def train_and_evaluate(X_train, y_train, config): | 
					
						
						|  |  | 
					
						
						|  | n_repeats = config['n_repeats'] | 
					
						
						|  | batch_size = config['batch_size'] | 
					
						
						|  | epochs = config['epochs'] | 
					
						
						|  | hidden_features = config['hidden_features'] | 
					
						
						|  | drop_rate = config['drop_rate'] | 
					
						
						|  | loss_type = config['loss_type'] | 
					
						
						|  | optimizer_type = config['optimizer_type'] | 
					
						
						|  | select_criteria = config['select_criteria'] | 
					
						
						|  | initial_lr = config['initial_lr'] | 
					
						
						|  | weight_decay = config['weight_decay'] | 
					
						
						|  | patience = config['patience'] | 
					
						
						|  | l1_w = config['l1_w'] | 
					
						
						|  | rank_w = config['rank_w'] | 
					
						
						|  | use_swa = config.get('use_swa', False) | 
					
						
						|  | logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Batch size: {batch_size}, Number of epochs: {epochs}') | 
					
						
						|  | logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}') | 
					
						
						|  | logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}') | 
					
						
						|  | logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) | 
					
						
						|  |  | 
					
						
						|  | best_model = None | 
					
						
						|  | best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_train_losses = [] | 
					
						
						|  | all_val_losses = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = Mlp(input_features=X_train.shape[1], hidden_features=hidden_features, drop_rate=drop_rate) | 
					
						
						|  | model = model.to(device) | 
					
						
						|  |  | 
					
						
						|  | if loss_type == 'MAERankLoss': | 
					
						
						|  | criterion = MAEAndRankLoss() | 
					
						
						|  | criterion.l1_w = l1_w | 
					
						
						|  | criterion.rank_w = rank_w | 
					
						
						|  | else: | 
					
						
						|  | nn.MSELoss() | 
					
						
						|  |  | 
					
						
						|  | if optimizer_type == 'sgd': | 
					
						
						|  | optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay) | 
					
						
						|  | scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5) | 
					
						
						|  | else: | 
					
						
						|  | optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) | 
					
						
						|  | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) | 
					
						
						|  | if use_swa: | 
					
						
						|  | swa_model = AveragedModel(model).to(device) | 
					
						
						|  | swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train)) | 
					
						
						|  | val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val)) | 
					
						
						|  | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) | 
					
						
						|  | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False) | 
					
						
						|  |  | 
					
						
						|  | train_losses, val_losses = [], [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | best_val_loss = float('inf') | 
					
						
						|  | epochs_no_improve = 0 | 
					
						
						|  | early_stop_active = False | 
					
						
						|  | swa_start = int(epochs * 0.7) if use_swa else epochs | 
					
						
						|  |  | 
					
						
						|  | for epoch in range(epochs): | 
					
						
						|  | train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) | 
					
						
						|  | train_losses.append(train_loss) | 
					
						
						|  | scheduler.step() | 
					
						
						|  | if use_swa and epoch >= swa_start: | 
					
						
						|  | swa_model.update_parameters(model) | 
					
						
						|  | swa_scheduler.step() | 
					
						
						|  | early_stop_active = True | 
					
						
						|  | print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}") | 
					
						
						|  |  | 
					
						
						|  | lr = optimizer.param_groups[0]['lr'] | 
					
						
						|  | print('Epoch %d: Learning rate: %f' % (epoch + 1, lr)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | current_model = swa_model if use_swa and epoch >= swa_start else model | 
					
						
						|  | current_model.eval() | 
					
						
						|  | val_loss, y_val_pred = evaluate(current_model, val_loader, criterion, device) | 
					
						
						|  | val_losses.append(val_loss) | 
					
						
						|  | print(f"Epoch {epoch + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}") | 
					
						
						|  |  | 
					
						
						|  | y_val_pred = np.array(list(y_val_pred), dtype=float) | 
					
						
						|  | _, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val, y_val_pred) | 
					
						
						|  | current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val | 
					
						
						|  | best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model) | 
					
						
						|  | if is_better: | 
					
						
						|  | logging.info(f"Epoch {epoch + 1}:") | 
					
						
						|  | y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val, y_val_pred) | 
					
						
						|  | logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}') | 
					
						
						|  |  | 
					
						
						|  | X_train_fold_tensor = torch.FloatTensor(X_train).to(device) | 
					
						
						|  | y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().numpy().squeeze() | 
					
						
						|  | y_tra_pred_tmp = np.array(list(y_tra_pred_tmp), dtype=float) | 
					
						
						|  | y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train, y_tra_pred_tmp) | 
					
						
						|  | logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if early_stop_active: | 
					
						
						|  | if val_loss < best_val_loss: | 
					
						
						|  | best_val_loss = val_loss | 
					
						
						|  |  | 
					
						
						|  | best_model = copy.deepcopy(model) | 
					
						
						|  | epochs_no_improve = 0 | 
					
						
						|  | else: | 
					
						
						|  | epochs_no_improve += 1 | 
					
						
						|  | if epochs_no_improve >= patience: | 
					
						
						|  |  | 
					
						
						|  | print(f"Early stopping triggered after {epoch + 1} epochs.") | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if use_swa: | 
					
						
						|  | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_to_device(x, device)) | 
					
						
						|  | best_model = best_model.to(device) | 
					
						
						|  | best_model.eval() | 
					
						
						|  | torch.optim.swa_utils.update_bn(train_loader, best_model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_train_losses.append(train_losses) | 
					
						
						|  | all_val_losses.append(val_losses) | 
					
						
						|  | max_length = max(len(x) for x in all_train_losses) | 
					
						
						|  | all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses] | 
					
						
						|  | max_length = max(len(x) for x in all_val_losses) | 
					
						
						|  | all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses] | 
					
						
						|  |  | 
					
						
						|  | return best_model, all_train_losses, all_val_losses | 
					
						
						|  |  | 
					
						
						|  | def collate_to_device(batch, device): | 
					
						
						|  | data, targets = zip(*batch) | 
					
						
						|  | return torch.stack(data).to(device), torch.stack(targets).to(device) | 
					
						
						|  |  | 
					
						
						|  | def model_test(best_model, X, y, device): | 
					
						
						|  | test_dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y)) | 
					
						
						|  | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) | 
					
						
						|  |  | 
					
						
						|  | best_model.eval() | 
					
						
						|  | y_pred = [] | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for inputs, _ in test_loader: | 
					
						
						|  | inputs = inputs.to(device) | 
					
						
						|  |  | 
					
						
						|  | outputs = best_model(inputs) | 
					
						
						|  | y_pred.extend(outputs.view(-1).tolist()) | 
					
						
						|  |  | 
					
						
						|  | return y_pred | 
					
						
						|  |  | 
					
						
						|  | def main(config): | 
					
						
						|  | model_name = config['model_name'] | 
					
						
						|  | data_name = config['data_name'] | 
					
						
						|  | network_name = config['network_name'] | 
					
						
						|  |  | 
					
						
						|  | metadata_path = config['metadata_path'] | 
					
						
						|  | feature_path = config['feature_path'] | 
					
						
						|  | log_path = config['log_path'] | 
					
						
						|  | save_path = config['save_path'] | 
					
						
						|  | score_path = config['score_path'] | 
					
						
						|  | result_path = config['result_path'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | select_criteria = config['select_criteria'] | 
					
						
						|  | n_repeats = config['n_repeats'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(log_path, exist_ok=True) | 
					
						
						|  | os.makedirs(save_path, exist_ok=True) | 
					
						
						|  | os.makedirs(score_path, exist_ok=True) | 
					
						
						|  | os.makedirs(result_path, exist_ok=True) | 
					
						
						|  | result_file = f'{result_path}{data_name}_{network_name}_{select_criteria}.mat' | 
					
						
						|  | pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{select_criteria}.csv") | 
					
						
						|  | file_path = os.path.join(save_path, f"{data_name}_{network_name}_{select_criteria}_trained_median_model_param.pth") | 
					
						
						|  | configure_logging(log_path, model_name, data_name, network_name, select_criteria) | 
					
						
						|  |  | 
					
						
						|  | '''======================== Main Body ===========================''' | 
					
						
						|  | PLCC_all_repeats_test = [] | 
					
						
						|  | SRCC_all_repeats_test = [] | 
					
						
						|  | KRCC_all_repeats_test = [] | 
					
						
						|  | RMSE_all_repeats_test = [] | 
					
						
						|  | PLCC_all_repeats_train = [] | 
					
						
						|  | SRCC_all_repeats_train = [] | 
					
						
						|  | KRCC_all_repeats_train = [] | 
					
						
						|  | RMSE_all_repeats_train = [] | 
					
						
						|  | all_repeats_test_vids = [] | 
					
						
						|  | all_repeats_df_test_pred = [] | 
					
						
						|  | best_model_list = [] | 
					
						
						|  |  | 
					
						
						|  | for i in range(1, n_repeats + 1): | 
					
						
						|  | print(f"{i}th repeated 80-20 hold out test") | 
					
						
						|  | logging.info(f"{i}th repeated 80-20 hold out test") | 
					
						
						|  | t0 = time.time() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | test_size = 0.2 | 
					
						
						|  | random_state = math.ceil(8.8 * i) | 
					
						
						|  |  | 
					
						
						|  | if data_name == 'lsvq_train': | 
					
						
						|  | test_data_name = 'lsvq_test' | 
					
						
						|  | train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name) | 
					
						
						|  | elif data_name == 'cross_dataset': | 
					
						
						|  | train_data_name = 'youtube_ugc_all' | 
					
						
						|  | test_data_name = 'cvd_2014_all' | 
					
						
						|  | _, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name) | 
					
						
						|  | else: | 
					
						
						|  | _, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name) | 
					
						
						|  |  | 
					
						
						|  | '''======================== read files =============================== ''' | 
					
						
						|  | if data_name == 'lsvq_train': | 
					
						
						|  | X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features) | 
					
						
						|  | else: | 
					
						
						|  | X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None) | 
					
						
						|  |  | 
					
						
						|  | '''======================== regression model =============================== ''' | 
					
						
						|  | best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | avg_train_losses = np.mean(all_train_losses, axis=0) | 
					
						
						|  | avg_val_losses = np.mean(all_val_losses, axis=0) | 
					
						
						|  | test_vids = test_vids.tolist() | 
					
						
						|  | plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | y_train_pred = model_test(best_model, X_train, y_train, device) | 
					
						
						|  | y_train_pred = np.array(list(y_train_pred), dtype=float) | 
					
						
						|  | y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train, y_train_pred) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | y_test_pred = model_test(best_model, X_test, y_test, device) | 
					
						
						|  | y_test_pred = np.array(list(y_test_pred), dtype=float) | 
					
						
						|  | y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test, y_test_pred) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic} | 
					
						
						|  | df_test_pred = pd.DataFrame(test_pred_score) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info("============================================================================================================") | 
					
						
						|  | SRCC_all_repeats_test.append(srcc_test) | 
					
						
						|  | KRCC_all_repeats_test.append(krcc_test) | 
					
						
						|  | PLCC_all_repeats_test.append(plcc_test) | 
					
						
						|  | RMSE_all_repeats_test.append(rmse_test) | 
					
						
						|  | SRCC_all_repeats_train.append(srcc_train) | 
					
						
						|  | KRCC_all_repeats_train.append(krcc_train) | 
					
						
						|  | PLCC_all_repeats_train.append(plcc_train) | 
					
						
						|  | RMSE_all_repeats_train.append(rmse_train) | 
					
						
						|  | all_repeats_test_vids.append(test_vids) | 
					
						
						|  | all_repeats_df_test_pred.append(df_test_pred) | 
					
						
						|  | best_model_list.append(copy.deepcopy(best_model)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info('Best results in Mlp model within one split') | 
					
						
						|  | logging.info(f'MODEL: {best_model}') | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info(f'Train set - Evaluation Results') | 
					
						
						|  | logging.info(f'SRCC_train: {srcc_train}') | 
					
						
						|  | logging.info(f'KRCC_train: {krcc_train}') | 
					
						
						|  | logging.info(f'PLCC_train: {plcc_train}') | 
					
						
						|  | logging.info(f'RMSE_train: {rmse_train}') | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info(f'Test set - Evaluation Results') | 
					
						
						|  | logging.info(f'SRCC_test: {srcc_test}') | 
					
						
						|  | logging.info(f'KRCC_test: {krcc_test}') | 
					
						
						|  | logging.info(f'PLCC_test: {plcc_test}') | 
					
						
						|  | logging.info(f'RMSE_test: {rmse_test}') | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0)) | 
					
						
						|  |  | 
					
						
						|  | logging.info('') | 
					
						
						|  | SRCC_all_repeats_test = np.nan_to_num(SRCC_all_repeats_test) | 
					
						
						|  | KRCC_all_repeats_test = np.nan_to_num(KRCC_all_repeats_test) | 
					
						
						|  | PLCC_all_repeats_test = np.nan_to_num(PLCC_all_repeats_test) | 
					
						
						|  | RMSE_all_repeats_test = np.nan_to_num(RMSE_all_repeats_test) | 
					
						
						|  | SRCC_all_repeats_train = np.nan_to_num(SRCC_all_repeats_train) | 
					
						
						|  | KRCC_all_repeats_train = np.nan_to_num(KRCC_all_repeats_train) | 
					
						
						|  | PLCC_all_repeats_train = np.nan_to_num(PLCC_all_repeats_train) | 
					
						
						|  | RMSE_all_repeats_train = np.nan_to_num(RMSE_all_repeats_train) | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info('Average training results among all repeated 80-20 holdouts:') | 
					
						
						|  | logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_train), np.std(SRCC_all_repeats_train)) | 
					
						
						|  | logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_train), np.std(KRCC_all_repeats_train)) | 
					
						
						|  | logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_train), np.std(PLCC_all_repeats_train)) | 
					
						
						|  | logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_train), np.std(RMSE_all_repeats_train)) | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info('Average testing results among all repeated 80-20 holdouts:') | 
					
						
						|  | logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_test), np.std(SRCC_all_repeats_test)) | 
					
						
						|  | logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_test), np.std(KRCC_all_repeats_test)) | 
					
						
						|  | logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_test), np.std(PLCC_all_repeats_test)) | 
					
						
						|  | logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_test), np.std(RMSE_all_repeats_test)) | 
					
						
						|  | logging.info('======================================================') | 
					
						
						|  | logging.info('\n') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print('======================================================') | 
					
						
						|  | if select_criteria == 'byrmse': | 
					
						
						|  | median_metrics = np.median(RMSE_all_repeats_test) | 
					
						
						|  | indices = np.where(RMSE_all_repeats_test == median_metrics)[0] | 
					
						
						|  | select_criteria = select_criteria.replace('by', '').upper() | 
					
						
						|  | print(RMSE_all_repeats_test) | 
					
						
						|  | logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}') | 
					
						
						|  | elif select_criteria == 'bykrcc': | 
					
						
						|  | median_metrics = np.median(KRCC_all_repeats_test) | 
					
						
						|  | indices = np.where(KRCC_all_repeats_test == median_metrics)[0] | 
					
						
						|  | select_criteria = select_criteria.replace('by', '').upper() | 
					
						
						|  | print(KRCC_all_repeats_test) | 
					
						
						|  | logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}') | 
					
						
						|  |  | 
					
						
						|  | median_test_vids = [all_repeats_test_vids[i] for i in indices] | 
					
						
						|  | test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else []) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | median_model = None | 
					
						
						|  | if len(indices) > 0: | 
					
						
						|  | median_index = indices[0] | 
					
						
						|  | median_model = best_model_list[median_index] | 
					
						
						|  | median_model_df_test_pred = all_repeats_df_test_pred[median_index] | 
					
						
						|  |  | 
					
						
						|  | median_model_df_test_pred.to_csv(pred_score_filename, index=False) | 
					
						
						|  | plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria) | 
					
						
						|  |  | 
					
						
						|  | print(f'Median Metrics: {median_metrics}') | 
					
						
						|  | print(f'Indices: {indices}') | 
					
						
						|  |  | 
					
						
						|  | print(f'Best model: {median_model}') | 
					
						
						|  |  | 
					
						
						|  | logging.info(f'median test {select_criteria}: {median_metrics}') | 
					
						
						|  | logging.info(f"Indices of median metrics: {indices}") | 
					
						
						|  |  | 
					
						
						|  | logging.info(f'Best model predict score: {median_model_df_test_pred}') | 
					
						
						|  | logging.info(f'Best model: {median_model}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scipy.io.savemat(result_file, mdict={'SRCC_train': np.asarray(SRCC_all_repeats_train, dtype=float), \ | 
					
						
						|  | 'KRCC_train': np.asarray(KRCC_all_repeats_train, dtype=float), \ | 
					
						
						|  | 'PLCC_train': np.asarray(PLCC_all_repeats_train, dtype=float), \ | 
					
						
						|  | 'RMSE_train': np.asarray(RMSE_all_repeats_train, dtype=float), \ | 
					
						
						|  | 'SRCC_test': np.asarray(SRCC_all_repeats_test, dtype=float), \ | 
					
						
						|  | 'KRCC_test': np.asarray(KRCC_all_repeats_test, dtype=float), \ | 
					
						
						|  | 'PLCC_test': np.asarray(PLCC_all_repeats_test, dtype=float), \ | 
					
						
						|  | 'RMSE_test': np.asarray(RMSE_all_repeats_test, dtype=float), \ | 
					
						
						|  | f'Median_{select_criteria}': median_metrics, \ | 
					
						
						|  | 'Test_Videos_list': all_repeats_test_vids, \ | 
					
						
						|  | 'Test_videos_Median_model': test_vids, \ | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.save(median_model.state_dict(), file_path) | 
					
						
						|  | print(f"Model state_dict saved to {file_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument('--model_name', type=str, default='Mlp') | 
					
						
						|  | parser.add_argument('--data_name', type=str, default='cvd_2014', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, lsvq_train, cross_dataset') | 
					
						
						|  | parser.add_argument('--network_name', type=str, default='relaxvqa', help='relaxvqa, {frag_name}_{network_name}_{layer_name}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument('--metadata_path', type=str, default='../metadata/') | 
					
						
						|  | parser.add_argument('--feature_path', type=str, default='../features/') | 
					
						
						|  | parser.add_argument('--log_path', type=str, default='../log/') | 
					
						
						|  | parser.add_argument('--save_path', type=str, default='../model/') | 
					
						
						|  | parser.add_argument('--score_path', type=str, default='../log/predict_score/') | 
					
						
						|  | parser.add_argument('--result_path', type=str, default='../log/result/') | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc') | 
					
						
						|  | parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test') | 
					
						
						|  | parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training') | 
					
						
						|  | parser.add_argument('--epochs', type=int, default=120, help='Epochs for training') | 
					
						
						|  | parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features') | 
					
						
						|  | parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.') | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss') | 
					
						
						|  | parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd') | 
					
						
						|  | parser.add_argument('--initial_lr', type=float, default=1e-2, help='Initial learning rate: 1e-2') | 
					
						
						|  | parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay (L2 loss): 1e-4') | 
					
						
						|  | parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.') | 
					
						
						|  | parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging') | 
					
						
						|  | parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight') | 
					
						
						|  | parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight') | 
					
						
						|  |  | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  | config = vars(args) | 
					
						
						|  | print(config) | 
					
						
						|  |  | 
					
						
						|  | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
					
						
						|  | print(device) | 
					
						
						|  | if device.type == "cuda": | 
					
						
						|  | torch.cuda.set_device(0) | 
					
						
						|  |  | 
					
						
						|  | main(config) |