rui3000 commited on
Commit
9b05877
·
verified ·
1 Parent(s): a556cd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -86
app.py CHANGED
@@ -1,12 +1,17 @@
1
  import gradio as gr
2
  import torch
3
  import time
4
- import spaces # Import the spaces library
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from db import save_test_result
 
 
 
 
 
7
 
8
  # --- Configuration ---
9
- MODEL_ID = "Qwen/Qwen2.5-Math-1.5B" # Replace with actual ID if found
10
  # --- Load Model and Tokenizer ---
11
  print(f"Loading model: {MODEL_ID}")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -17,7 +22,6 @@ model = AutoModelForCausalLM.from_pretrained(
17
  )
18
  print("Model loaded successfully.")
19
 
20
-
21
  # --- Generation Function (Returns response and token count) ---
22
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
23
  """Generate a response and return it along with the number of generated tokens."""
@@ -48,17 +52,6 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
48
  num_generated_tokens = len(output_ids)
49
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
50
  print("Generation complete.")
51
-
52
- save_test_result(
53
- analysis_mode='',
54
- system_prompt='',
55
- input_content='',
56
- model_response= response,
57
- generation_time='',
58
- tokens_generated='',
59
- temperature='',
60
- top_p='',
61
- max_length='')
62
  return response.strip(), num_generated_tokens
63
 
64
  except Exception as e:
@@ -73,7 +66,8 @@ def process_input(
73
  system_prompt, # Single system prompt from UI
74
  max_length,
75
  temperature,
76
- top_p
 
77
  ):
78
  """Process inputs based on selected analysis mode using the provided system prompt."""
79
  print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}")
@@ -117,11 +111,29 @@ def process_input(
117
  display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}"
118
 
119
  print(f"Processing finished in {duration} seconds.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # Return all results including time and tokens
121
  return display_prompt, response, f"{duration} seconds", generated_tokens
122
 
123
  # --- System Prompts (Defaults only, UI will hold the editable version) ---
124
-
125
  DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats.
126
 
127
  Follow these steps EXACTLY. Do NOT deviate.
@@ -144,7 +156,6 @@ Step 4: State Final Recommendation.
144
  Base your analysis strictly on the provided frequencies and the stated RPS rules.
145
  """
146
 
147
- # Updated Markov system prompt - no changes to content
148
  DEFAULT_SYSTEM_PROMPT_MARKOV = """You are analyzing a Rock-Paper-Scissors (RPS) game using a Markov transition matrix.
149
 
150
  ### TRANSITION MATRIX:
@@ -186,7 +197,6 @@ Predicted Next Move: [Move with highest probability]
186
  Optimal Counter: [Move that beats the predicted move]
187
  """
188
 
189
- # New Behavior Analysis prompt
190
  DEFAULT_SYSTEM_PROMPT_BEHAVIOR = """You are an RPS assistant analyzing player behavior after wins, losses, and ties. Predict the player's next move and give counter strategy based on the Behavioral probabilities.
191
 
192
  **Behavioral Probabilities P(Change/not change | Win/Loss/Tie):**
@@ -224,74 +234,116 @@ DEFAULT_PLAYER_BEHAVIOR = "Player's Last Outcome: Win\nPlayer's Last Move: Rock"
224
 
225
  # --- Gradio Interface ---
226
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
227
- gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester")
228
- gr.Markdown("Test model advice using Frequency Stats, Markov Predictions, or Win/Loss/Tie Behavior Analysis.")
229
-
230
- # Mode Selector - now with three options
231
- analysis_mode_selector = gr.Radio(
232
- label="Select Analysis Mode",
233
- choices=["Frequency Only", "Markov Prediction Only", "Behavior Analysis"],
234
- value="Frequency Only" # Default mode
235
- )
236
-
237
- # --- Visible System Prompt Textbox ---
238
- system_prompt_input = gr.Textbox(
239
- label="System Prompt (Edit based on selected mode)",
240
- value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt
241
- lines=15
242
  )
243
 
244
- # Input Sections (conditionally visible)
245
- with gr.Group(visible=True) as frequency_inputs: # Visible by default
246
- gr.Markdown("### Frequency Analysis Inputs")
247
- player_stats_input = gr.Textbox(
248
- label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4,
249
- info="Overall player move distribution."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
-
252
- with gr.Group(visible=False) as markov_inputs: # Hidden by default
253
- gr.Markdown("### Markov Prediction Analysis Inputs")
254
- gr.Markdown("*Use the System Prompt field to directly input your Markov analysis instructions.*")
255
-
256
- # New behavior analysis inputs
257
- with gr.Group(visible=False) as behavior_inputs:
258
- gr.Markdown("### Win/Loss/Tie Behavior Analysis Inputs")
259
- player_behavior_input = gr.Textbox(
260
- label="Player's Last Outcome and Move", value=DEFAULT_PLAYER_BEHAVIOR, lines=4,
261
- info="Enter the last outcome (Win/Loss/Tie) and move (Rock/Paper/Scissors)."
262
- )
263
-
264
- # General Inputs / Parameters / Outputs
265
- with gr.Row():
266
- with gr.Column():
267
- gr.Markdown("#### Generation Parameters")
268
- max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
269
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
270
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
271
-
272
- submit_btn = gr.Button("Generate Response", variant="primary")
273
-
274
- with gr.Row():
275
- with gr.Column():
276
- gr.Markdown("#### Performance Metrics")
277
- time_output = gr.Textbox(label="Generation Time", interactive=False)
278
- tokens_output = gr.Number(label="Generated Tokens", interactive=False)
279
- with gr.Column():
280
- gr.Markdown("""
281
- #### Testing Tips
282
- - Select the desired **Analysis Mode**.
283
- - Fill in the inputs for the **selected mode only**.
284
- - **Edit the System Prompt** above as needed for testing.
285
- - Use low **Temperature** for factual analysis.
286
- """)
287
-
288
- with gr.Row():
289
- final_prompt_display = gr.Textbox(
290
- label="Formatted Input Sent to Model (via Chat Template)", lines=20
291
- )
292
- response_display = gr.Textbox(
293
- label="Model Response", lines=20, show_copy_button=True
294
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # --- Event Handlers ---
297
 
@@ -326,6 +378,34 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
326
  system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ)
327
  }
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # Link the radio button change to the UI update function
330
  analysis_mode_selector.change(
331
  fn=update_ui_visibility_and_prompt, # Use the combined update function
@@ -343,14 +423,33 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
343
  system_prompt_input, # Pass the visible system prompt textbox
344
  max_length_slider,
345
  temperature_slider,
346
- top_p_slider
 
347
  ],
348
  outputs=[
349
  final_prompt_display, response_display,
350
  time_output, tokens_output
351
- ],
352
- api_name="generate_rps_selectable_analysis_v4" # Updated api_name
353
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  # --- Launch the demo ---
356
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import time
4
+ import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from db import init_db, save_test_result, get_test_history, get_test_details
7
+
8
+ # --- Initialize Database ---
9
+ db_initialized = init_db()
10
+ if not db_initialized:
11
+ print("WARNING: Database initialization failed. Test history will not be saved.")
12
 
13
  # --- Configuration ---
14
+ MODEL_ID = "Qwen/Qwen2.5-Math-1.5B" # Replace with actual ID if found
15
  # --- Load Model and Tokenizer ---
16
  print(f"Loading model: {MODEL_ID}")
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
22
  )
23
  print("Model loaded successfully.")
24
 
 
25
  # --- Generation Function (Returns response and token count) ---
26
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
27
  """Generate a response and return it along with the number of generated tokens."""
 
52
  num_generated_tokens = len(output_ids)
53
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
54
  print("Generation complete.")
 
 
 
 
 
 
 
 
 
 
 
55
  return response.strip(), num_generated_tokens
56
 
57
  except Exception as e:
 
66
  system_prompt, # Single system prompt from UI
67
  max_length,
68
  temperature,
69
+ top_p,
70
+ save_to_db=True # New parameter to toggle database saving
71
  ):
72
  """Process inputs based on selected analysis mode using the provided system prompt."""
73
  print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}")
 
111
  display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}"
112
 
113
  print(f"Processing finished in {duration} seconds.")
114
+
115
+ # Save to database if requested and if database is available
116
+ if save_to_db and db_initialized:
117
+ test_id = save_test_result(
118
+ analysis_mode=analysis_mode,
119
+ system_prompt=system_prompt,
120
+ input_content=user_content if user_content else "",
121
+ model_response=response,
122
+ generation_time=duration,
123
+ tokens_generated=generated_tokens,
124
+ temperature=temperature,
125
+ top_p=top_p,
126
+ max_length=max_length
127
+ )
128
+ if test_id:
129
+ print(f"Test saved to database with ID: {test_id}")
130
+ else:
131
+ print("Failed to save test to database")
132
+
133
  # Return all results including time and tokens
134
  return display_prompt, response, f"{duration} seconds", generated_tokens
135
 
136
  # --- System Prompts (Defaults only, UI will hold the editable version) ---
 
137
  DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats.
138
 
139
  Follow these steps EXACTLY. Do NOT deviate.
 
156
  Base your analysis strictly on the provided frequencies and the stated RPS rules.
157
  """
158
 
 
159
  DEFAULT_SYSTEM_PROMPT_MARKOV = """You are analyzing a Rock-Paper-Scissors (RPS) game using a Markov transition matrix.
160
 
161
  ### TRANSITION MATRIX:
 
197
  Optimal Counter: [Move that beats the predicted move]
198
  """
199
 
 
200
  DEFAULT_SYSTEM_PROMPT_BEHAVIOR = """You are an RPS assistant analyzing player behavior after wins, losses, and ties. Predict the player's next move and give counter strategy based on the Behavioral probabilities.
201
 
202
  **Behavioral Probabilities P(Change/not change | Win/Loss/Tie):**
 
234
 
235
  # --- Gradio Interface ---
236
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
237
+ with gr.Tab("Model Testing"):
238
+ gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester with Test History")
239
+ gr.Markdown("Test model advice using Frequency Stats, Markov Predictions, or Win/Loss/Tie Behavior Analysis.")
240
+
241
+ # Mode Selector - now with three options
242
+ analysis_mode_selector = gr.Radio(
243
+ label="Select Analysis Mode",
244
+ choices=["Frequency Only", "Markov Prediction Only", "Behavior Analysis"],
245
+ value="Frequency Only" # Default mode
 
 
 
 
 
 
246
  )
247
 
248
+ # --- Visible System Prompt Textbox ---
249
+ system_prompt_input = gr.Textbox(
250
+ label="System Prompt (Edit based on selected mode)",
251
+ value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt
252
+ lines=15
253
+ )
254
+
255
+ # Input Sections (conditionally visible)
256
+ with gr.Group(visible=True) as frequency_inputs: # Visible by default
257
+ gr.Markdown("### Frequency Analysis Inputs")
258
+ player_stats_input = gr.Textbox(
259
+ label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4,
260
+ info="Overall player move distribution."
261
+ )
262
+
263
+ with gr.Group(visible=False) as markov_inputs: # Hidden by default
264
+ gr.Markdown("### Markov Prediction Analysis Inputs")
265
+ gr.Markdown("*Use the System Prompt field to directly input your Markov analysis instructions.*")
266
+
267
+ # New behavior analysis inputs
268
+ with gr.Group(visible=False) as behavior_inputs:
269
+ gr.Markdown("### Win/Loss/Tie Behavior Analysis Inputs")
270
+ player_behavior_input = gr.Textbox(
271
+ label="Player's Last Outcome and Move", value=DEFAULT_PLAYER_BEHAVIOR, lines=4,
272
+ info="Enter the last outcome (Win/Loss/Tie) and move (Rock/Paper/Scissors)."
273
+ )
274
+
275
+ # General Inputs / Parameters / Outputs
276
+ with gr.Row():
277
+ with gr.Column():
278
+ gr.Markdown("#### Generation Parameters")
279
+ max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
280
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
281
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
282
+
283
+ # Add a checkbox to control saving to database
284
+ save_to_db_checkbox = gr.Checkbox(
285
+ label="Save this test to database",
286
+ value=True,
287
+ info="Store input and output in SQLite database for later reference"
288
+ )
289
+
290
+ submit_btn = gr.Button("Generate Response", variant="primary")
291
+
292
+ with gr.Row():
293
+ with gr.Column():
294
+ gr.Markdown("#### Performance Metrics")
295
+ time_output = gr.Textbox(label="Generation Time", interactive=False)
296
+ tokens_output = gr.Number(label="Generated Tokens", interactive=False)
297
+ with gr.Column():
298
+ gr.Markdown("""
299
+ #### Testing Tips
300
+ - Select the desired **Analysis Mode**.
301
+ - Fill in the inputs for the **selected mode only**.
302
+ - **Edit the System Prompt** above as needed for testing.
303
+ - Use low **Temperature** for factual analysis.
304
+ """)
305
+
306
+ with gr.Row():
307
+ final_prompt_display = gr.Textbox(
308
+ label="Formatted Input Sent to Model (via Chat Template)", lines=20
309
+ )
310
+ response_display = gr.Textbox(
311
+ label="Model Response", lines=20, show_copy_button=True
312
+ )
313
+
314
+ # Add a new tab for test history
315
+ with gr.Tab("Test History"):
316
+ gr.Markdown("### Saved Test Results")
317
+
318
+ refresh_btn = gr.Button("Refresh History")
319
+
320
+ # Display test history as a dataframe
321
+ test_history_df = gr.Dataframe(
322
+ headers=["Test ID", "Analysis Mode", "Timestamp", "Generation Time", "Tokens"],
323
+ label="Recent Tests",
324
+ interactive=False
325
  )
326
+
327
+ # Add a number input to load a specific test
328
+ test_id_input = gr.Number(
329
+ label="Test ID",
330
+ precision=0,
331
+ info="Enter a Test ID to load details"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  )
333
+ load_test_btn = gr.Button("Load Test")
334
+
335
+ # Display test details
336
+ with gr.Group():
337
+ test_mode_display = gr.Textbox(label="Analysis Mode", interactive=False)
338
+ test_prompt_display = gr.Textbox(label="System Prompt", interactive=False, lines=8)
339
+ test_input_display = gr.Textbox(label="Input Content", interactive=False, lines=4)
340
+ test_response_display = gr.Textbox(label="Model Response", interactive=False, lines=8)
341
+
342
+ with gr.Row():
343
+ test_time_display = gr.Number(label="Generation Time (s)", interactive=False)
344
+ test_tokens_display = gr.Number(label="Tokens Generated", interactive=False)
345
+ test_temp_display = gr.Number(label="Temperature", interactive=False)
346
+ test_topp_display = gr.Number(label="Top P", interactive=False)
347
 
348
  # --- Event Handlers ---
349
 
 
378
  system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ)
379
  }
380
 
381
+ # Function to update test history display
382
+ def update_test_history():
383
+ if db_initialized:
384
+ history = get_test_history(limit=20)
385
+ return [[h[0], h[1], h[2], h[3], h[4]] for h in history]
386
+ else:
387
+ return [["N/A", "Database Not Available", "N/A", 0, 0]]
388
+
389
+ # Function to load test details
390
+ def load_test_details(test_id):
391
+ if not db_initialized:
392
+ return ["Database Not Available", "", "", "", 0, 0, 0, 0]
393
+
394
+ test = get_test_details(test_id)
395
+ if test:
396
+ return [
397
+ test["analysis_mode"],
398
+ test["system_prompt"],
399
+ test["input_content"] or "",
400
+ test["model_response"],
401
+ test["generation_time"],
402
+ test["tokens_generated"],
403
+ test["temperature"],
404
+ test["top_p"]
405
+ ]
406
+ else:
407
+ return ["Test not found", "", "", "", 0, 0, 0, 0]
408
+
409
  # Link the radio button change to the UI update function
410
  analysis_mode_selector.change(
411
  fn=update_ui_visibility_and_prompt, # Use the combined update function
 
423
  system_prompt_input, # Pass the visible system prompt textbox
424
  max_length_slider,
425
  temperature_slider,
426
+ top_p_slider,
427
+ save_to_db_checkbox # Pass the checkbox value
428
  ],
429
  outputs=[
430
  final_prompt_display, response_display,
431
  time_output, tokens_output
432
+ ]
 
433
  )
434
+
435
+ # Connect buttons for test history tab
436
+ refresh_btn.click(
437
+ update_test_history,
438
+ outputs=[test_history_df]
439
+ )
440
+
441
+ load_test_btn.click(
442
+ load_test_details,
443
+ inputs=[test_id_input],
444
+ outputs=[
445
+ test_mode_display, test_prompt_display, test_input_display,
446
+ test_response_display, test_time_display, test_tokens_display,
447
+ test_temp_display, test_topp_display
448
+ ]
449
+ )
450
+
451
+ # Initialize history on page load
452
+ demo.load(update_test_history, outputs=[test_history_df])
453
 
454
  # --- Launch the demo ---
455
  if __name__ == "__main__":