#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
"""

import argparse
import os

join = os.path.join

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from stardist import star_dist,edt_prob
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
from stardist import random_label_cmap,ray_angles
import monai
from collections import OrderedDict
from compute_metric import eval_tp_fp_fn,remove_boundary_cells
from monai.data import decollate_batch, PILReader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    SpatialPadd,
    RandSpatialCropd,
    RandRotate90d,
    ScaleIntensityd,
    RandAxisFlipd,
    RandZoomd,
    RandGaussianNoised,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandHistogramShiftd,
    EnsureTyped,
    EnsureType,
)
from monai.visualize import plot_2d_or_3d_image
import matplotlib.pyplot as plt
from datetime import datetime
import shutil
import tqdm
from models.unetr2d import UNETR2D
from models.swin_unetr import SwinUNETR
from models.flexible_unet import FlexibleUNet 
from models.flexible_unet_convext import FlexibleUNetConvext
print("Successfully imported all requirements!")
torch.backends.cudnn.enabled =False
#os.environ["OMP_NUM_THREADS"] = "1" 
#os.environ["MKL_NUM_THREADS"] = "1" 
def main():
    parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
    # Dataset parameters
    parser.add_argument(
        "--data_path",
        default="",
        type=str,
        help="training data path; subfolders: images, labels",
    )
    parser.add_argument(
        "--work_dir", default="/mntnfs/med_data5/louwei/nips_comp/stardist_finetune1/", help="path where to save models and logs"
    )
    parser.add_argument(
        "--model_dir", default="/", help="path where to load pretrained model"
    )
    parser.add_argument("--seed", default=2022, type=int)
    # parser.add_argument("--resume", default=False, help="resume from checkpoint")
    parser.add_argument("--num_workers", default=4, type=int)
    #parser.add_argument("--local_rank", type=int)
    # Model parameters
    parser.add_argument(
        "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
    )
    parser.add_argument("--num_class", default=3, type=int, help="segmentation classes")
    parser.add_argument(
        "--input_size", default=512, type=int, help="segmentation classes"
    )
    # Training parameters
    parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU")
    parser.add_argument("--max_epochs", default=2000, type=int)
    parser.add_argument("--val_interval", default=10, type=int)
    parser.add_argument("--epoch_tolerance", default=100, type=int)
    parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")

    args = parser.parse_args()
    #torch.cuda.set_device(args.local_rank)
    #torch.distributed.init_process_group(backend='nccl')
    monai.config.print_config()
    n_rays = 32
    pre_trained = True
    #%% set training/validation split
    np.random.seed(args.seed)
	pre_trained_path = args.model_dir
    model_path = join(args.work_dir, args.model_name + "_3class")
    os.makedirs(model_path, exist_ok=True)
    run_id = datetime.now().strftime("%Y%m%d-%H%M")
    # This must be change every runing time ! ! ! ! ! ! ! ! ! ! !
    model_file = "models/flexible_unet_convext.py"
    shutil.copyfile(
        __file__, join(model_path, os.path.basename(__file__))
    )
    shutil.copyfile(
        model_file, join(model_path, os.path.basename(model_file))
    )
    img_path = join(args.data_path, "train/images")
    gt_path = join(args.data_path, "train/tif")
    val_img_path = join(args.data_path, "valid/images")
    val_gt_path = join(args.data_path, "valid/tif")
    img_names = sorted(os.listdir(img_path))
    gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names]
    img_num = len(img_names)
    val_frac = 0.1
    val_img_names = sorted(os.listdir(val_img_path))
    val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names]
    #indices = np.arange(img_num)
    #np.random.shuffle(indices)
    #val_split = int(img_num * val_frac)
    #train_indices = indices[val_split:]
    #val_indices = indices[:val_split]

    train_files = [
        {"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i])}
        for i in range(len(img_names))
    ]
    val_files = [
        {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])}
        for i in range(len(val_img_names))
    ]
    print(
        f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
    )
    #%% define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(
                keys=["img", "label"], reader=PILReader, dtype=np.float32
            ),  # image three channels (H, W, 3); label: (H, W)
            AddChanneld(keys=["label"], allow_missing_keys=True),  # label: (1, H, W)
            AsChannelFirstd(
                keys=["img"], channel_dim=-1, allow_missing_keys=True
            ),  # image: (3, H, W)
            #ScaleIntensityd(
                #keys=["img"], allow_missing_keys=True
            #),  # Do not scale label
            SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
            RandSpatialCropd(
                keys=["img", "label"], roi_size=args.input_size, random_size=False
            ),
            RandAxisFlipd(keys=["img", "label"], prob=0.5),
            RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
            # # intensity transform
            RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
            RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
            RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
            RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
            RandZoomd(
                keys=["img", "label"],
                prob=0.15,
                min_zoom=0.5,
                max_zoom=2,
                mode=["area", "nearest"],
            ),
            EnsureTyped(keys=["img", "label"]),
        ]
    )

    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
            AddChanneld(keys=["label"], allow_missing_keys=True),
            AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
            #ScaleIntensityd(keys=["img"], allow_missing_keys=True),
            # AsDiscreted(keys=['label'], to_onehot=3),
            EnsureTyped(keys=["img", "label"]),
        ]
    )

    #% define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
    check_data = monai.utils.misc.first(check_loader)
    print(
        "sanity check:",
        check_data["img"].shape,
        torch.max(check_data["img"]),
        check_data["label"].shape,
        torch.max(check_data["label"]),
    )

    #%% create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory =True,
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)

    dice_metric = DiceMetric(
        include_background=False, reduction="mean", get_not_nans=False
    )

    post_pred = Compose(
        [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
    )
    post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.model_name.lower() == "efficientunet":
        model = FlexibleUNetConvext(
            in_channels=3,
            out_channels=n_rays+1,
            backbone='convnext_small',
            pretrained=False,
        ).to(device)
  
    #loss_masked_dice = monai.losses.DiceCELoss(softmax=True)
    loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True)
    loss_bce = nn.BCELoss()
    loss_dist_mae = nn.L1Loss()
    activatation = nn.ReLU()
    sigmoid = nn.Sigmoid()
    #loss_dist_mae = monai.losses.DiceCELoss(softmax=True)
    initial_lr = args.initial_lr
    encoder = list(map(id, model.encoder.parameters()))
    base_params = filter(lambda p: id(p) not in encoder, model.parameters())
    params = [
        {"params": base_params, "lr":initial_lr},
        {"params": model.encoder.parameters(), "lr": initial_lr * 0.1},
    ]
    optimizer = torch.optim.AdamW(params, initial_lr)
    if pre_trained == True:

        checkpoint = torch.load(pre_trained_path, map_location=torch.device(device))
        model.load_state_dict(checkpoint['model_state_dict'])
        print('Load pretrained weights...')
    max_epochs = args.max_epochs
    epoch_tolerance = args.epoch_tolerance
    val_interval = args.val_interval
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(model_path)
    max_f1 = 0
    for epoch in range(0, max_epochs):
        model.train()
        epoch_loss = 0
        epoch_loss_prob = 0
        epoch_loss_dist_2 = 0
        epoch_loss_dist_1 = 0
        for step, batch_data in enumerate(train_loader, 1):
            print(step)
            inputs, labels = batch_data["img"],batch_data["label"]

            processes_labels = []
            
            for i in range(labels.shape[0]):
                label = labels[i][0]
                distances = star_dist(label,n_rays,mode='opencl')
                distances = np.transpose(distances,(2,0,1))
                #print(distances.shape)
                obj_probabilities = edt_prob(label.astype(int))
                obj_probabilities = np.expand_dims(obj_probabilities,0)
                #print(obj_probabilities.shape)
                final_label = np.concatenate((distances,obj_probabilities),axis=0)
                #print(final_label.shape)
                processes_labels.append(final_label)
            
            labels = np.stack(processes_labels)

            #print(inputs.shape,labels.shape)
            inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device)
            #print(inputs.shape,labels.shape)
            optimizer.zero_grad()
            output_dist,output_prob = model(inputs)
            #print(outputs.shape)
            dist_output = output_dist
            prob_output = output_prob
            dist_label = labels[:,:n_rays,:,:]
            prob_label = torch.unsqueeze(labels[:,-1,:,:], 1)
            #print(dist_output.shape,prob_output.shape,dist_label.shape)
            #labels_onehot = monai.networks.one_hot(
                #labels, args.num_class
            #)  # (b,cls,256,256)
            #print(prob_label.max(),prob_label.min())
            loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label)
            #print(loss_dist_1)
            loss_prob = loss_bce(prob_output,prob_label)
            #print(prob_label.shape,dist_output.shape)
            loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label)
            #print(loss_dist_2)
            loss = loss_prob + loss_dist_2*0.3 + loss_dist_1
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_loss_prob += loss_prob.item()
            epoch_loss_dist_2 += loss_dist_2.item()
            epoch_loss_dist_1 += loss_dist_1.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            
        epoch_loss /= step
        epoch_loss_prob /= step
        epoch_loss_dist_2 /= step
        epoch_loss_dist_1 /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch} average loss: {epoch_loss:.4f}")
        writer.add_scalar("train_loss", epoch_loss, epoch)
        print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob))
        checkpoint = {	
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": epoch_loss_values,
        }
        if epoch < 40:
            continue
        if epoch > 1 and epoch % val_interval == 0:
            torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                seg_metric = OrderedDict()
                seg_metric['F1_Score'] = []
                for val_data in tqdm.tqdm(val_loader):
                    val_images, val_labels = val_data["img"].to(device), val_data[
                        "label"
                    ].to(device)
                    roi_size = (512, 512)
                    sw_batch_size = 4
                    output_dist,output_prob = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model
                        )
                    val_labels = val_labels[0][0].cpu().numpy()
                    prob = output_prob[0][0].cpu().numpy()
                    dist = output_dist[0].cpu().numpy()
                    #print(val_labels.shape,prob.shape,dist.shape)
                    dist = np.transpose(dist,(1,2,0))
                    dist = np.maximum(1e-3, dist)
                    points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)

                    coord = dist_to_coord(disti,points)
            
                    star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
                    gt = remove_boundary_cells(val_labels.astype(np.int32)) 
                    seg = remove_boundary_cells(star_label.astype(np.int32))           
                    tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
                    if tp == 0:
                        precision = 0
                        recall = 0
                        f1 = 0
                    else:
                        precision = tp / (tp + fp)
                        recall = tp / (tp + fn)
                        f1 = 2*(precision * recall)/ (precision + recall)
                    f1 = np.round(f1, 4)
                    seg_metric['F1_Score'].append(np.round(f1, 4))
                avg_f1 = np.mean(seg_metric['F1_Score'])
                writer.add_scalar("val_f1score", avg_f1, epoch)
                if avg_f1 > max_f1:
                    max_f1 = avg_f1
                    print(str(epoch) + 'f1 score: ' + str(max_f1))
                    torch.save(checkpoint, join(model_path, "best_model.pth"))
    np.savez_compressed(
        join(model_path, "train_log.npz"),
        val_dice=metric_values,
        epoch_loss=epoch_loss_values,
    )


if __name__ == "__main__":
    main()