import os
import pandas as pd
import plotly.express as px
import gradio as gr
import urllib.parse
import plotly.graph_objects as go
import numpy as np



def read_google_sheet(sheet_id, sheet_name):
    # URL encode the sheet name
    encoded_sheet_name = urllib.parse.quote(sheet_name)
    
    # Construct the base URL
    base_url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={encoded_sheet_name}"
    
    try:
        # Read the sheet into a pandas DataFrame
        df = pd.read_csv(base_url)
        return df
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

# Function to generate tick values and labels
def log2_ticks(values):
    min_val, max_val = np.floor(values.min()), np.ceil(values.max())
    print(max_val, min_val)
    tick_vals = np.arange(min_val, max_val+1)
    tick_text = [f"{2**val:.0f}" for val in tick_vals]
    return tick_vals, tick_text

# Load data
sheet_id = "1g07tdGf9ocOZ8XZgLGepI5Q4u6ZH961J0T9O9P64rYw"
sheet_names = [f"{i} node" if i == 1 else f"{i} nodes" for i in [1, 8]]

df = pd.concat([read_google_sheet(sheet_id, sheet_name) for sheet_name in sheet_names])
df = df.rename(columns={"micro_batch_size":"mbs", "batch_accumulation_per_replica": "gradacc"})
df["tok/s/gpu"] = df["tok/s/gpu"].replace(-1, 0)
df["throughput"] = df["tok/s/gpu"]*df["nnodes"]*8



def get_figure(nodes, hide_nans):
    
    # Create a temporary DataFrame with only the rows where nnodes is 8
    df_tmp = df[df["nnodes"]==nodes].reset_index(drop=True)

    if hide_nans:
        df_tmp = df_tmp.dropna()

    # Apply log2 scale to all columns except throughput
    log_columns = ['dp', 'tp', 'pp', 'mbs', 'gradacc']
    for col in log_columns:
        df_tmp[f'log_{col}'] = np.log2(df_tmp[col])
    
    
    
    # Generate dimensions list
    dimensions = []
    for col in log_columns:
        ticks, labels = log2_ticks(df_tmp[f'log_{col}'])
        dimensions.append(
            dict(range = [df_tmp[f'log_{col}'].min(), df_tmp[f'log_{col}'].max()],
                 label = col,
                 values = df_tmp[f'log_{col}'],
                 tickvals = ticks,
                 ticktext = labels)
        )
    
    # Add throughput dimension (not log-scaled)
    dimensions.append(
        dict(range = [df_tmp['throughput'].min(), df_tmp['throughput'].max()],
             label = 'throughput', 
             values = df_tmp['throughput'])
    )
    
    fig = go.Figure(data=
        go.Parcoords(
            line = dict(color = df_tmp['throughput'],
                        colorscale = 'GnBu',
                        showscale = True,
                        cmin = df_tmp['throughput'].min(),
                        cmax = df_tmp['throughput'].max()),
            dimensions = dimensions
        )
    )
    
    # Update the layout if needed
    fig.update_layout(
        title = "3D parallel setup throughput ",
        plot_bgcolor = 'white',
        paper_bgcolor = 'white'
    )

    
    return fig


with gr.Blocks() as demo:
    title = gr.Markdown("# 3D parallel benchmark")
    with gr.Row():
        nnodes = gr.Dropdown(choices=[1, 8], label="Number of nodes", value=8)
        hide_nan = gr.Dropdown(choices=[False, True], label="Hide NaNs", value=False)

    plot = gr.Plot()
    demo.load(get_figure, [nnodes, hide_nan], [plot])
    nnodes.change(get_figure, [nnodes, hide_nan], [plot])
    hide_nan.change(get_figure, [nnodes, hide_nan], [plot])

demo.launch(show_api=False)