File size: 3,939 Bytes
6221b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# metrics.py
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from models.srcnn import SRCNN
from models.vdsr import VDSR
from models.edsr import EDSR
import math
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt


def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(1.0 / math.sqrt(mse.item()))

def process_image(model, lr_img):
    with torch.no_grad():
        # Convert to YCbCr and extract Y channel
        ycbcr = lr_img.convert('YCbCr')
        y, cb, cr = ycbcr.split()
        
        # Transform Y channel
        transform = transforms.Compose([transforms.ToTensor()])
        input_tensor = transform(y).unsqueeze(0)
        
        # Process through model
        output = model(input_tensor)
        
        # Post-process output
        output = output.squeeze().clamp(0, 1).numpy()
        output_y = Image.fromarray((output * 255).astype(np.uint8))
        
        # Merge channels back
        output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
        return output_ycbcr.convert('RGB')

def calculate_ssim(img1, img2):
    # Move channel axis to the end for SSIM calculation
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)
    return ssim(img1_np, img2_np, data_range=1.0, channel_axis=2, win_size=7)

def evaluate_models(test_image_path):
    # Load models
    models = {
        'SRCNN': SRCNN(),
        'VDSR': VDSR(),
        'EDSR': EDSR()
    }
    
    # Load weights
    for name, model in models.items():
        model.load_state_dict(torch.load(f'checkpoints/{name.lower()}_best.pth', weights_only=True))
        model.eval()
    
    # Load test image
    lr_img = Image.open(test_image_path)
    hr_img = Image.open(test_image_path)  # Using same image as reference
    
    # Results dictionary
    results = {model_name: {} for model_name in models.keys()}
    
    # Process image with each model and calculate metrics
    for name, model in models.items():
        # Generate SR image
        sr_img = process_image(model, lr_img)
        
        # Convert images to tensors for metric calculation
        transform = transforms.Compose([
            transforms.Resize((256, 256)),  # Ensure minimum size for SSIM
            transforms.ToTensor()
        ])
        
        sr_tensor = transform(sr_img)
        hr_tensor = transform(hr_img)
        
        # Calculate metrics
        results[name]['PSNR'] = calculate_psnr(sr_tensor, hr_tensor)
        results[name]['SSIM'] = calculate_ssim(sr_tensor, hr_tensor)
        
        # Save output images
        sr_img.save(f'results/{name.lower()}_output.png')
    
    # Display results
    print("\nModel Performance Metrics:")
    print("-" * 50)
    print(f"{'Model':<10} {'PSNR (dB)':<15} {'SSIM':<15}")
    print("-" * 50)
    
    for model_name, metrics in results.items():
        print(f"{model_name:<10} {metrics['PSNR']:<15.2f} {metrics['SSIM']:<15.4f}")
    
    # Plot results
    plt.figure(figsize=(12, 6))
    
    # PSNR comparison
    plt.subplot(1, 2, 1)
    plt.bar(results.keys(), [m['PSNR'] for m in results.values()])
    plt.title('PSNR Comparison')
    plt.ylabel('PSNR (dB)')
    
    # SSIM comparison
    plt.subplot(1, 2, 2)
    plt.bar(results.keys(), [m['SSIM'] for m in results.values()])
    plt.title('SSIM Comparison')
    plt.ylabel('SSIM')
    
    plt.tight_layout()
    plt.savefig('results/metrics_comparison.png')
    plt.close()

if __name__ == "__main__":
    import os
    os.makedirs('results', exist_ok=True)
    test_image_path = r"data\DIV2K_train_LR_bicubic_X4\DIV2K_train_LR_bicubic\X4\0001x4.png"  # Replace with your test image path
    evaluate_models(test_image_path)