rui3000 commited on
Commit
4b69b6e
·
verified ·
1 Parent(s): 2e54946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -73
app.py CHANGED
@@ -8,20 +8,18 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
9
 
10
  # --- Load Model and Tokenizer ---
11
- # Note: Model loading happens when the Space starts.
12
- # device_map="auto" will attempt to use the GPU when allocated by @spaces.GPU
13
  print(f"Loading model: {MODEL_ID}")
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  torch_dtype="auto",
18
- device_map="auto" # Keep this, it helps distribute within the allocated GPU(s)
19
  )
20
  print("Model loaded successfully.")
21
 
22
 
23
  # --- Generation Function (Returns response and token count) ---
24
- # This function will run on the GPU allocated via the decorator on process_input
25
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
26
  """Generate a response and return it along with the number of generated tokens."""
27
  num_generated_tokens = 0
@@ -31,8 +29,6 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
31
  tokenize=False,
32
  add_generation_prompt=True
33
  )
34
- # Ensure model_inputs are sent to the correct device the model is on
35
- # device_map='auto' handles this, but explicitly checking model.device is safer
36
  device = model.device
37
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
38
  input_ids_len = model_inputs.input_ids.shape[-1]
@@ -47,40 +43,48 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
47
 
48
  print("Generating response...")
49
  with torch.no_grad():
50
- # Generate response
51
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
52
 
53
- # Calculate generated tokens
54
  output_ids = generated_ids[0, input_ids_len:]
55
  num_generated_tokens = len(output_ids)
56
-
57
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
58
  print("Generation complete.")
59
  return response.strip(), num_generated_tokens
60
 
61
  except Exception as e:
62
  print(f"Error during generation: {e}")
63
- # Ensure error message is returned correctly even if tokens couldn't be counted
64
  return f"An error occurred: {str(e)}", num_generated_tokens
65
 
66
- # --- Input Processing Function (Decorated for ZeroGPU) ---
67
- @spaces.GPU # Add the ZeroGPU decorator here
68
  def process_input(
 
69
  player_stats,
70
- ai_stats,
71
- system_prompt,
72
- user_query,
 
 
73
  max_length,
74
  temperature,
75
  top_p
76
  ):
77
- """Process inputs, generate response, and return display info, response, time, and token count."""
78
- print("GPU requested via decorator, starting processing...") # Add a log message
79
- # Construct the user message content
80
- user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
81
- if ai_stats and ai_stats.strip():
82
- user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
83
- user_content += f"User Query:\n{user_query}"
 
 
 
 
 
 
 
 
 
84
 
85
  # Create the messages list
86
  messages = []
@@ -92,7 +96,7 @@ def process_input(
92
  start_time = time.time()
93
 
94
  # Generate response from the model
95
- response, generated_tokens = generate_response(
96
  messages,
97
  max_length=max_length,
98
  temperature=temperature,
@@ -101,71 +105,139 @@ def process_input(
101
 
102
  # --- Time Measurement End ---
103
  end_time = time.time()
104
- duration = round(end_time - start_time, 2)
105
 
106
  # For display purposes
107
- display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
108
 
109
- print(f"Processing finished in {duration} seconds.") # Add a log message
110
  # Return all results including time and tokens
111
  return display_prompt, response, f"{duration} seconds", generated_tokens
112
 
113
- # --- Gradio Interface (No changes needed here) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
116
- Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
117
 
118
- Follow these steps:
119
- 1. **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
120
- 2. **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
121
- 3. **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
122
- 4. **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).
123
 
124
- Base your analysis strictly on the provided frequencies and standard RPS rules."""
 
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
127
- DEFAULT_AI_STATS = ""
128
- DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
 
 
129
 
 
130
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
- gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
132
- gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")
 
 
 
 
 
 
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  with gr.Row():
135
- with gr.Column(scale=2): # Input column
136
- player_stats_input = gr.Textbox(
137
- label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
138
- info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
139
- )
140
- ai_stats_input = gr.Textbox(
141
- label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
142
- info="Optionally, enter AI's own move frequencies."
143
- )
144
- user_query_input = gr.Textbox(
145
- label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
146
- info="Ask the specific question based on the stats."
147
- )
148
- system_prompt_input = gr.Textbox(
149
- label="System Prompt", value=DEFAULT_SYSTEM_PROMPT,
150
- lines=12
151
- )
152
-
153
- with gr.Column(scale=1): # Params/Output column
154
- gr.Markdown("## Generation Parameters")
155
  max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
156
  temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
157
  top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
158
- submit_btn = gr.Button("Generate Response", variant="primary")
159
 
160
- gr.Markdown("## Performance Metrics")
161
- time_output = gr.Textbox(label="Generation Time", interactive=False)
162
- tokens_output = gr.Number(label="Generated Tokens", interactive=False)
163
 
164
- gr.Markdown("""
165
- ## Testing Tips
166
- - Focus on player stats for optimal counter strategy.
167
- - Use the refined **System Prompt** for better reasoning guidance.
168
- - Lower **Temperature** encourages more direct, statistical answers.
 
 
 
 
 
 
 
 
169
  """)
170
 
171
  with gr.Row():
@@ -176,21 +248,56 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
  label="Model Response", lines=20, show_copy_button=True
177
  )
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  submit_btn.click(
180
  process_input,
181
  inputs=[
182
- player_stats_input, ai_stats_input, system_prompt_input,
183
- user_query_input, max_length_slider, temperature_slider, top_p_slider
 
 
 
 
 
 
 
 
184
  ],
185
  outputs=[
186
  final_prompt_display, response_display,
187
  time_output, tokens_output
188
  ],
189
- api_name="generate_rps_frequency_analysis_v2"
190
  )
191
 
192
  # --- Launch the demo ---
193
  if __name__ == "__main__":
194
- # Share=True is needed for ZeroGPU to work correctly if running locally for testing
195
- # but usually not needed when deployed on HF Spaces platform.
196
  demo.launch()
 
8
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
9
 
10
  # --- Load Model and Tokenizer ---
 
 
11
  print(f"Loading model: {MODEL_ID}")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_ID,
15
  torch_dtype="auto",
16
+ device_map="auto"
17
  )
18
  print("Model loaded successfully.")
19
 
20
 
21
  # --- Generation Function (Returns response and token count) ---
22
+ # No changes needed here
23
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
24
  """Generate a response and return it along with the number of generated tokens."""
25
  num_generated_tokens = 0
 
29
  tokenize=False,
30
  add_generation_prompt=True
31
  )
 
 
32
  device = model.device
33
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
34
  input_ids_len = model_inputs.input_ids.shape[-1]
 
43
 
44
  print("Generating response...")
45
  with torch.no_grad():
 
46
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
47
 
 
48
  output_ids = generated_ids[0, input_ids_len:]
49
  num_generated_tokens = len(output_ids)
 
50
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
51
  print("Generation complete.")
52
  return response.strip(), num_generated_tokens
53
 
54
  except Exception as e:
55
  print(f"Error during generation: {e}")
 
56
  return f"An error occurred: {str(e)}", num_generated_tokens
57
 
58
+ # --- Input Processing Function (Adapts based on mode) ---
59
+ @spaces.GPU # Keep ZeroGPU decorator
60
  def process_input(
61
+ analysis_mode, # New input: Mode selector
62
  player_stats,
63
+ player_last_move,
64
+ markov_prediction_text,
65
+ system_prompt_freq, # Specific system prompt for frequency mode
66
+ system_prompt_markov, # Specific system prompt for markov mode
67
+ user_query, # User query might need slight adaptation based on mode
68
  max_length,
69
  temperature,
70
  top_p
71
  ):
72
+ """Process inputs based on selected analysis mode, generate response."""
73
+ print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}")
74
+
75
+ # Select the appropriate system prompt and construct user content based on mode
76
+ if analysis_mode == "Frequency Only":
77
+ system_prompt = system_prompt_freq
78
+ user_content = f"Player Move Frequency Stats (Long-Term):\n{player_stats}\n\n"
79
+ user_content += f"User Query:\n{user_query}" # Query might need adjustment
80
+ elif analysis_mode == "Markov Prediction Only":
81
+ system_prompt = system_prompt_markov
82
+ user_content = f"Player's Last Move:\n{player_last_move}\n\n"
83
+ user_content += f"Predicted Next Move (Short-Term Markov Analysis):\n{markov_prediction_text}\n\n"
84
+ user_content += f"User Query:\n{user_query}" # Query might need adjustment
85
+ else:
86
+ # Default or error case
87
+ return "Invalid analysis mode selected.", "", "0 seconds", 0
88
 
89
  # Create the messages list
90
  messages = []
 
96
  start_time = time.time()
97
 
98
  # Generate response from the model
99
+ response, generated_tokens = generate_response( # Capture token count
100
  messages,
101
  max_length=max_length,
102
  temperature=temperature,
 
105
 
106
  # --- Time Measurement End ---
107
  end_time = time.time()
108
+ duration = round(end_time - start_time, 2) # Calculate duration
109
 
110
  # For display purposes
111
+ display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
112
 
113
+ print(f"Processing finished in {duration} seconds.")
114
  # Return all results including time and tokens
115
  return display_prompt, response, f"{duration} seconds", generated_tokens
116
 
117
+ # --- System Prompts ---
118
+
119
+ # Refined system prompt for Frequency Analysis
120
+ DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats.
121
+
122
+ Follow these steps EXACTLY. Do NOT deviate.
123
+
124
+ Step 1: Identify Player's Most Frequent Move.
125
+ - Look ONLY at the 'Player Move Frequency Stats'.
126
+ - List the percentages: Rock (%), Paper (%), Scissors (%).
127
+ - State which move name has the highest percentage number.
128
+
129
+ Step 2: Determine the Counter Move using RPS Rules.
130
+ - REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper.
131
+ - Based *only* on the move identified in Step 1, state the single move name that beats it according to the rules. State the rule you used (e.g., "Paper beats Rock").
132
+
133
+ Step 3: Explain the Counter Choice.
134
+ - Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's most frequent move, [Most Frequent Move from Step 1]."
135
+
136
+ Step 4: State Final Recommendation.
137
+ - State *only* the recommended AI move name from Step 2. Example: "Recommendation: Paper"
138
 
139
+ Base your analysis strictly on the provided frequencies and the stated RPS rules.
140
+ """
141
 
142
+ # New system prompt for Markov Analysis
143
+ DEFAULT_SYSTEM_PROMPT_MARKOV = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) short-term player patterns. Your ONLY goal is to find the best single AI move to counter the player's PREDICTED next move, based on their LAST move.
 
 
 
144
 
145
+ Information Provided:
146
+ 1. **Player's Last Move:** The actual move the player just made.
147
+ 2. **Predicted Next Move (Short-Term Markov Analysis):** The player's statistically most likely *next* move based on their *last* move.
148
 
149
+ Follow these steps EXACTLY:
150
+
151
+ Step 1: Identify Predicted Player Move.
152
+ - Look at the 'Predicted Next Move (Short-Term Markov Analysis)' text.
153
+ - State the player's predicted next move (Rock, Paper, or Scissors). Note the probability if provided.
154
+
155
+ Step 2: Determine Counter Move using RPS Rules.
156
+ - REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper.
157
+ - Based *only* on the predicted move identified in Step 1, state the single AI move name that beats it. State the rule used (e.g., "Rock beats Scissors").
158
+
159
+ Step 3: Explain the Counter Choice.
160
+ - Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's predicted next move, [Predicted Move from Step 1]."
161
+
162
+ Step 4: State Final Recommendation.
163
+ - State *only* the recommended AI move name from Step 2. Example: "Recommendation: Rock"
164
+
165
+ Base your analysis strictly on the provided prediction and the standard RPS rules.
166
+ """
167
+
168
+ # --- Default Input Values ---
169
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
170
+ DEFAULT_PLAYER_LAST_MOVE = "Rock"
171
+ DEFAULT_MARKOV_PREDICTION = "Based on the last move (Rock), the player's most likely next move is Paper (60% probability)."
172
+ # Default query might need to be generic or adapted based on mode
173
+ DEFAULT_USER_QUERY = "Based on the provided information for the selected analysis mode, what single move should the AI make next? Explain your reasoning step-by-step as instructed."
174
 
175
+ # --- Gradio Interface ---
176
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
177
+ gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester")
178
+ gr.Markdown("Test model advice using either Frequency Stats OR Short-Term (Markov) Predictions.")
179
+
180
+ # Mode Selector
181
+ analysis_mode_selector = gr.Radio(
182
+ label="Select Analysis Mode",
183
+ choices=["Frequency Only", "Markov Prediction Only"],
184
+ value="Frequency Only" # Default mode
185
+ )
186
 
187
+ # Input Sections (conditionally visible)
188
+ with gr.Group(visible=True) as frequency_inputs: # Visible by default
189
+ gr.Markdown("### Frequency Analysis Inputs")
190
+ player_stats_input = gr.Textbox(
191
+ label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4,
192
+ info="Overall player move distribution."
193
+ )
194
+ # Hidden system prompt for frequency mode (can be edited if needed)
195
+ system_prompt_freq_input = gr.Textbox(
196
+ label="System Prompt (Frequency Mode - Edit if needed)", value=DEFAULT_SYSTEM_PROMPT_FREQ, lines=15, visible=False # Hidden by default, but can be shown for advanced editing
197
+ )
198
+
199
+ with gr.Group(visible=False) as markov_inputs: # Hidden by default
200
+ gr.Markdown("### Markov Prediction Analysis Inputs")
201
+ player_last_move_input = gr.Dropdown( # Dropdown is good for defined choices
202
+ label="Player's Last Move", choices=["Rock", "Paper", "Scissors"], value=DEFAULT_PLAYER_LAST_MOVE,
203
+ info="The player's most recent actual move."
204
+ )
205
+ markov_prediction_input = gr.Textbox(
206
+ label="Predicted Next Move (Short-Term Markov Analysis)", value=DEFAULT_MARKOV_PREDICTION, lines=3,
207
+ info="Provide the pre-calculated prediction based on the last move (e.g., 'Player likely plays Paper (60%)')."
208
+ )
209
+ # Hidden system prompt for markov mode (can be edited if needed)
210
+ system_prompt_markov_input = gr.Textbox(
211
+ label="System Prompt (Markov Mode - Edit if needed)", value=DEFAULT_SYSTEM_PROMPT_MARKOV, lines=15, visible=False # Hidden by default
212
+ )
213
+
214
+ # General Inputs / Parameters / Outputs
215
  with gr.Row():
216
+ with gr.Column(scale=2):
217
+ user_query_input = gr.Textbox(
218
+ label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
219
+ info="Ask the specific question based on the selected mode's analysis."
220
+ )
221
+ with gr.Column(scale=1):
222
+ gr.Markdown("#### Generation Parameters")
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
224
  temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
225
  top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
 
226
 
 
 
 
227
 
228
+ submit_btn = gr.Button("Generate Response", variant="primary")
229
+
230
+ with gr.Row():
231
+ with gr.Column():
232
+ gr.Markdown("#### Performance Metrics")
233
+ time_output = gr.Textbox(label="Generation Time", interactive=False)
234
+ tokens_output = gr.Number(label="Generated Tokens", interactive=False) # Use Number for token count
235
+ with gr.Column():
236
+ gr.Markdown("""
237
+ #### Testing Tips
238
+ - Select the desired **Analysis Mode**.
239
+ - Fill in the inputs for the **selected mode only**.
240
+ - Use low **Temperature** for factual analysis.
241
  """)
242
 
243
  with gr.Row():
 
248
  label="Model Response", lines=20, show_copy_button=True
249
  )
250
 
251
+ # --- Event Handlers ---
252
+
253
+ # Function to update UI visibility based on mode selection
254
+ def update_ui_visibility(mode):
255
+ if mode == "Frequency Only":
256
+ return {
257
+ frequency_inputs: gr.update(visible=True),
258
+ markov_inputs: gr.update(visible=False)
259
+ }
260
+ elif mode == "Markov Prediction Only":
261
+ return {
262
+ frequency_inputs: gr.update(visible=False),
263
+ markov_inputs: gr.update(visible=True)
264
+ }
265
+ else: # Default case
266
+ return {
267
+ frequency_inputs: gr.update(visible=True),
268
+ markov_inputs: gr.update(visible=False)
269
+ }
270
+
271
+ # Link the radio button change to the UI update function
272
+ analysis_mode_selector.change(
273
+ fn=update_ui_visibility,
274
+ inputs=analysis_mode_selector,
275
+ outputs=[frequency_inputs, markov_inputs] # Components to update
276
+ )
277
+
278
+ # Handle button click - Pass all inputs, function will select based on mode
279
  submit_btn.click(
280
  process_input,
281
  inputs=[
282
+ analysis_mode_selector, # Mode selector first
283
+ player_stats_input,
284
+ player_last_move_input,
285
+ markov_prediction_input,
286
+ system_prompt_freq_input, # Pass both system prompts
287
+ system_prompt_markov_input,
288
+ user_query_input,
289
+ max_length_slider,
290
+ temperature_slider,
291
+ top_p_slider
292
  ],
293
  outputs=[
294
  final_prompt_display, response_display,
295
  time_output, tokens_output
296
  ],
297
+ api_name="generate_rps_selectable_analysis" # Updated api_name
298
  )
299
 
300
  # --- Launch the demo ---
301
  if __name__ == "__main__":
302
+ # Share=True might be needed for ZeroGPU if running locally for testing
 
303
  demo.launch()