plotly-tool / app.py
burtenshaw
first commit
cb64143
raw
history blame contribute delete
3.85 kB
import gradio as gr
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from PIL import Image
from io import BytesIO
def generate_plot(
x_sequence: str,
y_sequence: str,
plot_type: str,
x_label: str,
y_label: str,
width: int,
height: int
) -> Image:
"""
Generate a plot based on the provided x and y sequences and plot type.
Parameters:
- x_sequence (str): A comma-separated string of x values.
- y_sequence (str): A comma-separated string of y values.
- plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix').
- x_label (str): Label for the x-axis.
- y_label (str): Label for the y-axis.
- width (int): Width of the plot.
- height (int): Height of the plot.
Returns:
- Image: A PIL Image object of the generated plot.
"""
# Convert the input sequences to lists of numbers
try:
x_data = list(map(float, x_sequence.split(",")))
y_data = list(map(float, y_sequence.split(",")))
except ValueError:
return "Invalid input. Please enter sequences of numbers separated by commas."
# Ensure the x and y sequences have the same length
if len(x_data) != len(y_data):
return "The x and y sequences must have the same length."
# Create a DataFrame for plotting
df = pd.DataFrame({"x": x_data, "y": y_data})
# Set default width and height if not provided
width = width if width else 800
height = height if height else 600
# Generate the plot based on the selected type
if plot_type == "Bar":
fig = px.bar(
df,
x="x",
y="y",
title="Bar Plot",
labels={"x": x_label, "y": y_label},
width=width,
height=height,
)
elif plot_type == "Scatter":
fig = px.scatter(
df,
x="x",
y="y",
title="Scatter Plot",
labels={"x": x_label, "y": y_label},
width=width,
height=height,
)
elif plot_type == "Confusion Matrix":
# For demonstration, create a confusion matrix from the sequence
y_true = np.random.randint(0, 2, len(y_data))
y_pred = np.array(y_data) > 0.5
cm = confusion_matrix(y_true, y_pred)
fig = px.imshow(
cm, text_auto=True, title="Confusion Matrix", width=width, height=height
)
else:
return "Invalid plot type selected."
# Convert the plot to a PNG image
img_bytes = fig.to_image(
format="png", width=width, height=height, scale=2, engine="kaleido"
)
return Image.open(BytesIO(img_bytes))
# Define the Gradio interface using the new syntax
app = gr.Interface(
fn=generate_plot,
inputs=[
gr.Textbox(
lines=2,
placeholder="Enter x sequence of numbers separated by commas",
label="X",
),
gr.Textbox(
lines=2,
placeholder="Enter y sequence of numbers separated by commas",
label="Y",
),
gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"),
gr.Textbox(
placeholder="Enter x-axis label (optional)", label="X_Label", value=""
),
gr.Textbox(
placeholder="Enter y-axis label (optional)", label="Y_Label", value=""
),
gr.Number(
value=800,
label="Width",
),
gr.Number(value=600, label="Height"),
],
outputs=gr.Image(type="pil", label="Generated Plot"),
title="Plotly Plot Generator",
description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.",
)
# Launch the app
app.launch()