Oleh Kuznetsov commited on
Commit
7058ffd
·
1 Parent(s): 7397d2d

feat(ui): Add simple side-by-side ui for comparison

Browse files
Files changed (1) hide show
  1. app.py +112 -8
app.py CHANGED
@@ -1,15 +1,119 @@
1
  import gradio as gr
 
2
 
3
 
4
- def greet(name, intensity):
5
- return "Hello, " + name + "!" * int(intensity)
 
6
 
7
 
8
- app = gr.Interface(
9
- fn=greet,
10
- inputs=["text", "slider"],
11
- outputs=["text"],
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  if __name__ == "__main__":
15
- app.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ import random
3
 
4
 
5
+ # Dummy model functions for demonstration
6
+ def recommend_sadaimrec(query: str):
7
+ return f"SADAIMREC: response to '{query}'"
8
 
9
 
10
+ def recommend_chatgpt(query: str):
11
+ return f"CHATGPT: response to '{query}'"
12
+
13
+
14
+ # Mapping names to functions
15
+ pipelines = {
16
+ "sadaimrec": recommend_sadaimrec,
17
+ "chatgpt": recommend_chatgpt,
18
+ }
19
+
20
+
21
+ # Interface logic
22
+ def generate_responses(query):
23
+ # Randomize model order
24
+ pipeline_names = list(pipelines.keys())
25
+ random.shuffle(pipeline_names)
26
+
27
+ # Generate responses
28
+ resp1 = pipelines[pipeline_names[0]](query)
29
+ resp2 = pipelines[pipeline_names[1]](query)
30
+
31
+ # Return texts and hidden labels
32
+ return resp1, resp2, pipeline_names[0], pipeline_names[1]
33
+
34
+
35
+ # Callback to capture vote
36
+ def handle_vote(selected, label1, label2, resp1, resp2):
37
+ chosen_name = label1 if selected == "Option 1" else label2
38
+ chosen_resp = resp1 if selected == "Option 1" else resp2
39
+ print(f"User voted for {chosen_name}: '{chosen_resp}'")
40
+ return (
41
+ "Thank you for your vote! Restarting in 2 seconds...",
42
+ gr.update(active=True),
43
+ )
44
+
45
+
46
+ def reset_ui():
47
+ return (
48
+ gr.update(value="", visible=False), # hide row
49
+ gr.update(value=""), # clear query
50
+ gr.update(visible=False), # hide radio
51
+ gr.update(visible=False), # hide vote button
52
+ gr.update(value=""), # clear Option 1 text
53
+ gr.update(value=""), # clear Option 2 text
54
+ gr.update(value=""), # clear result
55
+ gr.update(active=False),
56
+ )
57
+
58
+
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# Music Genre Recommendation Side-By-Side Comparison")
61
+ query = gr.Textbox(label="Your Query")
62
+ submit_btn = gr.Button("Submit")
63
+ # timer that resets ui after feedback is sent
64
+ reset_timer = gr.Timer(value=2.0, active=False)
65
+
66
+ # Hidden components to store model responses and names
67
+ with gr.Row(visible=False) as response_row:
68
+ response_1 = gr.Textbox(label="Option 1", interactive=False)
69
+ response_2 = gr.Textbox(label="Option 2", interactive=False)
70
+ model_label_1 = gr.Textbox(visible=False)
71
+ model_label_2 = gr.Textbox(visible=False)
72
+
73
+ # Feedback
74
+ vote = gr.Radio(
75
+ ["Option 1", "Option 2"], label="Select Best Response", visible=False
76
+ )
77
+ vote_btn = gr.Button("Vote", visible=False)
78
+ result = gr.Textbox(label="Console", interactive=False)
79
+
80
+ # On submit
81
+ submit_btn.click( # generate
82
+ fn=generate_responses,
83
+ inputs=[query],
84
+ outputs=[response_1, response_2, model_label_1, model_label_2],
85
+ )
86
+ submit_btn.click( # update ui
87
+ fn=lambda: (
88
+ gr.update(visible=True),
89
+ gr.update(visible=True),
90
+ gr.update(visible=True),
91
+ ),
92
+ inputs=None,
93
+ outputs=[response_row, vote, vote_btn],
94
+ )
95
+
96
+ # Feedback handling
97
+ vote_btn.click(
98
+ fn=handle_vote,
99
+ inputs=[vote, model_label_1, model_label_2, response_1, response_2],
100
+ outputs=[result, reset_timer],
101
+ )
102
+ reset_timer.tick(
103
+ fn=reset_ui,
104
+ inputs=None,
105
+ outputs=[
106
+ response_row,
107
+ query,
108
+ vote,
109
+ vote_btn,
110
+ response_1,
111
+ response_2,
112
+ result,
113
+ reset_timer,
114
+ ],
115
+ trigger_mode="once",
116
+ )
117
 
118
  if __name__ == "__main__":
119
+ demo.launch(server_name="0.0.0.0", server_port=7860)