Spaces:
Running
Running
import gradio as gr | |
import plotly.express as px | |
import numpy as np | |
import pandas as pd | |
from sklearn.metrics import confusion_matrix | |
from PIL import Image | |
from io import BytesIO | |
def generate_plot( | |
x_sequence: str, | |
y_sequence: str, | |
plot_type: str, | |
x_label: str, | |
y_label: str, | |
width: int, | |
height: int | |
) -> Image: | |
""" | |
Generate a plot based on the provided x and y sequences and plot type. | |
Parameters: | |
- x_sequence (str): A comma-separated string of x values. | |
- y_sequence (str): A comma-separated string of y values. | |
- plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix'). | |
- x_label (str): Label for the x-axis. | |
- y_label (str): Label for the y-axis. | |
- width (int): Width of the plot. | |
- height (int): Height of the plot. | |
Returns: | |
- Image: A PIL Image object of the generated plot. | |
""" | |
# Convert the input sequences to lists of numbers | |
try: | |
x_data = list(map(float, x_sequence.split(","))) | |
y_data = list(map(float, y_sequence.split(","))) | |
except ValueError: | |
return "Invalid input. Please enter sequences of numbers separated by commas." | |
# Ensure the x and y sequences have the same length | |
if len(x_data) != len(y_data): | |
return "The x and y sequences must have the same length." | |
# Create a DataFrame for plotting | |
df = pd.DataFrame({"x": x_data, "y": y_data}) | |
# Set default width and height if not provided | |
width = width if width else 800 | |
height = height if height else 600 | |
# Generate the plot based on the selected type | |
if plot_type == "Bar": | |
fig = px.bar( | |
df, | |
x="x", | |
y="y", | |
title="Bar Plot", | |
labels={"x": x_label, "y": y_label}, | |
width=width, | |
height=height, | |
) | |
elif plot_type == "Scatter": | |
fig = px.scatter( | |
df, | |
x="x", | |
y="y", | |
title="Scatter Plot", | |
labels={"x": x_label, "y": y_label}, | |
width=width, | |
height=height, | |
) | |
elif plot_type == "Confusion Matrix": | |
# For demonstration, create a confusion matrix from the sequence | |
y_true = np.random.randint(0, 2, len(y_data)) | |
y_pred = np.array(y_data) > 0.5 | |
cm = confusion_matrix(y_true, y_pred) | |
fig = px.imshow( | |
cm, text_auto=True, title="Confusion Matrix", width=width, height=height | |
) | |
else: | |
return "Invalid plot type selected." | |
# Convert the plot to a PNG image | |
img_bytes = fig.to_image( | |
format="png", width=width, height=height, scale=2, engine="kaleido" | |
) | |
return Image.open(BytesIO(img_bytes)) | |
# Define the Gradio interface using the new syntax | |
app = gr.Interface( | |
fn=generate_plot, | |
inputs=[ | |
gr.Textbox( | |
lines=2, | |
placeholder="Enter x sequence of numbers separated by commas", | |
label="X", | |
), | |
gr.Textbox( | |
lines=2, | |
placeholder="Enter y sequence of numbers separated by commas", | |
label="Y", | |
), | |
gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"), | |
gr.Textbox( | |
placeholder="Enter x-axis label (optional)", label="X_Label", value="" | |
), | |
gr.Textbox( | |
placeholder="Enter y-axis label (optional)", label="Y_Label", value="" | |
), | |
gr.Number( | |
value=800, | |
label="Width", | |
), | |
gr.Number(value=600, label="Height"), | |
], | |
outputs=gr.Image(type="pil", label="Generated Plot"), | |
title="Plotly Plot Generator", | |
description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.", | |
) | |
# Launch the app | |
app.launch() | |