Grad-CDM / debug_model.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from sample import frequency_aware_sample
import numpy as np
def debug_model_predictions():
"""Debug what the model is actually predicting"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Find latest checkpoint
log_dirs = []
if os.path.exists('./logs'):
for item in os.listdir('./logs'):
if os.path.isdir(os.path.join('./logs', item)):
log_dirs.append(item)
if not log_dirs:
print("No log directories found!")
return
latest_log = sorted(log_dirs)[-1]
log_path = os.path.join('./logs', latest_log)
checkpoint_files = []
for file in os.listdir(log_path):
if file.startswith('model_epoch_') and file.endswith('.pth'):
epoch = int(file.split('_')[2].split('.')[0])
checkpoint_files.append((epoch, file))
if not checkpoint_files:
print("No checkpoint files found!")
return
# Get latest checkpoint
checkpoint_files.sort()
latest_epoch, latest_file = checkpoint_files[-1]
checkpoint_path = os.path.join(log_path, latest_file)
print(f"Loading {latest_file}")
# Load model
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint.get('config', Config())
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("\n=== DEBUGGING MODEL PREDICTIONS ===")
with torch.no_grad():
# Create a simple test input
x_test = torch.randn(1, 3, 64, 64, device=device)
# Test at different timesteps
timesteps_to_test = [0, 50, 100, 250, 499]
for t_val in timesteps_to_test:
t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long)
# Get model prediction
pred_noise = model(x_test, t_tensor)
print(f"\nTimestep {t_val}:")
print(f" Input range: [{x_test.min().item():.3f}, {x_test.max().item():.3f}]")
print(f" Input mean/std: {x_test.mean().item():.3f} / {x_test.std().item():.3f}")
print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
# Check if prediction is reasonable
if torch.isnan(pred_noise).any():
print(f" ❌ NaN detected in predictions!")
elif pred_noise.std().item() < 0.01:
print(f" ⚠️ Very low variance - model might be collapsed")
elif pred_noise.std().item() > 10:
print(f" ⚠️ Very high variance - model might be unstable")
else:
print(f" ✓ Prediction variance looks reasonable")
print("\n=== TESTING TRAINING DATA SIMULATION ===")
# Simulate what happens during training
with torch.no_grad():
# Create clean image
x0 = torch.randn(1, 3, 64, 64, device=device) * 0.5 # More reasonable range
t = torch.randint(100, 400, (1,), device=device) # Mid-range timestep
# Apply noise like in training
xt, noise_target = noise_scheduler.apply_noise(x0, t)
# Get model prediction
pred_noise = model(xt, t)
print(f"\nTraining simulation:")
print(f" Clean image range: [{x0.min().item():.3f}, {x0.max().item():.3f}]")
print(f" Noisy image range: [{xt.min().item():.3f}, {xt.max().item():.3f}]")
print(f" Target noise range: [{noise_target.min().item():.3f}, {noise_target.max().item():.3f}]")
print(f" Target noise mean/std: {noise_target.mean().item():.3f} / {noise_target.std().item():.3f}")
print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
# Calculate MSE
mse = torch.mean((pred_noise - noise_target) ** 2)
print(f" MSE between prediction and target: {mse.item():.6f}")
if mse.item() > 1.0:
print(f" ⚠️ High MSE suggests poor training")
elif mse.item() < 0.001:
print(f" ✓ Very low MSE - model learned well")
else:
print(f" ✓ Reasonable MSE")
print("\n=== ATTEMPTING CORRECTED SAMPLING ===")
# Try different sampling approaches
try:
samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=4)
save_image(grid, "debug_samples.png", normalize=False)
print(f"Samples saved to debug_samples.png")
print(f"Sample statistics:")
print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]")
print(f" Mean: {samples.mean().item():.3f}")
print(f" Std: {samples.std().item():.3f}")
except Exception as e:
print(f"Sampling failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
debug_model_predictions()