RADAR_Imaging / scripts /sarddpm_test.py
TTrain404's picture
Upload 24 files
39aef76 verified
"""
SAR-DDPM Inference on real SAR images.
"""
import argparse
import torch
import os
import cv2
import numpy as np
import torch.nn.functional as F
from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
sr_model_and_diffusion_defaults,
sr_create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop
from torch.utils.data import DataLoader
from torch.optim import AdamW
from valdata import ValData, ValDataNew, ValDataNewReal
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
val_dir = 'path_to_validation_data/'
base_path = 'path_to_save_results/'
resume_checkpoint_clean = './weights/sar_ddpm.pt'
def main():
args = create_argparser().parse_args()
print(args)
model_clean, diffusion = sr_create_model_and_diffusion(
**args_to_dict(args, sr_model_and_diffusion_defaults().keys())
)
print(torch.device('cuda'))
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
val_data = DataLoader(ValDataNewReal(dataset_path=val_dir), batch_size=1, shuffle=False, num_workers=1) #load_superres_dataval()
device0 = torch.device("cuda:0")
model_clean.load_state_dict(torch.load(resume_checkpoint_clean, map_location="cuda:0"))
model_clean.to(device0)
params = list(model_clean.parameters())
print('model clean device:')
print(next(model_clean.parameters()).device)
with torch.no_grad():
number = 0
for batch_id1, data_var in enumerate(val_data):
number = number+1
clean_batch, model_kwargs1 = data_var
single_img = model_kwargs1['SR'].to(dist_util.dev())
count = 0
[t1,t2,max_r,max_c] = single_img.size()
N =9
val_inputv = single_img.clone()
for row in range(0,max_r,100):
for col in range(0,max_c,100):
val_inputv[:,:,:row,:col] = single_img[:,:,max_r-row:,max_c-col:]
val_inputv[:,:,row:,col:] = single_img[:,:,:max_r-row,:max_c-col]
val_inputv[:,:,row:,:col] = single_img[:,:,:max_r-row,max_c-col:]
val_inputv[:,:,:row,col:] = single_img[:,:,max_r-row:,:max_c-col]
model_kwargs = {}
for k, v in model_kwargs1.items():
if('Index' in k):
img_name=v
elif('SR' in k):
model_kwargs[k] = val_inputv.to(dist_util.dev())
else:
model_kwargs[k]= v.to(dist_util.dev())
sample = diffusion.p_sample_loop(
model_clean,
(clean_batch.shape[0], 3, 256,256),
clip_denoised=True,
model_kwargs=model_kwargs,
)
if count==0:
sample_new = (1.0/N)*sample
else :
sample_new[:,:,max_r-row:,max_c-col:] = sample_new[:,:,max_r-row:,max_c-col:] + (1.0/N)*sample[:,:,:row,:col]
sample_new[:,:,:max_r-row,:max_c-col] = sample_new[:,:,:max_r-row,:max_c-col] + (1.0/N)*sample[:,:,row:,col:]
sample_new[:,:,:max_r-row,max_c-col:] = sample_new[:,:,:max_r-row,max_c-col:] + (1.0/N)*sample[:,:,row:,:col]
sample_new[:,:,max_r-row:,:max_c-col] = sample_new[:,:,max_r-row:,:max_c-col] + (1.0/N)*sample[:,:,:row,col:]
count += 1
sample_new = ((sample_new + 1) * 127.5)
sample_new = sample_new.clamp(0, 255).to(torch.uint8)
sample_new = sample_new.permute(0, 2, 3, 1)
sample_new = sample_new.contiguous().cpu().numpy()
sample_new = sample_new[0][:,:,::-1]
sample_new = cv2.cvtColor(sample_new, cv2.COLOR_BGR2GRAY)
print(img_name[0])
cv2.imwrite(base_path+'pred_'+img_name[0],sample_new)
def create_argparser():
defaults = dict(
data_dir= val_dir,
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=2,
microbatch=1,
ema_rate="0.9999",
log_interval=100,
save_interval=200,
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(sr_model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()