from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import trange

from Dataloader import *
from .utils import image_quality
from .utils.cls import CyclicLR
from .utils.prepare_images import *

train_folder = './dataset/train'
test_folder = "./dataset/test"

img_dataset = ImageDBData(db_file='dataset/images.db', db_table="train_images_size_128_noise_1_rgb", max_images=24)
img_data = DataLoader(img_dataset, batch_size=6, shuffle=True, num_workers=6)

total_batch = len(img_data)
print(len(img_dataset))

test_dataset = ImageDBData(db_file='dataset/test2.db', db_table="test_images_size_128_noise_1_rgb", max_images=None)
num_test = len(test_dataset)
test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

criteria = nn.L1Loss()

model = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d,
                single_conv_size=3, single_conv_group=1,
                scale=2, activation=nn.LeakyReLU(0.1),
                SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1))

model.total_parameters()


# model.initialize_weights_xavier_uniform()

# fp16 training is available in GPU only
model = network_to_half(model)
model = model.cuda()
model.load_state_dict(torch.load("CARN_model_checkpoint.pt"))

learning_rate = 1e-4
weight_decay = 1e-6
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True)
# optimizer = optim.SGD(model.parameters(), momentum=0.9, nesterov=True, weight_decay=weight_decay, lr=learning_rate)

# optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0, verbose=False)
# optimizer.load_state_dict(torch.load("CARN_adam_checkpoint.pt"))

last_iter = -1  # torch.load("CARN_scheduler_last_iter")
scheduler = CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-4,
                     step_size=3 * total_batch, mode="triangular",
                     last_batch_iteration=last_iter)
train_loss = []
train_ssim = []
train_psnr = []

test_loss = []
test_ssim = []
test_psnr = []

# train_loss = torch.load("train_loss.pt")
# train_ssim = torch.load("train_ssim.pt")
# train_psnr = torch.load("train_psnr.pt")
#
# test_loss = torch.load("test_loss.pt")
# test_ssim = torch.load("test_ssim.pt")
# test_psnr = torch.load("test_psnr.pt")


counter = 0
iteration = 2
ibar = trange(iteration, ascii=True, maxinterval=1, postfix={"avg_loss": 0, "train_ssim": 0, "test_ssim": 0})
for i in ibar:
    # batch_loss = []
    # insample_ssim = []
    # insample_psnr = []
    for index, batch in enumerate(img_data):
        scheduler.batch_step()
        lr_img, hr_img = batch
        lr_img = lr_img.cuda().half()
        hr_img = hr_img.cuda()

        # model.zero_grad()
        optimizer.zero_grad()
        outputs = model.forward(lr_img)
        outputs = outputs.float()
        loss = criteria(outputs, hr_img)
        # loss.backward()
        optimizer.backward(loss)
        # nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

        counter += 1
        # train_loss.append(loss.item())

        ssim = image_quality.msssim(outputs, hr_img).item()
        psnr = image_quality.psnr(outputs, hr_img).item()

        ibar.set_postfix(ratio=index / total_batch, loss=loss.item(),
                         ssim=ssim, batch=index,
                         psnr=psnr,
                         lr=scheduler.current_lr
                         )
        train_loss.append(loss.item())
        train_ssim.append(ssim)
        train_psnr.append(psnr)

        # +++++++++++++++++++++++++++++++++++++
        #      save checkpoints by iterations
        # -------------------------------------

        if (counter + 1) % 500 == 0:
            torch.save(model.state_dict(), 'CARN_model_checkpoint.pt')
            torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt')
            torch.save(train_loss, 'train_loss.pt')
            torch.save(train_ssim, "train_ssim.pt")
            torch.save(train_psnr, 'train_psnr.pt')
            torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt")

    # +++++++++++++++++++++++++++++++++++++
    #           End of One Epoch      
    # -------------------------------------

    # one_ite_loss = np.mean(batch_loss)
    # one_ite_ssim = np.mean(insample_ssim)
    # one_ite_psnr = np.mean(insample_psnr)

    # print(f"One iteration loss {one_ite_loss}, ssim {one_ite_ssim}, psnr {one_ite_psnr}")
    # train_loss.append(one_ite_loss)
    # train_ssim.append(one_ite_ssim)
    # train_psnr.append(one_ite_psnr)

    torch.save(model.state_dict(), 'CARN_model_checkpoint.pt')
    # torch.save(scheduler, "CARN_scheduler_optim.pt")
    torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt')
    torch.save(train_loss, 'train_loss.pt')
    torch.save(train_ssim, "train_ssim.pt")
    torch.save(train_psnr, 'train_psnr.pt')
    # torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt")

    # +++++++++++++++++++++++++++++++++++++
    #           Test
    # -------------------------------------

    with torch.no_grad():
        ssim = []
        batch_loss = []
        psnr = []
        for index, test_batch in enumerate(test_data):
            lr_img, hr_img = test_batch
            lr_img = lr_img.cuda()
            hr_img = hr_img.cuda()

            lr_img_up = model(lr_img)
            lr_img_up = lr_img_up.float()
            loss = criteria(lr_img_up, hr_img)

            save_image([lr_img_up[0], hr_img[0]], f"check_test_imgs/{index}.png")
            batch_loss.append(loss.item())
            ssim.append(image_quality.msssim(lr_img_up, hr_img).item())
            psnr.append(image_quality.psnr(lr_img_up, hr_img).item())

        test_ssim.append(np.mean(ssim))
        test_loss.append(np.mean(batch_loss))
        test_psnr.append(np.mean(psnr))

        torch.save(test_loss, 'test_loss.pt')
        torch.save(test_ssim, "test_ssim.pt")
        torch.save(test_psnr, "test_psnr.pt")

# import subprocess

# subprocess.call(["shutdown", "/s"])