|
"""
|
|
SAR-DDPM Inference on real SAR images.
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from guided_diffusion import dist_util, logger
|
|
from guided_diffusion.image_datasets import load_data
|
|
from guided_diffusion.resample import create_named_schedule_sampler
|
|
from guided_diffusion.script_util import (
|
|
sr_model_and_diffusion_defaults,
|
|
sr_create_model_and_diffusion,
|
|
args_to_dict,
|
|
add_dict_to_argparser,
|
|
)
|
|
from guided_diffusion.train_util import TrainLoop
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
|
|
from valdata import ValData, ValDataNew, ValDataNewReal
|
|
from skimage.metrics import peak_signal_noise_ratio as psnr
|
|
from skimage.metrics import structural_similarity as ssim
|
|
|
|
|
|
|
|
val_dir = 'path_to_validation_data/'
|
|
base_path = 'path_to_save_results/'
|
|
resume_checkpoint_clean = './weights/sar_ddpm.pt'
|
|
|
|
|
|
|
|
|
|
def main():
|
|
args = create_argparser().parse_args()
|
|
|
|
print(args)
|
|
|
|
|
|
model_clean, diffusion = sr_create_model_and_diffusion(
|
|
**args_to_dict(args, sr_model_and_diffusion_defaults().keys())
|
|
)
|
|
|
|
|
|
print(torch.device('cuda'))
|
|
|
|
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
|
|
|
|
|
|
val_data = DataLoader(ValDataNewReal(dataset_path=val_dir), batch_size=1, shuffle=False, num_workers=1)
|
|
|
|
device0 = torch.device("cuda:0")
|
|
|
|
model_clean.load_state_dict(torch.load(resume_checkpoint_clean, map_location="cuda:0"))
|
|
|
|
|
|
model_clean.to(device0)
|
|
|
|
|
|
|
|
|
|
params = list(model_clean.parameters())
|
|
|
|
print('model clean device:')
|
|
print(next(model_clean.parameters()).device)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
number = 0
|
|
|
|
|
|
for batch_id1, data_var in enumerate(val_data):
|
|
number = number+1
|
|
clean_batch, model_kwargs1 = data_var
|
|
|
|
single_img = model_kwargs1['SR'].to(dist_util.dev())
|
|
|
|
count = 0
|
|
[t1,t2,max_r,max_c] = single_img.size()
|
|
|
|
N =9
|
|
|
|
val_inputv = single_img.clone()
|
|
|
|
for row in range(0,max_r,100):
|
|
for col in range(0,max_c,100):
|
|
|
|
|
|
val_inputv[:,:,:row,:col] = single_img[:,:,max_r-row:,max_c-col:]
|
|
val_inputv[:,:,row:,col:] = single_img[:,:,:max_r-row,:max_c-col]
|
|
val_inputv[:,:,row:,:col] = single_img[:,:,:max_r-row,max_c-col:]
|
|
val_inputv[:,:,:row,col:] = single_img[:,:,max_r-row:,:max_c-col]
|
|
|
|
model_kwargs = {}
|
|
for k, v in model_kwargs1.items():
|
|
if('Index' in k):
|
|
img_name=v
|
|
elif('SR' in k):
|
|
model_kwargs[k] = val_inputv.to(dist_util.dev())
|
|
else:
|
|
model_kwargs[k]= v.to(dist_util.dev())
|
|
|
|
|
|
|
|
sample = diffusion.p_sample_loop(
|
|
model_clean,
|
|
(clean_batch.shape[0], 3, 256,256),
|
|
clip_denoised=True,
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
|
|
|
|
|
|
if count==0:
|
|
sample_new = (1.0/N)*sample
|
|
else :
|
|
sample_new[:,:,max_r-row:,max_c-col:] = sample_new[:,:,max_r-row:,max_c-col:] + (1.0/N)*sample[:,:,:row,:col]
|
|
sample_new[:,:,:max_r-row,:max_c-col] = sample_new[:,:,:max_r-row,:max_c-col] + (1.0/N)*sample[:,:,row:,col:]
|
|
sample_new[:,:,:max_r-row,max_c-col:] = sample_new[:,:,:max_r-row,max_c-col:] + (1.0/N)*sample[:,:,row:,:col]
|
|
sample_new[:,:,max_r-row:,:max_c-col] = sample_new[:,:,max_r-row:,:max_c-col] + (1.0/N)*sample[:,:,:row,col:]
|
|
|
|
count += 1
|
|
|
|
sample_new = ((sample_new + 1) * 127.5)
|
|
sample_new = sample_new.clamp(0, 255).to(torch.uint8)
|
|
sample_new = sample_new.permute(0, 2, 3, 1)
|
|
sample_new = sample_new.contiguous().cpu().numpy()
|
|
sample_new = sample_new[0][:,:,::-1]
|
|
|
|
sample_new = cv2.cvtColor(sample_new, cv2.COLOR_BGR2GRAY)
|
|
print(img_name[0])
|
|
cv2.imwrite(base_path+'pred_'+img_name[0],sample_new)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_argparser():
|
|
defaults = dict(
|
|
data_dir= val_dir,
|
|
schedule_sampler="uniform",
|
|
lr=1e-4,
|
|
weight_decay=0.0,
|
|
lr_anneal_steps=0,
|
|
batch_size=2,
|
|
microbatch=1,
|
|
ema_rate="0.9999",
|
|
log_interval=100,
|
|
save_interval=200,
|
|
use_fp16=False,
|
|
fp16_scale_growth=1e-3,
|
|
)
|
|
defaults.update(sr_model_and_diffusion_defaults())
|
|
parser = argparse.ArgumentParser()
|
|
add_dict_to_argparser(parser, defaults)
|
|
return parser
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|