import json
import os
import torch
import psutil
import gc
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.data.objaverse import load_obj
from src.utils import mesh
from src.utils.material import Material
import argparse


def bytes_to_megabytes(bytes):
    return bytes / (1024 * 1024)


def bytes_to_gigabytes(bytes):
    return bytes / (1024 * 1024 * 1024)


def print_memory_usage(stage):
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    allocated = torch.cuda.memory_allocated() / 1024**2
    cached = torch.cuda.memory_reserved() / 1024**2
    print(
        f"[{stage}] Process memory: {memory_info.rss / 1024**2:.2f} MB, "
        f"Allocated CUDA memory: {allocated:.2f} MB, Cached CUDA memory: {cached:.2f} MB"
    )


def process_obj(index, root_dir, final_save_dir, paths):
    obj_path = os.path.join(root_dir, paths[index], paths[index] + '.obj')
    mtl_path = os.path.join(root_dir, paths[index], paths[index] + '.mtl')

    if os.path.exists(os.path.join(final_save_dir, f"{paths[index]}.pth")):
        return None

    try:
        with torch.no_grad():
            ref_mesh, vertices, faces, normals, nfaces, texcoords, tfaces, uber_material = load_obj(
                obj_path, return_attributes=True
            )
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            ref_mesh = mesh.compute_tangents(ref_mesh)

        with open(mtl_path, 'r') as file:
            lines = file.readlines()

        if len(lines) >= 250:
            return None

        final_mesh_attributes = {
            "v_pos": ref_mesh.v_pos.detach().cpu(),
            "v_nrm": ref_mesh.v_nrm.detach().cpu(),
            "v_tex": ref_mesh.v_tex.detach().cpu(),
            "v_tng": ref_mesh.v_tng.detach().cpu(),
            "t_pos_idx": ref_mesh.t_pos_idx.detach().cpu(),
            "t_nrm_idx": ref_mesh.t_nrm_idx.detach().cpu(),
            "t_tex_idx": ref_mesh.t_tex_idx.detach().cpu(),
            "t_tng_idx": ref_mesh.t_tng_idx.detach().cpu(),
            "mat_dict": {key: ref_mesh.material[key] for key in ref_mesh.material.mat_keys},
        }

        torch.save(final_mesh_attributes, f"{final_save_dir}/{paths[index]}.pth")
        print(f"==> Saved to {final_save_dir}/{paths[index]}.pth")

        del ref_mesh
        torch.cuda.empty_cache()
        return paths[index]

    except Exception as e:
        print(f"Failed to process {paths[index]}: {e}")
        return None

    finally:
        gc.collect()
        torch.cuda.empty_cache()


def main(root_dir, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    finish_lists = os.listdir(save_dir)
    paths = os.listdir(root_dir)

    valid_uid = []

    print_memory_usage("Start")

    batch_size = 100
    num_batches = (len(paths) + batch_size - 1) // batch_size

    for batch in tqdm(range(num_batches)):
        start_index = batch * batch_size
        end_index = min(start_index + batch_size, len(paths))

        with ThreadPoolExecutor(max_workers=8) as executor:
            futures = [
                executor.submit(process_obj, index, root_dir, save_dir, paths)
                for index in range(start_index, end_index)
            ]
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    valid_uid.append(result)

        print_memory_usage(f"=====> After processing batch {batch + 1}")
        torch.cuda.empty_cache()
        gc.collect()

    print_memory_usage("End")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process OBJ files and save final results.")
    parser.add_argument("root_dir", type=str, help="Directory containing the root OBJ files.")
    parser.add_argument("save_dir", type=str, help="Directory to save the processed results.")
    args = parser.parse_args()

    main(args.root_dir, args.save_dir)