File size: 2,042 Bytes
27b2e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf95fc
27b2e81
 
 
 
 
 
 
 
 
 
 
 
edf95fc
27b2e81
 
 
 
2ce4950
27b2e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from vega_datasets import data

cars = data.cars()
iris = data.iris()

# # Or generate your own fake data

# import pandas as pd
# import random

# cars_data = {
#     "Name": ["car name " + f" {int(i/10)}" for i in range(400)],
#     "Miles_per_Gallon": [random.randint(10, 30) for _ in range(400)],
#     "Origin": [random.choice(["USA", "Europe", "Japan"]) for _ in range(400)],
#     "Horsepower": [random.randint(50, 250) for _ in range(400)],
# }

# iris_data = {
#     "petalWidth": [round(random.uniform(0, 2.5), 2) for _ in range(150)],
#     "petalLength": [round(random.uniform(0, 7), 2) for _ in range(150)],
#     "species": [
#         random.choice(["setosa", "versicolor", "virginica"]) for _ in range(150)
#     ],
# }

# cars = pd.DataFrame(cars_data)
# iris = pd.DataFrame(iris_data)

def scatter_plot_fn(dataset):
    if dataset == "iris":
        return gr.ScatterPlot(
            value=iris,
            x="petalWidth",
            y="petalLength",
            color="species",
            title="Iris Dataset",
            color_legend_title="Species",
            x_title="Petal Width",
            y_title="Petal Length",
            tooltip=["petalWidth", "petalLength", "species"],
            caption="",
        )
    else:
        return gr.ScatterPlot(
            value=cars,
            x="Horsepower",
            y="Miles_per_Gallon",
            color="Origin",
            tooltip=["Name"],
            title="Car Data",
            y_title="Miles per Gallon",
            color_legend_title="Origin of Car",
            caption="MPG vs Horsepower of various cars",
        )

with gr.Blocks() as scatter_plot:
    with gr.Row():
        with gr.Column():
            dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
        with gr.Column():
            plot = gr.ScatterPlot()
    dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
    scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)

if __name__ == "__main__":
    scatter_plot.launch()