rui3000 commited on
Commit
3ef427a
·
verified ·
1 Parent(s): 480da6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -111
app.py CHANGED
@@ -4,22 +4,19 @@ import json
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # --- Configuration ---
7
- # Updated to the 1.5B Instruct model as requested
8
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
9
 
10
  # --- Load Model and Tokenizer ---
11
  print(f"Loading model: {MODEL_ID}")
12
- # Removed trust_remote_code=True as it's generally not needed for standard HF models
13
- # Using torch_dtype="auto" for flexibility (can use bfloat16 if available)
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
- torch_dtype="auto", # Changed from float16 to auto
18
  device_map="auto"
19
  )
20
  print("Model loaded successfully.")
21
 
22
- # --- Predefined Data (Keep user's structure) ---
23
  PREDEFINED_GAMES = {
24
  "rps_simple": {
25
  "description": "Rock-Paper-Scissors (Simple Format)",
@@ -47,8 +44,7 @@ PREDEFINED_GAMES = {
47
  }
48
  }
49
 
50
- # --- Predefined Prompts (Keep user's structure) ---
51
- # Updated default prompts to be more aligned with the goal
52
  PROMPT_TEMPLATES = {
53
  "detailed_analysis_recommendation": "Analyze the game history provided. Identify patterns in the player's moves. Based on your analysis, explain the reasoning and recommend the best move for the AI (or player if specified) in the next round.",
54
  "player_pattern_focus": "Focus specifically on the player's move patterns. Do they favor a specific move? Do they follow sequences? Do they react predictably after winning or losing?",
@@ -56,31 +52,36 @@ PROMPT_TEMPLATES = {
56
  "structured_output_request": "Provide a structured analysis with these sections: 1) Obvious player patterns, 2) Potential opponent counter-strategies, 3) Final move recommendation with reasoning."
57
  }
58
 
59
- # --- Formatting Functions (Keep user's functions) ---
60
  def format_rps_simple(game_data):
61
- """Format the RPS data in a simple way that's easy for small models to understand"""
62
  game = game_data["data"]
63
  move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
64
- result_names = {0: "AI wins", 1: "Player wins", 2: "Tie"}
65
  player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
 
66
  formatted_data = "Game: Rock-Paper-Scissors\n"
67
- formatted_data += "Format explanation: [Round#, Player move, AI move, Result]\n"
68
  formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
69
- formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n"
70
- formatted_data += "Game Data:\n"
 
71
  for round_data in game["rounds"]:
72
- r_num, p_move, ai_move, result = round_data["round"], round_data["player"], round_data["ai"], round_data["result"]
73
  player_moves[move_names[p_move]] += 1
74
- formatted_data += f"[{r_num}, {p_move}, {ai_move}, {result}] # R{r_num}: Player {move_names[p_move]}, AI {move_names[ai_move]}, {result_names[result]}\n"
 
 
 
75
  formatted_data += "\nSummary:\n"
76
  formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
77
  formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
78
  formatted_data += f"Ties: {game['summary']['ties']}\n\n"
 
79
  formatted_data += "Player move frequencies:\n"
80
  total_rounds = len(game["rounds"])
81
  for move, count in player_moves.items():
82
  percentage = round((count / total_rounds) * 100) if total_rounds > 0 else 0
83
- formatted_data += f"{move}: {count} times ({percentage}%)\n" # Corrected percentage calc
84
  return formatted_data
85
 
86
  def format_rps_numeric(game_data):
@@ -98,21 +99,16 @@ FORMAT_FUNCTIONS = {
98
  "rps_numeric": format_rps_numeric
99
  }
100
 
101
- # --- Generation Function (Updated for Chat Template) ---
102
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
103
  """Generate a response from the Qwen2 model using chat template."""
104
  try:
105
- # Apply the chat template
106
  prompt_text = tokenizer.apply_chat_template(
107
  messages,
108
  tokenize=False,
109
- add_generation_prompt=True # Important for instruct models
110
  )
111
-
112
- # Tokenize the formatted prompt
113
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
114
-
115
- # Generation arguments
116
  generation_kwargs = {
117
  "max_new_tokens": max_length,
118
  "temperature": temperature,
@@ -120,13 +116,9 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
120
  "do_sample": True,
121
  "pad_token_id": tokenizer.eos_token_id,
122
  }
123
-
124
- # Generate response
125
  print("Generating response...")
126
  with torch.no_grad():
127
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
128
-
129
- # Decode the response, excluding the input tokens
130
  input_ids_len = model_inputs.input_ids.shape[-1]
131
  output_ids = generated_ids[0, input_ids_len:]
132
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -137,7 +129,7 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
137
  print(f"Error during generation: {e}")
138
  return f"An error occurred: {str(e)}"
139
 
140
- # --- Input Processing Function (Updated for Chat Template) ---
141
  def process_input(
142
  game_format,
143
  prompt_template,
@@ -149,135 +141,94 @@ def process_input(
149
  top_p
150
  ):
151
  """Process the input, format using chat template, and generate response."""
152
-
153
- # Get the selected game data and format it
154
  game_data = PREDEFINED_GAMES[game_format]
155
- formatted_game_data = FORMAT_FUNCTIONS[game_format](game_data) #
156
-
157
- # Determine which prompt question to use
158
- user_question = custom_prompt if use_custom_prompt else PROMPT_TEMPLATES[prompt_template] #
159
-
160
- # Construct the user message content
161
  user_content = f"Game History:\n{formatted_game_data}\n\nQuestion:\n{user_question}"
162
-
163
- # Create the messages list for the chat template
164
  messages = []
165
- if system_prompt and system_prompt.strip(): # Add system prompt if provided
166
  messages.append({"role": "system", "content": system_prompt})
167
  messages.append({"role": "user", "content": user_content})
168
-
169
- # Generate response from the model
170
  response = generate_response(
171
  messages,
172
  max_length=max_length,
173
  temperature=temperature,
174
  top_p=top_p
175
  )
176
-
177
- # For display purposes, show the "user" part of the prompt
178
- # (The system prompt isn't usually shown in the final input display)
179
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
180
-
181
  return display_prompt, response
182
 
183
- # --- Gradio Interface (Minor updates) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
185
- gr.Markdown(f"# {MODEL_ID} - RPS Game Analysis Tester") # Updated Title
186
  gr.Markdown("Test how the model responds to different RPS game data formats and prompts using its chat template.")
187
 
188
  with gr.Row():
189
  with gr.Column():
190
- # Game data selection
191
  game_format = gr.Dropdown(
192
- choices=list(PREDEFINED_GAMES.keys()),
193
- value="rps_simple",
194
- label="Game Data Format"
195
  )
196
-
197
- # System prompt (optional)
198
- # Added a more relevant placeholder based on the user's goal
199
  system_prompt = gr.Textbox(
200
  label="System Prompt (Optional)",
201
- placeholder="e.g., You are an expert RPS analyst. Analyze the provided game history, identify patterns, explain your reasoning clearly, and recommend the next move. Structure your output with observations, reasoning, and a final recommendation.",
202
- lines=4 # Increased lines slightly
 
203
  )
204
-
205
- # Prompt selection
206
  with gr.Row():
207
  prompt_template = gr.Dropdown(
208
- choices=list(PROMPT_TEMPLATES.keys()),
209
- value="detailed_analysis_recommendation", # Updated default
210
- label="Prompt Template"
211
  )
212
- use_custom_prompt = gr.Checkbox(
213
- label="Use Custom Prompt",
214
- value=False
215
- )
216
-
217
  custom_prompt = gr.Textbox(
218
  label="Custom Prompt (if Use Custom Prompt is checked)",
219
- placeholder="Enter your custom prompt/question here",
220
- lines=3 # Increased lines slightly
221
  )
222
-
223
- # Generation parameters
224
  with gr.Row():
225
- max_length = gr.Slider(
226
- minimum=50,
227
- maximum=1024, # Increased max
228
- value=512, # Increased default
229
- step=16, # Step size power of 2
230
- label="Max New Tokens" # Renamed label
231
- )
232
- temperature = gr.Slider(
233
- minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature" # Step size finer
234
- )
235
- top_p = gr.Slider(
236
- minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P" # Step size finer
237
- )
238
-
239
- # Generate button
240
- submit_btn = gr.Button("Generate Response", variant="primary") # Added variant
241
 
242
  with gr.Column():
243
- # Display final prompt and model response
244
- # Renamed label for clarity
245
  final_prompt_display = gr.Textbox(
246
- label="Formatted Input Sent to Model (via Chat Template)",
247
- lines=15 # Increased lines
248
  )
249
  response_display = gr.Textbox(
250
- label="Model Response",
251
- lines=15, # Increased lines
252
- show_copy_button=True # Added copy button
253
  )
254
-
255
- # Tips for using the interface
256
  gr.Markdown("""
257
  ## Testing Tips
258
-
259
- - **Game Data Format**: Selects how the history is structured. 'rps_simple' is often easier for models to parse.
260
- - **System Prompt**: Crucial for setting the AI's role and desired output style (like your example image). Be descriptive!
261
- - **Prompt Template / Custom Prompt**: Asks the specific question based on the history and system instructions.
262
- - **Generation Params**: Tune `Temperature` and `Top P` to control creativity vs. focus. Adjust `Max New Tokens` for response length.
263
- - **Chat Template**: This version now correctly uses the model's chat template for better instruction following.
264
  """)
265
 
266
- # Handle button click
267
  submit_btn.click(
268
  process_input,
269
  inputs=[
270
- game_format,
271
- prompt_template,
272
- custom_prompt,
273
- use_custom_prompt,
274
- system_prompt,
275
- max_length,
276
- temperature,
277
- top_p
278
  ],
279
  outputs=[final_prompt_display, response_display],
280
- api_name="generate_rps_analysis" # Added api_name
281
  )
282
 
283
  # --- Launch the demo ---
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # --- Configuration ---
 
7
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
8
 
9
  # --- Load Model and Tokenizer ---
10
  print(f"Loading model: {MODEL_ID}")
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
+ torch_dtype="auto",
15
  device_map="auto"
16
  )
17
  print("Model loaded successfully.")
18
 
19
+ # --- Predefined Data (User's structure) ---
20
  PREDEFINED_GAMES = {
21
  "rps_simple": {
22
  "description": "Rock-Paper-Scissors (Simple Format)",
 
44
  }
45
  }
46
 
47
+ # --- Predefined Prompts (User's structure) ---
 
48
  PROMPT_TEMPLATES = {
49
  "detailed_analysis_recommendation": "Analyze the game history provided. Identify patterns in the player's moves. Based on your analysis, explain the reasoning and recommend the best move for the AI (or player if specified) in the next round.",
50
  "player_pattern_focus": "Focus specifically on the player's move patterns. Do they favor a specific move? Do they follow sequences? Do they react predictably after winning or losing?",
 
52
  "structured_output_request": "Provide a structured analysis with these sections: 1) Obvious player patterns, 2) Potential opponent counter-strategies, 3) Final move recommendation with reasoning."
53
  }
54
 
55
+ # --- Formatting Functions (Updated format_rps_simple) ---
56
  def format_rps_simple(game_data):
57
+ """Format the RPS data clearly, explicitly stating moves and results."""
58
  game = game_data["data"]
59
  move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
60
+ result_map = {0: "AI wins", 1: "Player wins", 2: "Tie"} # Changed name
61
  player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
62
+
63
  formatted_data = "Game: Rock-Paper-Scissors\n"
 
64
  formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
65
+ formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n" # Simplified explanation
66
+
67
+ formatted_data += "Game Data (Round, Player Move, AI Move, Result Text):\n" # Clarified header
68
  for round_data in game["rounds"]:
69
+ r_num, p_move, ai_move, result_code = round_data["round"], round_data["player"], round_data["ai"], round_data["result"]
70
  player_moves[move_names[p_move]] += 1
71
+ result_text = result_map[result_code]
72
+ # Explicitly add text names and result text in the main data line
73
+ formatted_data += f"R{r_num}: Player={move_names[p_move]}({p_move}), AI={move_names[ai_move]}({ai_move}), Result={result_text}\n"
74
+
75
  formatted_data += "\nSummary:\n"
76
  formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
77
  formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
78
  formatted_data += f"Ties: {game['summary']['ties']}\n\n"
79
+
80
  formatted_data += "Player move frequencies:\n"
81
  total_rounds = len(game["rounds"])
82
  for move, count in player_moves.items():
83
  percentage = round((count / total_rounds) * 100) if total_rounds > 0 else 0
84
+ formatted_data += f"{move}: {count} times ({percentage}%)\n"
85
  return formatted_data
86
 
87
  def format_rps_numeric(game_data):
 
99
  "rps_numeric": format_rps_numeric
100
  }
101
 
102
+ # --- Generation Function (Using Chat Template) ---
103
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
104
  """Generate a response from the Qwen2 model using chat template."""
105
  try:
 
106
  prompt_text = tokenizer.apply_chat_template(
107
  messages,
108
  tokenize=False,
109
+ add_generation_prompt=True
110
  )
 
 
111
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
 
 
112
  generation_kwargs = {
113
  "max_new_tokens": max_length,
114
  "temperature": temperature,
 
116
  "do_sample": True,
117
  "pad_token_id": tokenizer.eos_token_id,
118
  }
 
 
119
  print("Generating response...")
120
  with torch.no_grad():
121
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
 
 
122
  input_ids_len = model_inputs.input_ids.shape[-1]
123
  output_ids = generated_ids[0, input_ids_len:]
124
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
 
129
  print(f"Error during generation: {e}")
130
  return f"An error occurred: {str(e)}"
131
 
132
+ # --- Input Processing Function (Using Chat Template) ---
133
  def process_input(
134
  game_format,
135
  prompt_template,
 
141
  top_p
142
  ):
143
  """Process the input, format using chat template, and generate response."""
 
 
144
  game_data = PREDEFINED_GAMES[game_format]
145
+ formatted_game_data = FORMAT_FUNCTIONS[game_format](game_data)
146
+ user_question = custom_prompt if use_custom_prompt else PROMPT_TEMPLATES[prompt_template]
 
 
 
 
147
  user_content = f"Game History:\n{formatted_game_data}\n\nQuestion:\n{user_question}"
 
 
148
  messages = []
149
+ if system_prompt and system_prompt.strip():
150
  messages.append({"role": "system", "content": system_prompt})
151
  messages.append({"role": "user", "content": user_content})
 
 
152
  response = generate_response(
153
  messages,
154
  max_length=max_length,
155
  temperature=temperature,
156
  top_p=top_p
157
  )
 
 
 
158
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
 
159
  return display_prompt, response
160
 
161
+ # --- Gradio Interface (Updated system prompt placeholder) ---
162
+
163
+ # Define the improved default system prompt
164
+ DEFAULT_SYSTEM_PROMPT = """You are a highly accurate and methodical Rock-Paper-Scissors (RPS) strategy analyst.
165
+ Your goal is to analyze the provided game history and give the user strategic advice for their next move.
166
+
167
+ Follow these steps precisely:
168
+ 1. **Verify Rules:** Remember: Rock (0) beats Scissors (2), Scissors (2) beats Paper (1), Paper (1) beats Rock (0).
169
+ 2. **Analyze Player Moves:** Go through the 'Game Data' round by round. List the player's move and the result (Win, Loss, Tie) for each round accurately.
170
+ 3. **Calculate Frequencies:** Use the provided 'Player move frequencies' or calculate them from the rounds. Note any strong preference.
171
+ 4. **Identify Patterns:** Look for sequences (e.g., did the player repeat Rock twice?), reactions (e.g., what did the player do after winning/losing?), or other tendencies based *only* on the provided data. State the patterns clearly.
172
+ 5. **Reasoning:** Explain your reasoning for the recommendation based *only* on the verified round data and identified patterns. Do not invent patterns.
173
+ 6. **Recommendation:** Provide a single, clear recommendation (Rock, Paper, or Scissors) for the *next* round and justify it concisely based on your reasoning.
174
+
175
+ Structure your response clearly with sections for Analysis, Patterns, Reasoning, and Recommendation. Be factual and base everything strictly on the provided game history."""
176
+
177
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
178
+ gr.Markdown(f"# {MODEL_ID} - RPS Game Analysis Tester")
179
  gr.Markdown("Test how the model responds to different RPS game data formats and prompts using its chat template.")
180
 
181
  with gr.Row():
182
  with gr.Column():
 
183
  game_format = gr.Dropdown(
184
+ choices=list(PREDEFINED_GAMES.keys()), value="rps_simple", label="Game Data Format"
 
 
185
  )
186
+ # Use the detailed DEFAULT_SYSTEM_PROMPT as the placeholder/default value
 
 
187
  system_prompt = gr.Textbox(
188
  label="System Prompt (Optional)",
189
+ placeholder=DEFAULT_SYSTEM_PROMPT, # Set placeholder
190
+ value=DEFAULT_SYSTEM_PROMPT, # Set default value
191
+ lines=15 # Increased lines to show more default text
192
  )
 
 
193
  with gr.Row():
194
  prompt_template = gr.Dropdown(
195
+ choices=list(PROMPT_TEMPLATES.keys()), value="detailed_analysis_recommendation", label="Prompt Template"
 
 
196
  )
197
+ use_custom_prompt = gr.Checkbox(label="Use Custom Prompt", value=False)
 
 
 
 
198
  custom_prompt = gr.Textbox(
199
  label="Custom Prompt (if Use Custom Prompt is checked)",
200
+ placeholder="Enter your custom prompt/question here", lines=3
 
201
  )
 
 
202
  with gr.Row():
203
+ max_length = gr.Slider(minimum=50, maximum=1024, value=512, step=16, label="Max New Tokens")
204
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature")
205
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
206
+ submit_btn = gr.Button("Generate Response", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  with gr.Column():
 
 
209
  final_prompt_display = gr.Textbox(
210
+ label="Formatted Input Sent to Model (via Chat Template)", lines=15
 
211
  )
212
  response_display = gr.Textbox(
213
+ label="Model Response", lines=15, show_copy_button=True
 
 
214
  )
 
 
215
  gr.Markdown("""
216
  ## Testing Tips
217
+ - **Game Data Format**: Selects how history is structured. 'rps_simple' uses the improved format now.
218
+ - **System Prompt**: Crucial for setting the AI's role and desired output style. The default is now much more detailed.
219
+ - **Prompt Template / Custom Prompt**: Asks the specific question.
220
+ - **Generation Params**: Try lowering `Temperature` (e.g., to 0.3-0.5) for more factual, less random output.
221
+ - **Chat Template**: This version uses the model's chat template correctly.
 
222
  """)
223
 
 
224
  submit_btn.click(
225
  process_input,
226
  inputs=[
227
+ game_format, prompt_template, custom_prompt, use_custom_prompt,
228
+ system_prompt, max_length, temperature, top_p
 
 
 
 
 
 
229
  ],
230
  outputs=[final_prompt_display, response_display],
231
+ api_name="generate_rps_analysis"
232
  )
233
 
234
  # --- Launch the demo ---