import numpy as np
import pandas as pd
import streamlit as st

from PIL import Image

import torch
import torch.nn.functional as F
import pytesseract

import plotly.express as px

from torch.utils.data import Dataset, DataLoader,  Subset
import os
import io
import pytesseract
import fitz
from typing import List
import json

import sys
from pathlib import Path

from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TOKENIZER = "microsoft/layoutlmv3-base"
MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01"

TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract

@st.cache_resource
def create_ocr_reader():
    def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0):
        return [
            int(box[0] * w_scale),
            int(box[1] * h_scale),
            int(box[2] * w_scale),
            int(box[3] * h_scale)
        ]
    def ocr_page(image) -> dict:
        """
        OCR a given image. Return a dictionary of words and the bounding boxes
        for each word. For each word, there is a corresponding bounding box.
        """
        ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS)
        ocr_df = ocr_df.dropna().reset_index(drop=True)
        float_cols = ocr_df.select_dtypes('float').columns
        ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
        ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
        ocr_df = ocr_df.dropna().reset_index(drop=True)

        words = list(ocr_df.text)
        words = [str(w) for w in words]

        coordinates = ocr_df[['left', 'top', 'width', 'height']]
        boxes = []
        for i, row in coordinates.iterrows():
            x, y, w, h = tuple(row)
            actual_box = [x, y, x + w, y + h]
            boxes.append(actual_box)

        assert len(words) == len(boxes)
        return {"bbox": boxes, "words": words}
        
    def prepare_image(image):
        ocr_data = ocr_page(image)
        width, height = image.size
        width_scale = 1000 / width
        height_scale = 1000 / height
        words = []
        boxes = []
        for w, b in zip(ocr_data["words"], ocr_data["bbox"]):
            words.append(w)
            boxes.append(scale_bounding_box(b, width_scale, height_scale))

        assert len(words) == len(boxes)
        for bo in boxes:
            for z in bo:
                if (z > 1000):
                    raise
        return words, boxes

    return prepare_image

@st.cache_resource
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)

@st.cache_resource
def create_processor():
    feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER)
    return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
    words, boxes = reader(image)
    encoding = processor(
        image,
        words,
        boxes=boxes,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE)            
        )
        logits = output.logits
        predicted_class = logits.argmax()
        probabilities = F.softmax(logits, dim=-1).flatten().tolist()
        return predicted_class.detach().item(), probabilities

st.markdown(f"Test")

# reader = create_ocr_reader()
# processor = create_processor()
# model = create_model()

# uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"])
# if uploaded_file is not None:
#     bytes_data = io.BytesIO(uploaded_file.read())    
#     image = Image.open(bytes_data)
#     st.image(image, caption="Uploaded Image", use_column_width=True)
#     predicted, probabilities = predict(image, reader, processor, model)
#     predicted_label = model.config.id2label[predicted]
#     st.markdown(f"Predicted Label: {predicted_label}")

#     df = pd.DataFrame({
#         "Label": list(model.config.id2label.values()),
#         "Probability": probabilities
#     })
#     fig = px.bar(df, x="Label", y="Probability")
#     st.plotly_chart(fig, use_container_width=True)