rui3000 commited on
Commit
480da6f
·
verified ·
1 Parent(s): 58e6bf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -143
app.py CHANGED
@@ -3,17 +3,23 @@ import torch
3
  import json
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
- # Load the Qwen2 0.5B model
7
- model_id = "Qwen/Qwen2-0.5B"
8
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
- model_id,
11
- torch_dtype=torch.float16,
12
- device_map="auto",
13
- trust_remote_code=True
14
  )
 
15
 
16
- # Predefined game data in compressed formats
17
  PREDEFINED_GAMES = {
18
  "rps_simple": {
19
  "description": "Rock-Paper-Scissors (Simple Format)",
@@ -22,16 +28,11 @@ PREDEFINED_GAMES = {
22
  "encoding": {"rock": 0, "paper": 1, "scissors": 2},
23
  "result_encoding": {"ai_win": 0, "player_win": 1, "tie": 2},
24
  "rounds": [
25
- {"round": 1, "player": 0, "ai": 2, "result": 1},
26
- {"round": 2, "player": 1, "ai": 1, "result": 2},
27
- {"round": 3, "player": 2, "ai": 0, "result": 0},
28
- {"round": 4, "player": 0, "ai": 0, "result": 2},
29
- {"round": 5, "player": 1, "ai": 0, "result": 1},
30
- {"round": 6, "player": 2, "ai": 2, "result": 2},
31
- {"round": 7, "player": 0, "ai": 1, "result": 0},
32
- {"round": 8, "player": 1, "ai": 2, "result": 0},
33
- {"round": 9, "player": 2, "ai": 1, "result": 1},
34
- {"round": 10, "player": 0, "ai": 2, "result": 1}
35
  ],
36
  "summary": {"player_wins": 4, "ai_wins": 3, "ties": 3}
37
  }
@@ -46,146 +47,144 @@ PREDEFINED_GAMES = {
46
  }
47
  }
48
 
49
- # Predefined prompt templates
 
50
  PROMPT_TEMPLATES = {
51
- "basic_analysis": "Who is winning right now? What patterns do you notice in the player's choices?",
52
- "prediction": "Based on the player's past choices, predict what the player will choose in the next round. Explain your reasoning.",
53
- "strategy": "What strategy should the AI use to improve its win rate? Provide specific recommendations.",
54
- "pattern_analysis": "Analyze the frequency of each choice (rock, paper, scissors) made by the player. Is there a dominant pattern?",
55
- "structured_analysis": "Provide a structured analysis with these sections: 1) Current winner, 2) Player choice patterns, 3) AI performance, 4) Recommended strategy for AI."
56
  }
57
 
58
- # Prompt formatters
59
  def format_rps_simple(game_data):
60
  """Format the RPS data in a simple way that's easy for small models to understand"""
61
  game = game_data["data"]
62
-
63
- # Create a mapping for move names
64
  move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
65
  result_names = {0: "AI wins", 1: "Player wins", 2: "Tie"}
66
-
67
- # Initialize counters for frequency analysis
68
  player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
69
-
70
- # Format each round in a simple way
71
  formatted_data = "Game: Rock-Paper-Scissors\n"
72
  formatted_data += "Format explanation: [Round#, Player move, AI move, Result]\n"
73
  formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
74
  formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n"
75
-
76
  formatted_data += "Game Data:\n"
77
  for round_data in game["rounds"]:
78
- r_num = round_data["round"]
79
- p_move = round_data["player"]
80
- ai_move = round_data["ai"]
81
- result = round_data["result"]
82
-
83
- # Update player move counter
84
  player_moves[move_names[p_move]] += 1
85
-
86
- # Format as [round, player, ai, result]
87
  formatted_data += f"[{r_num}, {p_move}, {ai_move}, {result}] # R{r_num}: Player {move_names[p_move]}, AI {move_names[ai_move]}, {result_names[result]}\n"
88
-
89
- # Add summary statistics
90
  formatted_data += "\nSummary:\n"
91
  formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
92
  formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
93
  formatted_data += f"Ties: {game['summary']['ties']}\n\n"
94
-
95
- # Add player move frequencies
96
  formatted_data += "Player move frequencies:\n"
 
97
  for move, count in player_moves.items():
98
- formatted_data += f"{move}: {count} times ({count*10}%)\n"
99
-
100
  return formatted_data
101
 
102
  def format_rps_numeric(game_data):
103
  """Format the RPS data in a highly compressed numeric format"""
104
  game = game_data["data"]
105
-
106
  formatted_data = "RPS Game Data (compressed format)\n"
107
  formatted_data += f"Rules: {game['rules']}\n\n"
108
-
109
- # Format all rounds on a single line
110
  rounds_str = ",".join([str(r) for r in game['rounds']])
111
  formatted_data += f"Rounds: {rounds_str}\n\n"
112
-
113
- # Add score summary
114
  formatted_data += f"Score: Player={game['score']['P']} AI={game['score']['AI']} Ties={game['score']['Tie']}\n"
115
-
116
  return formatted_data
117
 
118
- # Format selectors
119
  FORMAT_FUNCTIONS = {
120
  "rps_simple": format_rps_simple,
121
  "rps_numeric": format_rps_numeric
122
  }
123
 
124
- def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9):
125
- """Generate a response from the Qwen2 model based on the input prompt."""
126
-
127
- # Tokenize the input prompt
128
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
129
-
130
- # Generate response
131
- with torch.no_grad():
132
- outputs = model.generate(
133
- **inputs,
134
- max_new_tokens=max_length,
135
- do_sample=True,
136
- temperature=temperature,
137
- top_p=top_p,
138
  )
139
-
140
- # Decode the response
141
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
-
143
- # Extract only the model's response (remove the input prompt)
144
- if prompt in response:
145
- response = response[len(prompt):]
146
-
147
- return response.strip()
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def process_input(
150
  game_format,
151
  prompt_template,
152
  custom_prompt,
153
  use_custom_prompt,
154
  system_prompt,
155
- max_length,
156
- temperature,
157
  top_p
158
  ):
159
- """Process the input and generate a response from the model."""
160
-
161
  # Get the selected game data and format it
162
  game_data = PREDEFINED_GAMES[game_format]
163
- formatted_game_data = FORMAT_FUNCTIONS[game_format](game_data)
164
-
165
- # Determine which prompt to use
166
- prompt_text = custom_prompt if use_custom_prompt else PROMPT_TEMPLATES[prompt_template]
167
-
168
- # Create the final prompt with optional system prompt
169
- if system_prompt:
170
- final_prompt = f"{system_prompt}\n\n{formatted_game_data}\n\n{prompt_text}"
171
- else:
172
- final_prompt = f"{formatted_game_data}\n\n{prompt_text}"
173
-
 
 
 
174
  # Generate response from the model
175
  response = generate_response(
176
- final_prompt,
177
  max_length=max_length,
178
  temperature=temperature,
179
  top_p=top_p
180
  )
181
-
182
- return final_prompt, response
183
-
184
- # Create the Gradio interface
185
- with gr.Blocks() as demo:
186
- gr.Markdown("# Qwen2 0.5B Game Analysis Tester")
187
- gr.Markdown("Test how the Qwen2 0.5B model responds to different game data formats and prompts")
188
-
 
 
 
 
189
  with gr.Row():
190
  with gr.Column():
191
  # Game data selection
@@ -194,81 +193,76 @@ with gr.Blocks() as demo:
194
  value="rps_simple",
195
  label="Game Data Format"
196
  )
197
-
198
  # System prompt (optional)
 
199
  system_prompt = gr.Textbox(
200
  label="System Prompt (Optional)",
201
- placeholder="e.g., You are an expert game analyzer. Your task is to analyze game patterns and provide insights.",
202
- lines=2
203
  )
204
-
205
  # Prompt selection
206
  with gr.Row():
207
  prompt_template = gr.Dropdown(
208
  choices=list(PROMPT_TEMPLATES.keys()),
209
- value="basic_analysis",
210
  label="Prompt Template"
211
  )
212
  use_custom_prompt = gr.Checkbox(
213
  label="Use Custom Prompt",
214
  value=False
215
  )
216
-
217
  custom_prompt = gr.Textbox(
218
- label="Custom Prompt (if enabled above)",
219
- placeholder="Enter your custom prompt here",
220
- lines=2
221
  )
222
-
223
  # Generation parameters
224
  with gr.Row():
225
  max_length = gr.Slider(
226
- minimum=50,
227
- maximum=512,
228
- value=256,
229
- step=1,
230
- label="Max Response Length"
231
  )
232
  temperature = gr.Slider(
233
- minimum=0.1,
234
- maximum=1.5,
235
- value=0.7,
236
- step=0.1,
237
- label="Temperature"
238
  )
239
  top_p = gr.Slider(
240
- minimum=0.1,
241
- maximum=1.0,
242
- value=0.9,
243
- step=0.1,
244
- label="Top P"
245
  )
246
-
247
  # Generate button
248
- submit_btn = gr.Button("Generate Response")
249
-
250
  with gr.Column():
251
  # Display final prompt and model response
 
252
  final_prompt_display = gr.Textbox(
253
- label="Final Prompt Sent to Model",
254
- lines=12
255
  )
256
  response_display = gr.Textbox(
257
- label="Model Response",
258
- lines=12
 
259
  )
260
-
261
  # Tips for using the interface
262
  gr.Markdown("""
263
  ## Testing Tips
264
-
265
- - The **Game Data Format** determines how game information is presented to the model
266
- - The **System Prompt** can be used to provide context or role instructions
267
- - **Prompt Templates** offer pre-made queries, or you can use a custom prompt
268
- - Experiment with **Temperature** (higher = more creative/random, lower = more focused)
269
- - Document successful prompts for fine-tuning datasets
270
  """)
271
-
272
  # Handle button click
273
  submit_btn.click(
274
  process_input,
@@ -282,8 +276,10 @@ with gr.Blocks() as demo:
282
  temperature,
283
  top_p
284
  ],
285
- outputs=[final_prompt_display, response_display]
 
286
  )
287
 
288
- # Launch the demo
289
- demo.launch()
 
 
3
  import json
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
+ # --- Configuration ---
7
+ # Updated to the 1.5B Instruct model as requested
8
+ MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
9
+
10
+ # --- Load Model and Tokenizer ---
11
+ print(f"Loading model: {MODEL_ID}")
12
+ # Removed trust_remote_code=True as it's generally not needed for standard HF models
13
+ # Using torch_dtype="auto" for flexibility (can use bfloat16 if available)
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype="auto", # Changed from float16 to auto
18
+ device_map="auto"
 
19
  )
20
+ print("Model loaded successfully.")
21
 
22
+ # --- Predefined Data (Keep user's structure) ---
23
  PREDEFINED_GAMES = {
24
  "rps_simple": {
25
  "description": "Rock-Paper-Scissors (Simple Format)",
 
28
  "encoding": {"rock": 0, "paper": 1, "scissors": 2},
29
  "result_encoding": {"ai_win": 0, "player_win": 1, "tie": 2},
30
  "rounds": [
31
+ {"round": 1, "player": 0, "ai": 2, "result": 1}, {"round": 2, "player": 1, "ai": 1, "result": 2},
32
+ {"round": 3, "player": 2, "ai": 0, "result": 0}, {"round": 4, "player": 0, "ai": 0, "result": 2},
33
+ {"round": 5, "player": 1, "ai": 0, "result": 1}, {"round": 6, "player": 2, "ai": 2, "result": 2},
34
+ {"round": 7, "player": 0, "ai": 1, "result": 0}, {"round": 8, "player": 1, "ai": 2, "result": 0},
35
+ {"round": 9, "player": 2, "ai": 1, "result": 1}, {"round": 10, "player": 0, "ai": 2, "result": 1}
 
 
 
 
 
36
  ],
37
  "summary": {"player_wins": 4, "ai_wins": 3, "ties": 3}
38
  }
 
47
  }
48
  }
49
 
50
+ # --- Predefined Prompts (Keep user's structure) ---
51
+ # Updated default prompts to be more aligned with the goal
52
  PROMPT_TEMPLATES = {
53
+ "detailed_analysis_recommendation": "Analyze the game history provided. Identify patterns in the player's moves. Based on your analysis, explain the reasoning and recommend the best move for the AI (or player if specified) in the next round.",
54
+ "player_pattern_focus": "Focus specifically on the player's move patterns. Do they favor a specific move? Do they follow sequences? Do they react predictably after winning or losing?",
55
+ "brief_recommendation": "Based on the history, what single move (Rock, Paper, or Scissors) should be played next and give a one-sentence justification?",
56
+ "structured_output_request": "Provide a structured analysis with these sections: 1) Obvious player patterns, 2) Potential opponent counter-strategies, 3) Final move recommendation with reasoning."
 
57
  }
58
 
59
+ # --- Formatting Functions (Keep user's functions) ---
60
  def format_rps_simple(game_data):
61
  """Format the RPS data in a simple way that's easy for small models to understand"""
62
  game = game_data["data"]
 
 
63
  move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
64
  result_names = {0: "AI wins", 1: "Player wins", 2: "Tie"}
 
 
65
  player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
 
 
66
  formatted_data = "Game: Rock-Paper-Scissors\n"
67
  formatted_data += "Format explanation: [Round#, Player move, AI move, Result]\n"
68
  formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
69
  formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n"
 
70
  formatted_data += "Game Data:\n"
71
  for round_data in game["rounds"]:
72
+ r_num, p_move, ai_move, result = round_data["round"], round_data["player"], round_data["ai"], round_data["result"]
 
 
 
 
 
73
  player_moves[move_names[p_move]] += 1
 
 
74
  formatted_data += f"[{r_num}, {p_move}, {ai_move}, {result}] # R{r_num}: Player {move_names[p_move]}, AI {move_names[ai_move]}, {result_names[result]}\n"
 
 
75
  formatted_data += "\nSummary:\n"
76
  formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
77
  formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
78
  formatted_data += f"Ties: {game['summary']['ties']}\n\n"
 
 
79
  formatted_data += "Player move frequencies:\n"
80
+ total_rounds = len(game["rounds"])
81
  for move, count in player_moves.items():
82
+ percentage = round((count / total_rounds) * 100) if total_rounds > 0 else 0
83
+ formatted_data += f"{move}: {count} times ({percentage}%)\n" # Corrected percentage calc
84
  return formatted_data
85
 
86
  def format_rps_numeric(game_data):
87
  """Format the RPS data in a highly compressed numeric format"""
88
  game = game_data["data"]
 
89
  formatted_data = "RPS Game Data (compressed format)\n"
90
  formatted_data += f"Rules: {game['rules']}\n\n"
 
 
91
  rounds_str = ",".join([str(r) for r in game['rounds']])
92
  formatted_data += f"Rounds: {rounds_str}\n\n"
 
 
93
  formatted_data += f"Score: Player={game['score']['P']} AI={game['score']['AI']} Ties={game['score']['Tie']}\n"
 
94
  return formatted_data
95
 
 
96
  FORMAT_FUNCTIONS = {
97
  "rps_simple": format_rps_simple,
98
  "rps_numeric": format_rps_numeric
99
  }
100
 
101
+ # --- Generation Function (Updated for Chat Template) ---
102
+ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
103
+ """Generate a response from the Qwen2 model using chat template."""
104
+ try:
105
+ # Apply the chat template
106
+ prompt_text = tokenizer.apply_chat_template(
107
+ messages,
108
+ tokenize=False,
109
+ add_generation_prompt=True # Important for instruct models
 
 
 
 
 
110
  )
 
 
 
 
 
 
 
 
 
111
 
112
+ # Tokenize the formatted prompt
113
+ model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
114
+
115
+ # Generation arguments
116
+ generation_kwargs = {
117
+ "max_new_tokens": max_length,
118
+ "temperature": temperature,
119
+ "top_p": top_p,
120
+ "do_sample": True,
121
+ "pad_token_id": tokenizer.eos_token_id,
122
+ }
123
+
124
+ # Generate response
125
+ print("Generating response...")
126
+ with torch.no_grad():
127
+ generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
128
+
129
+ # Decode the response, excluding the input tokens
130
+ input_ids_len = model_inputs.input_ids.shape[-1]
131
+ output_ids = generated_ids[0, input_ids_len:]
132
+ response = tokenizer.decode(output_ids, skip_special_tokens=True)
133
+ print("Generation complete.")
134
+ return response.strip()
135
+
136
+ except Exception as e:
137
+ print(f"Error during generation: {e}")
138
+ return f"An error occurred: {str(e)}"
139
+
140
+ # --- Input Processing Function (Updated for Chat Template) ---
141
  def process_input(
142
  game_format,
143
  prompt_template,
144
  custom_prompt,
145
  use_custom_prompt,
146
  system_prompt,
147
+ max_length,
148
+ temperature,
149
  top_p
150
  ):
151
+ """Process the input, format using chat template, and generate response."""
152
+
153
  # Get the selected game data and format it
154
  game_data = PREDEFINED_GAMES[game_format]
155
+ formatted_game_data = FORMAT_FUNCTIONS[game_format](game_data) #
156
+
157
+ # Determine which prompt question to use
158
+ user_question = custom_prompt if use_custom_prompt else PROMPT_TEMPLATES[prompt_template] #
159
+
160
+ # Construct the user message content
161
+ user_content = f"Game History:\n{formatted_game_data}\n\nQuestion:\n{user_question}"
162
+
163
+ # Create the messages list for the chat template
164
+ messages = []
165
+ if system_prompt and system_prompt.strip(): # Add system prompt if provided
166
+ messages.append({"role": "system", "content": system_prompt})
167
+ messages.append({"role": "user", "content": user_content})
168
+
169
  # Generate response from the model
170
  response = generate_response(
171
+ messages,
172
  max_length=max_length,
173
  temperature=temperature,
174
  top_p=top_p
175
  )
176
+
177
+ # For display purposes, show the "user" part of the prompt
178
+ # (The system prompt isn't usually shown in the final input display)
179
+ display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
180
+
181
+ return display_prompt, response
182
+
183
+ # --- Gradio Interface (Minor updates) ---
184
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
185
+ gr.Markdown(f"# {MODEL_ID} - RPS Game Analysis Tester") # Updated Title
186
+ gr.Markdown("Test how the model responds to different RPS game data formats and prompts using its chat template.")
187
+
188
  with gr.Row():
189
  with gr.Column():
190
  # Game data selection
 
193
  value="rps_simple",
194
  label="Game Data Format"
195
  )
196
+
197
  # System prompt (optional)
198
+ # Added a more relevant placeholder based on the user's goal
199
  system_prompt = gr.Textbox(
200
  label="System Prompt (Optional)",
201
+ placeholder="e.g., You are an expert RPS analyst. Analyze the provided game history, identify patterns, explain your reasoning clearly, and recommend the next move. Structure your output with observations, reasoning, and a final recommendation.",
202
+ lines=4 # Increased lines slightly
203
  )
204
+
205
  # Prompt selection
206
  with gr.Row():
207
  prompt_template = gr.Dropdown(
208
  choices=list(PROMPT_TEMPLATES.keys()),
209
+ value="detailed_analysis_recommendation", # Updated default
210
  label="Prompt Template"
211
  )
212
  use_custom_prompt = gr.Checkbox(
213
  label="Use Custom Prompt",
214
  value=False
215
  )
216
+
217
  custom_prompt = gr.Textbox(
218
+ label="Custom Prompt (if Use Custom Prompt is checked)",
219
+ placeholder="Enter your custom prompt/question here",
220
+ lines=3 # Increased lines slightly
221
  )
222
+
223
  # Generation parameters
224
  with gr.Row():
225
  max_length = gr.Slider(
226
+ minimum=50,
227
+ maximum=1024, # Increased max
228
+ value=512, # Increased default
229
+ step=16, # Step size power of 2
230
+ label="Max New Tokens" # Renamed label
231
  )
232
  temperature = gr.Slider(
233
+ minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature" # Step size finer
 
 
 
 
234
  )
235
  top_p = gr.Slider(
236
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P" # Step size finer
 
 
 
 
237
  )
238
+
239
  # Generate button
240
+ submit_btn = gr.Button("Generate Response", variant="primary") # Added variant
241
+
242
  with gr.Column():
243
  # Display final prompt and model response
244
+ # Renamed label for clarity
245
  final_prompt_display = gr.Textbox(
246
+ label="Formatted Input Sent to Model (via Chat Template)",
247
+ lines=15 # Increased lines
248
  )
249
  response_display = gr.Textbox(
250
+ label="Model Response",
251
+ lines=15, # Increased lines
252
+ show_copy_button=True # Added copy button
253
  )
254
+
255
  # Tips for using the interface
256
  gr.Markdown("""
257
  ## Testing Tips
258
+
259
+ - **Game Data Format**: Selects how the history is structured. 'rps_simple' is often easier for models to parse.
260
+ - **System Prompt**: Crucial for setting the AI's role and desired output style (like your example image). Be descriptive!
261
+ - **Prompt Template / Custom Prompt**: Asks the specific question based on the history and system instructions.
262
+ - **Generation Params**: Tune `Temperature` and `Top P` to control creativity vs. focus. Adjust `Max New Tokens` for response length.
263
+ - **Chat Template**: This version now correctly uses the model's chat template for better instruction following.
264
  """)
265
+
266
  # Handle button click
267
  submit_btn.click(
268
  process_input,
 
276
  temperature,
277
  top_p
278
  ],
279
+ outputs=[final_prompt_display, response_display],
280
+ api_name="generate_rps_analysis" # Added api_name
281
  )
282
 
283
+ # --- Launch the demo ---
284
+ if __name__ == "__main__":
285
+ demo.launch()