|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | def total_variation_loss(x): | 
					
						
						|  | """Total variation regularization""" | 
					
						
						|  | batch_size = x.size(0) | 
					
						
						|  | h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]).sum() | 
					
						
						|  | w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]).sum() | 
					
						
						|  | return (h_tv + w_tv) / batch_size | 
					
						
						|  |  | 
					
						
						|  | def gradient_loss(x): | 
					
						
						|  | """Sobel gradient loss""" | 
					
						
						|  | sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=x.device).view(1, 1, 3, 3) | 
					
						
						|  | sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=x.device).view(1, 1, 3, 3) | 
					
						
						|  |  | 
					
						
						|  | grad_x = F.conv2d(x, sobel_x.repeat(x.size(1), 1, 1, 1), padding=1, groups=x.size(1)) | 
					
						
						|  | grad_y = F.conv2d(x, sobel_y.repeat(x.size(1), 1, 1, 1), padding=1, groups=x.size(1)) | 
					
						
						|  |  | 
					
						
						|  | return torch.mean(grad_x**2 + grad_y**2) | 
					
						
						|  |  | 
					
						
						|  | def diffusion_loss(model, x0, t, noise_scheduler, config): | 
					
						
						|  | xt, noise = noise_scheduler.apply_noise(x0, t) | 
					
						
						|  | pred_noise = model(xt, t) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mse_loss = F.mse_loss(pred_noise, noise) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tv_loss = total_variation_loss(xt) | 
					
						
						|  | grad_loss = gradient_loss(xt) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total_loss = mse_loss + config.tv_weight * tv_loss + 0.001 * grad_loss | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if torch.isnan(total_loss) or total_loss > 1e6: | 
					
						
						|  | print(f"WARNING: Extreme loss detected!") | 
					
						
						|  | print(f"MSE: {mse_loss.item():.4f}, TV: {tv_loss.item():.4f}, Grad: {grad_loss.item():.4f}") | 
					
						
						|  | print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") | 
					
						
						|  | print(f"Pred range: [{pred_noise.min().item():.4f}, {pred_noise.max().item():.4f}]") | 
					
						
						|  |  | 
					
						
						|  | return total_loss |