import gradio as gr
import os
import zipfile
import json
from io import BytesIO
import base64
from PIL import Image
import uuid
import tempfile
import numpy as np
import time

# Function to save dataset to zip

def save_dataset_to_zip(dataset_name, dataset):
    temp_dir = tempfile.mkdtemp()
    dataset_path = os.path.join(temp_dir, dataset_name)
    os.makedirs(dataset_path, exist_ok=True)
    images_dir = os.path.join(dataset_path, 'images')
    os.makedirs(images_dir, exist_ok=True)

    annotations = []
    for idx, entry in enumerate(dataset):
        image_data = entry['image']
        prompt = entry['prompt']

        # Save image to images directory
        image_filename = f"{uuid.uuid4().hex}.png"
        image_path = os.path.join(images_dir, image_filename)
        # Decode the base64 image data
        image = Image.open(BytesIO(base64.b64decode(image_data.split(",")[1])))
        image.save(image_path)

        # Add annotation
        annotations.append({
            'file_name': os.path.join('images', image_filename),
            'text': prompt
        })

    # Save annotations to JSONL file
    annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
    with open(annotations_path, 'w') as f:
        for ann in annotations:
            f.write(json.dumps(ann) + '\n')

    # Create a zip file with the dataset_name as the top-level folder
    zip_buffer = BytesIO()
    with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                abs_file = os.path.join(root, file)
                rel_file = os.path.relpath(abs_file, temp_dir)
                zipf.write(abs_file, rel_file)

    zip_buffer.seek(0)
    return zip_buffer

# Function to load dataset from zip

def load_dataset_from_zip(zip_file_path):
    temp_dir = tempfile.mkdtemp()
    try:
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(temp_dir)

        # Get dataset name from zip file name
        dataset_name_guess = os.path.splitext(os.path.basename(zip_file_path))[0]
        dataset_path = os.path.join(temp_dir, dataset_name_guess)

        if os.path.exists(dataset_path):
            dataset_name = dataset_name_guess
        else:
            # If the dataset_name directory doesn't exist, try to find the top-level directory
            entries = [entry for entry in os.listdir(temp_dir) if os.path.isdir(os.path.join(temp_dir, entry))]
            if entries:
                dataset_name = entries[0]
                dataset_path = os.path.join(temp_dir, dataset_name)
            else:
                # Files are directly in temp_dir
                dataset_name = dataset_name_guess
                dataset_path = temp_dir

        annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
        dataset = []

        if os.path.exists(annotations_path):
            with open(annotations_path, 'r') as f:
                for line in f:
                    ann = json.loads(line)
                    file_name = ann['file_name']
                    prompt = ann['text']
                    image_path = os.path.join(dataset_path, file_name)

                    # Read image and convert to base64
                    with open(image_path, 'rb') as img_f:
                        image_bytes = img_f.read()
                        encoded = base64.b64encode(image_bytes).decode()
                        mime_type = "image/png"
                        image_data = f"data:{mime_type};base64,{encoded}"

                    dataset.append({
                        'image': image_data,
                        'prompt': prompt
                    })
        else:
            # If annotations file not found
            return None, []

        return dataset_name, dataset
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None, []

# Function to display dataset as HTML

def display_dataset_html(dataset, page_number=0, items_per_page=2):
    if dataset:
        start_idx = page_number * items_per_page
        end_idx = start_idx + items_per_page
        dataset_slice = dataset[start_idx:end_idx]
        html_content = '''
        <div style="display: flex; overflow-x: auto; padding: 10px; border: 1px solid #ccc;">'''
        for idx_offset, entry in enumerate(dataset_slice):
            idx = start_idx + idx_offset
            image_data = entry['image']
            prompt = entry['prompt']
            # Decode base64 image data to numpy array
            image_bytes = base64.b64decode(image_data.split(",")[1])
            image = Image.open(BytesIO(image_bytes))
            # Compress image
            image.thumbnail((100, 100))  # Resize image to 100x100 pixels
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
            img_data = f"data:image/png;base64,{img_str}"
            html_content += f"""
            <div style="display: flex; flex-direction: column; align-items: center; margin-right: 20px;">'''
            <div style="margin-bottom: 5px;">{idx}</div>
            <img src="{img_data}" alt="Image {idx}" style="max-height: 150px;"/>
            <div style="max-width: 150px; word-wrap: break-word; text-align: center;">{prompt}</div>
            </div>
            """
        html_content += '</div>'
        return html_content
    else:
        return "<div>No entries in dataset.</div>"

# Interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center; margin-bottom: 1px;'>Dataset Creator</h1>")
    gr.Markdown("You must create/upload a dataset before selecting one")
    datasets = gr.State({})
    current_dataset_name = gr.State("")
    current_page_number = gr.State(0)
    dataset_html = gr.HTML()  # Define dataset_html here

    # Top-level components
    with gr.Column():
        dataset_selector = gr.Dropdown(label="Select Dataset", interactive=True)
        message_box = gr.Textbox(interactive=False, label="Message")

    # Tabs
    with gr.Tabs():
        with gr.TabItem("Create / Upload Dataset"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Create a New Dataset")
                    dataset_name_input = gr.Textbox(label="New Dataset Name")
                    create_button = gr.Button("Create Dataset")
                with gr.Column():
                    gr.Markdown("### Upload Existing Dataset")
                    upload_input = gr.File(label="Upload Dataset Zip", type="filepath", file_types=['.zip'])
                    upload_button = gr.Button("Upload Dataset")

            def create_dataset(name, datasets):
                if not name:
                    return gr.update(), "Please enter a dataset name."
                if name in datasets:
                    return gr.update(), f"Dataset '{name}' already exists."
                datasets[name] = []
                return gr.update(choices=list(datasets.keys()), value=name), f"Dataset '{name}' created."

            create_button.click(
                create_dataset,
                inputs=[dataset_name_input, datasets],
                outputs=[dataset_selector, message_box]
            )

            def upload_dataset(zip_file_path, datasets):
                if not zip_file_path:
                    return gr.update(), "Please upload a zip file."
                dataset_name, dataset = load_dataset_from_zip(zip_file_path)
                if dataset_name is None:
                    return gr.update(), "Failed to load dataset from zip file."
                if dataset_name in datasets:
                    return gr.update(), f"Dataset '{dataset_name}' already exists."
                datasets[dataset_name] = dataset
                return gr.update(choices=list(datasets.keys()), value=dataset_name), f"Dataset '{dataset_name}' uploaded."

            upload_button.click(
                upload_dataset,
                inputs=[upload_input, datasets],
                outputs=[dataset_selector, message_box]
            )

        with gr.TabItem("Add Entry"):
            with gr.Row():
                image_input = gr.Image(label="Upload Image", type="numpy")
                prompt_input = gr.Textbox(label="Prompt")
            add_button = gr.Button("Add Entry")

            def add_entry(image_data, prompt, current_dataset_name, datasets):
                if not current_dataset_name:
                    return datasets, gr.update(), gr.update(), "No dataset selected."
                if image_data is None or not prompt:
                    return datasets, gr.update(), gr.update(), "Please provide both an image and a prompt."
                # Convert image_data to base64
                image = Image.fromarray(image_data.astype('uint8'))
                buffered = BytesIO()
                image.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
                img_data = f"data:image/png;base64,{img_str}"
                datasets[current_dataset_name].append({'image': img_data, 'prompt': prompt})
                dataset = datasets[current_dataset_name]
                # Reset page number to 0 and refresh HTML
                page_number = 0
                dataset = datasets[current_dataset_name]
                html_content = display_dataset_html(dataset, page_number=page_number)
                return datasets, page_number, gr.update(value=html_content), f"Entry added to dataset '{current_dataset_name}'."

            add_button.click(
                add_entry,
                inputs=[image_input, prompt_input, current_dataset_name, datasets],
                outputs=[datasets, current_page_number, dataset_html, message_box]
            )

        with gr.TabItem("Edit / Delete Entry"):
            with gr.Column():
                selected_image = gr.Image(label="Selected Image", interactive=False, type="numpy")
                selected_prompt = gr.Textbox(label="Current Prompt", interactive=False)
                # Define entry_selector here
                entry_selector = gr.Dropdown(label="Select Entry to Edit/Delete")
                new_prompt_input = gr.Textbox(label="New Prompt (for Edit)")
                with gr.Row():
                    edit_button = gr.Button("Edit Entry")
                    delete_button = gr.Button("Delete Entry")

            def update_selected_entry(entry_option, current_dataset_name, datasets):
                if not current_dataset_name or not entry_option:
                    return gr.update(), gr.update()
                index = int(entry_option.split(":")[0])
                entry = datasets[current_dataset_name][index]
                image_data = entry['image']
                prompt = entry['prompt']
                # Decode base64 image data to numpy array
                image_bytes = base64.b64decode(image_data.split(",")[1])
                image = Image.open(BytesIO(image_bytes))
                image_array = np.array(image)
                return gr.update(value=image_array), gr.update(value=prompt)

            entry_selector.change(
                update_selected_entry,
                inputs=[entry_selector, current_dataset_name, datasets],
                outputs=[selected_image, selected_prompt]
            )

            def edit_entry(entry_option, new_prompt, current_dataset_name, datasets, current_page_number):
                if not current_dataset_name:
                    return datasets, gr.update(), gr.update(), gr.update(), f"No dataset selected."
                if not entry_option or not new_prompt.strip():
                    return datasets, gr.update(), gr.update(), gr.update(), f"Please select an entry and provide a new prompt."
                index = int(entry_option.split(":")[0])
                datasets[current_dataset_name][index]['prompt'] = new_prompt
                dataset = datasets[current_dataset_name]
                html_content = display_dataset_html(dataset, page_number=current_page_number)
                # Update entry_selector options
                entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
                return datasets, gr.update(value=html_content), gr.update(choices=entry_options), gr.update(value=""), f"Entry {index} updated."

            edit_button.click(
                edit_entry,
                inputs=[entry_selector, new_prompt_input, current_dataset_name, datasets, current_page_number],
                outputs=[datasets, dataset_html, entry_selector, new_prompt_input, message_box]
            )

            def delete_entry(entry_option, current_dataset_name, datasets, current_page_number):
                if not current_dataset_name:
                    return datasets, gr.update(), gr.update(), gr.update(), gr.update(), "No dataset selected."
                if not entry_option:
                    return datasets, gr.update(), gr.update(), gr.update(), gr.update(), "Please select an entry to delete."
                index = int(entry_option.split(":")[0])
                del datasets[current_dataset_name][index]
                dataset = datasets[current_dataset_name]
                html_content = display_dataset_html(dataset, page_number=current_page_number)
                # Update entry_selector options
                entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
                return datasets, gr.update(value=html_content), gr.update(choices=entry_options), gr.update(value=None), f"Entry {index} deleted."

            delete_button.click(
                delete_entry,
                inputs=[entry_selector, current_dataset_name, datasets, current_page_number],
                outputs=[datasets, dataset_html, entry_selector, selected_image, message_box]
            )

            # Function to update entry_selector options
            def update_entry_selector(current_dataset_name, datasets):
                if current_dataset_name in datasets:
                    dataset = datasets[current_dataset_name]
                    entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
                    return gr.update(choices=entry_options)
                else:
                    return gr.update(choices=[])

            # Update entry_selector when dataset is selected
            dataset_selector.change(
                update_entry_selector,
                inputs=[current_dataset_name, datasets],
                outputs=[entry_selector]
            )

            # Also update entry_selector when an entry is added in "Add Entry" tab
            add_button.click(
                update_entry_selector,
                inputs=[current_dataset_name, datasets],
                outputs=[entry_selector]
            )

        with gr.TabItem("Download Dataset"):
            download_button = gr.Button("Download Dataset")
            download_output = gr.File(label="Download Zip", interactive=False)

            def download_dataset(current_dataset_name, datasets):
                if not current_dataset_name:
                    return None, "No dataset selected."
                if not datasets[current_dataset_name]:
                    return None, "Dataset is empty."
                zip_buffer = save_dataset_to_zip(current_dataset_name, datasets[current_dataset_name])
                # Write zip_buffer to a temporary file
                temp_dir = tempfile.mkdtemp()
                zip_path = os.path.join(temp_dir, f"{current_dataset_name}.zip")
                with open(zip_path, 'wb') as f:
                    f.write(zip_buffer.getvalue())
                return zip_path, f"Dataset '{current_dataset_name}' is ready for download."

            download_button.click(
                download_dataset,
                inputs=[current_dataset_name, datasets],
                outputs=[download_output, message_box]
            )

    def select_dataset(dataset_name, datasets):
        if dataset_name in datasets:
            dataset = datasets[dataset_name]
            html_content = display_dataset_html(dataset, page_number=0)
            return dataset_name, 0, gr.update(value=html_content), f"Dataset '{dataset_name}' selected."
        else:
            return "", 0, gr.update(value="<div>Select a dataset.</div>"), ""

    dataset_selector.change(
        select_dataset,
        inputs=[dataset_selector, datasets],
        outputs=[current_dataset_name, current_page_number, dataset_html, message_box]
    )

    # Dataset Viewer and Pagination Controls at the Bottom
    with gr.Column():
        gr.Markdown("### Dataset Viewer")
        dataset_viewer = gr.HTML()  # Use dataset_viewer instead of dataset_html
        with gr.Row():
            prev_button = gr.Button("Previous Page")
            next_button = gr.Button("Next Page")

    def change_page(action, current_page_number, datasets, current_dataset_name):
        if not current_dataset_name:
            return current_page_number, gr.update(), "No dataset selected."
        dataset = datasets[current_dataset_name]
        total_pages = (len(dataset) - 1) // 5 + 1
        if action == "next":
            if current_page_number + 1 < total_pages:
                current_page_number += 1
        elif action == "prev":
            if current_page_number > 0:
                current_page_number -= 1
        html_content = display_dataset_html(dataset, page_number=current_page_number)
        return current_page_number, gr.update(value=html_content), ""

    prev_button.click(
    fn=lambda current_page_number, datasets, current_dataset_name: change_page("prev", current_page_number, datasets, current_dataset_name),
    inputs=[current_page_number, datasets, current_dataset_name],
    outputs=[current_page_number, dataset_viewer, message_box]
    )

    next_button.click(
    fn=lambda current_page_number, datasets, current_dataset_name: change_page("next", current_page_number, datasets, current_dataset_name),
    inputs=[current_page_number, datasets, current_dataset_name],
    outputs=[current_page_number, dataset_viewer, message_box]
    )

    # Initialize dataset_selector
    def initialize_components(datasets):
        return gr.update(choices=list(datasets.keys()))

    demo.load(
        initialize_components,
        inputs=[datasets],
        outputs=[dataset_selector]
    )

    # Hide dataset_html
    dataset_html.visible = False

    # Update all components when a dataset is selected
    def update_all_components(current_dataset_name, datasets):
        while current_dataset_name not in datasets:
            time.sleep(0.1)  # Wait until dataset is loaded
        dataset = datasets[current_dataset_name]
        html_content = display_dataset_html(dataset, page_number=0)
        entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
        return gr.update(value=html_content), gr.update(choices=entry_options)

    dataset_selector.change(
        update_all_components,
        inputs=[current_dataset_name, datasets],
        outputs=[dataset_viewer, entry_selector]
    )

    # Update all components when an entry is added
    def update_all_components_after_add(current_dataset_name, datasets):
        while current_dataset_name not in datasets:
            time.sleep(0.1)  # Wait until dataset is loaded
        dataset = datasets[current_dataset_name]
        html_content = display_dataset_html(dataset, page_number=0)
        entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
        return gr.update(value=html_content), gr.update(choices=entry_options)

    add_button.click(
        update_all_components_after_add,
        inputs=[current_dataset_name, datasets],
        outputs=[dataset_viewer, entry_selector]
    )

    # Update all components when an entry is edited
    def update_all_components_after_edit(current_dataset_name, datasets):
        while current_dataset_name not in datasets:
            time.sleep(0.1)  # Wait until dataset is loaded
        dataset = datasets[current_dataset_name]
        html_content = display_dataset_html(dataset, page_number=0)
        entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
        return gr.update(value=html_content), gr.update(choices=entry_options)

    edit_button.click(
        update_all_components_after_edit,
        inputs=[current_dataset_name, datasets],
        outputs=[dataset_viewer, entry_selector]
    )

    # Update all components when an entry is deleted
    def update_all_components_after_delete(current_dataset_name, datasets):
        while current_dataset_name not in datasets:
            time.sleep(0.1)  # Wait until dataset is loaded
        dataset = datasets[current_dataset_name]
        html_content = display_dataset_html(dataset, page_number=0)
        entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
        return gr.update(value=html_content), gr.update(choices=entry_options)

    delete_button.click(
        update_all_components_after_delete,
        inputs=[current_dataset_name, datasets],
        outputs=[dataset_viewer, entry_selector]
    )

demo.launch()