rui3000 commited on
Commit
5f32a93
·
verified ·
1 Parent(s): 3ef427a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -143
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import torch
3
- import json
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # --- Configuration ---
@@ -16,90 +15,8 @@ model = AutoModelForCausalLM.from_pretrained(
16
  )
17
  print("Model loaded successfully.")
18
 
19
- # --- Predefined Data (User's structure) ---
20
- PREDEFINED_GAMES = {
21
- "rps_simple": {
22
- "description": "Rock-Paper-Scissors (Simple Format)",
23
- "data": {
24
- "game_type": "rps",
25
- "encoding": {"rock": 0, "paper": 1, "scissors": 2},
26
- "result_encoding": {"ai_win": 0, "player_win": 1, "tie": 2},
27
- "rounds": [
28
- {"round": 1, "player": 0, "ai": 2, "result": 1}, {"round": 2, "player": 1, "ai": 1, "result": 2},
29
- {"round": 3, "player": 2, "ai": 0, "result": 0}, {"round": 4, "player": 0, "ai": 0, "result": 2},
30
- {"round": 5, "player": 1, "ai": 0, "result": 1}, {"round": 6, "player": 2, "ai": 2, "result": 2},
31
- {"round": 7, "player": 0, "ai": 1, "result": 0}, {"round": 8, "player": 1, "ai": 2, "result": 0},
32
- {"round": 9, "player": 2, "ai": 1, "result": 1}, {"round": 10, "player": 0, "ai": 2, "result": 1}
33
- ],
34
- "summary": {"player_wins": 4, "ai_wins": 3, "ties": 3}
35
- }
36
- },
37
- "rps_numeric": {
38
- "description": "Rock-Paper-Scissors (Compressed Numeric Format)",
39
- "data": {
40
- "rules": "RPS: 0=Rock,1=Paper,2=Scissors. Result: 0=AI_win,1=Player_win,2=Tie",
41
- "rounds": [[1,0,2,1],[2,1,1,2],[3,2,0,0],[4,0,0,2],[5,1,0,1],[6,2,2,2],[7,0,1,0],[8,1,2,0],[9,2,1,1],[10,0,2,1]],
42
- "score": {"P": 4, "AI": 3, "Tie": 3}
43
- }
44
- }
45
- }
46
-
47
- # --- Predefined Prompts (User's structure) ---
48
- PROMPT_TEMPLATES = {
49
- "detailed_analysis_recommendation": "Analyze the game history provided. Identify patterns in the player's moves. Based on your analysis, explain the reasoning and recommend the best move for the AI (or player if specified) in the next round.",
50
- "player_pattern_focus": "Focus specifically on the player's move patterns. Do they favor a specific move? Do they follow sequences? Do they react predictably after winning or losing?",
51
- "brief_recommendation": "Based on the history, what single move (Rock, Paper, or Scissors) should be played next and give a one-sentence justification?",
52
- "structured_output_request": "Provide a structured analysis with these sections: 1) Obvious player patterns, 2) Potential opponent counter-strategies, 3) Final move recommendation with reasoning."
53
- }
54
-
55
- # --- Formatting Functions (Updated format_rps_simple) ---
56
- def format_rps_simple(game_data):
57
- """Format the RPS data clearly, explicitly stating moves and results."""
58
- game = game_data["data"]
59
- move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
60
- result_map = {0: "AI wins", 1: "Player wins", 2: "Tie"} # Changed name
61
- player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
62
-
63
- formatted_data = "Game: Rock-Paper-Scissors\n"
64
- formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
65
- formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n" # Simplified explanation
66
-
67
- formatted_data += "Game Data (Round, Player Move, AI Move, Result Text):\n" # Clarified header
68
- for round_data in game["rounds"]:
69
- r_num, p_move, ai_move, result_code = round_data["round"], round_data["player"], round_data["ai"], round_data["result"]
70
- player_moves[move_names[p_move]] += 1
71
- result_text = result_map[result_code]
72
- # Explicitly add text names and result text in the main data line
73
- formatted_data += f"R{r_num}: Player={move_names[p_move]}({p_move}), AI={move_names[ai_move]}({ai_move}), Result={result_text}\n"
74
-
75
- formatted_data += "\nSummary:\n"
76
- formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
77
- formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
78
- formatted_data += f"Ties: {game['summary']['ties']}\n\n"
79
-
80
- formatted_data += "Player move frequencies:\n"
81
- total_rounds = len(game["rounds"])
82
- for move, count in player_moves.items():
83
- percentage = round((count / total_rounds) * 100) if total_rounds > 0 else 0
84
- formatted_data += f"{move}: {count} times ({percentage}%)\n"
85
- return formatted_data
86
-
87
- def format_rps_numeric(game_data):
88
- """Format the RPS data in a highly compressed numeric format"""
89
- game = game_data["data"]
90
- formatted_data = "RPS Game Data (compressed format)\n"
91
- formatted_data += f"Rules: {game['rules']}\n\n"
92
- rounds_str = ",".join([str(r) for r in game['rounds']])
93
- formatted_data += f"Rounds: {rounds_str}\n\n"
94
- formatted_data += f"Score: Player={game['score']['P']} AI={game['score']['AI']} Ties={game['score']['Tie']}\n"
95
- return formatted_data
96
-
97
- FORMAT_FUNCTIONS = {
98
- "rps_simple": format_rps_simple,
99
- "rps_numeric": format_rps_numeric
100
- }
101
-
102
- # --- Generation Function (Using Chat Template) ---
103
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
104
  """Generate a response from the Qwen2 model using chat template."""
105
  try:
@@ -129,106 +46,142 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
129
  print(f"Error during generation: {e}")
130
  return f"An error occurred: {str(e)}"
131
 
132
- # --- Input Processing Function (Using Chat Template) ---
133
  def process_input(
134
- game_format,
135
- prompt_template,
136
- custom_prompt,
137
- use_custom_prompt,
138
  system_prompt,
 
139
  max_length,
140
  temperature,
141
  top_p
142
  ):
143
- """Process the input, format using chat template, and generate response."""
144
- game_data = PREDEFINED_GAMES[game_format]
145
- formatted_game_data = FORMAT_FUNCTIONS[game_format](game_data)
146
- user_question = custom_prompt if use_custom_prompt else PROMPT_TEMPLATES[prompt_template]
147
- user_content = f"Game History:\n{formatted_game_data}\n\nQuestion:\n{user_question}"
 
 
 
 
148
  messages = []
149
- if system_prompt and system_prompt.strip():
150
  messages.append({"role": "system", "content": system_prompt})
151
  messages.append({"role": "user", "content": user_content})
 
 
152
  response = generate_response(
153
  messages,
154
  max_length=max_length,
155
  temperature=temperature,
156
  top_p=top_p
157
  )
 
 
158
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
159
- return display_prompt, response
160
 
161
- # --- Gradio Interface (Updated system prompt placeholder) ---
162
 
163
- # Define the improved default system prompt
164
- DEFAULT_SYSTEM_PROMPT = """You are a highly accurate and methodical Rock-Paper-Scissors (RPS) strategy analyst.
165
- Your goal is to analyze the provided game history and give the user strategic advice for their next move.
166
 
167
- Follow these steps precisely:
168
- 1. **Verify Rules:** Remember: Rock (0) beats Scissors (2), Scissors (2) beats Paper (1), Paper (1) beats Rock (0).
169
- 2. **Analyze Player Moves:** Go through the 'Game Data' round by round. List the player's move and the result (Win, Loss, Tie) for each round accurately.
170
- 3. **Calculate Frequencies:** Use the provided 'Player move frequencies' or calculate them from the rounds. Note any strong preference.
171
- 4. **Identify Patterns:** Look for sequences (e.g., did the player repeat Rock twice?), reactions (e.g., what did the player do after winning/losing?), or other tendencies based *only* on the provided data. State the patterns clearly.
172
- 5. **Reasoning:** Explain your reasoning for the recommendation based *only* on the verified round data and identified patterns. Do not invent patterns.
173
- 6. **Recommendation:** Provide a single, clear recommendation (Rock, Paper, or Scissors) for the *next* round and justify it concisely based on your reasoning.
174
 
175
- Structure your response clearly with sections for Analysis, Patterns, Reasoning, and Recommendation. Be factual and base everything strictly on the provided game history."""
 
 
 
176
 
177
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
178
- gr.Markdown(f"# {MODEL_ID} - RPS Game Analysis Tester")
179
- gr.Markdown("Test how the model responds to different RPS game data formats and prompts using its chat template.")
180
 
181
  with gr.Row():
182
- with gr.Column():
183
- game_format = gr.Dropdown(
184
- choices=list(PREDEFINED_GAMES.keys()), value="rps_simple", label="Game Data Format"
 
 
 
 
185
  )
186
- # Use the detailed DEFAULT_SYSTEM_PROMPT as the placeholder/default value
187
- system_prompt = gr.Textbox(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  label="System Prompt (Optional)",
189
- placeholder=DEFAULT_SYSTEM_PROMPT, # Set placeholder
190
- value=DEFAULT_SYSTEM_PROMPT, # Set default value
191
- lines=15 # Increased lines to show more default text
192
  )
193
- with gr.Row():
194
- prompt_template = gr.Dropdown(
195
- choices=list(PROMPT_TEMPLATES.keys()), value="detailed_analysis_recommendation", label="Prompt Template"
196
- )
197
- use_custom_prompt = gr.Checkbox(label="Use Custom Prompt", value=False)
198
- custom_prompt = gr.Textbox(
199
- label="Custom Prompt (if Use Custom Prompt is checked)",
200
- placeholder="Enter your custom prompt/question here", lines=3
201
  )
202
- with gr.Row():
203
- max_length = gr.Slider(minimum=50, maximum=1024, value=512, step=16, label="Max New Tokens")
204
- temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature")
205
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
 
 
 
206
  submit_btn = gr.Button("Generate Response", variant="primary")
207
 
208
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
209
  final_prompt_display = gr.Textbox(
210
  label="Formatted Input Sent to Model (via Chat Template)", lines=15
211
  )
 
212
  response_display = gr.Textbox(
213
  label="Model Response", lines=15, show_copy_button=True
214
  )
215
- gr.Markdown("""
216
- ## Testing Tips
217
- - **Game Data Format**: Selects how history is structured. 'rps_simple' uses the improved format now.
218
- - **System Prompt**: Crucial for setting the AI's role and desired output style. The default is now much more detailed.
219
- - **Prompt Template / Custom Prompt**: Asks the specific question.
220
- - **Generation Params**: Try lowering `Temperature` (e.g., to 0.3-0.5) for more factual, less random output.
221
- - **Chat Template**: This version uses the model's chat template correctly.
222
- """)
223
 
 
 
224
  submit_btn.click(
225
  process_input,
226
  inputs=[
227
- game_format, prompt_template, custom_prompt, use_custom_prompt,
228
- system_prompt, max_length, temperature, top_p
 
 
 
 
 
229
  ],
230
  outputs=[final_prompt_display, response_display],
231
- api_name="generate_rps_analysis"
232
  )
233
 
234
  # --- Launch the demo ---
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # --- Configuration ---
 
15
  )
16
  print("Model loaded successfully.")
17
 
18
+
19
+ # --- Generation Function (Using Chat Template - No changes needed here) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
21
  """Generate a response from the Qwen2 model using chat template."""
22
  try:
 
46
  print(f"Error during generation: {e}")
47
  return f"An error occurred: {str(e)}"
48
 
49
+ # --- Input Processing Function (Simplified for Frequency Stats) ---
50
  def process_input(
51
+ player_stats,
52
+ ai_stats,
 
 
53
  system_prompt,
54
+ user_query, # Changed from prompt template/custom
55
  max_length,
56
  temperature,
57
  top_p
58
  ):
59
+ """Process frequency stats and user query for the model."""
60
+
61
+ # Construct the user message content using the provided stats and query
62
+ user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
63
+ if ai_stats and ai_stats.strip(): # Include AI stats if provided
64
+ user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
65
+ user_content += f"User Query:\n{user_query}"
66
+
67
+ # Create the messages list for the chat template
68
  messages = []
69
+ if system_prompt and system_prompt.strip(): # Add system prompt if provided
70
  messages.append({"role": "system", "content": system_prompt})
71
  messages.append({"role": "user", "content": user_content})
72
+
73
+ # Generate response from the model
74
  response = generate_response(
75
  messages,
76
  max_length=max_length,
77
  temperature=temperature,
78
  top_p=top_p
79
  )
80
+
81
+ # For display purposes, show the constructed input
82
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
 
83
 
84
+ return display_prompt, response
85
 
86
+ # --- Gradio Interface (Simplified for Frequency Stats) ---
 
 
87
 
88
+ # Define a default system prompt suitable for frequency analysis
89
+ DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist.
90
+ Analyze the provided frequency statistics for the player's (and potentially AI's) past moves.
91
+ Based *only* on these statistics, determine the statistically optimal counter-strategy or recommendation for the AI's next move.
92
+ Explain your reasoning clearly based on the probabilities implied by the frequencies and the rules of RPS (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
93
+ Provide a clear recommendation (Rock, Paper, or Scissors) and justify it using expected outcomes or probabilities."""
 
94
 
95
+ # Default example stats
96
+ DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
97
+ DEFAULT_AI_STATS = "Rock: 33%\nPaper: 34%\nScissors: 33%" # Example AI stats
98
+ DEFAULT_USER_QUERY = "Based on the player's move frequencies, what move should the AI make next to maximize its statistical chance of winning? Explain your reasoning."
99
 
100
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
+ gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
102
+ gr.Markdown("Test how the model provides strategic advice based *only* on Player and AI move frequency statistics.")
103
 
104
  with gr.Row():
105
+ with gr.Column(scale=2): # Make input column wider
106
+ # Input for Player Stats
107
+ player_stats_input = gr.Textbox(
108
+ label="Player Move Frequency Stats",
109
+ value=DEFAULT_PLAYER_STATS,
110
+ lines=4,
111
+ info="Enter the observed frequencies of the player's moves."
112
  )
113
+ # Input for AI Stats (Optional)
114
+ ai_stats_input = gr.Textbox(
115
+ label="AI Move Frequency Stats (Optional)",
116
+ value=DEFAULT_AI_STATS,
117
+ lines=4,
118
+ info="Optionally, enter the AI's own move frequencies if relevant."
119
+ )
120
+ # Input for User Query
121
+ user_query_input = gr.Textbox(
122
+ label="Your Query / Instruction",
123
+ value=DEFAULT_USER_QUERY,
124
+ lines=3,
125
+ info="Ask the specific question based on the frequency stats."
126
+ )
127
+ # System prompt (optional)
128
+ system_prompt_input = gr.Textbox(
129
  label="System Prompt (Optional)",
130
+ value=DEFAULT_SYSTEM_PROMPT,
131
+ placeholder="Define the AI's role and task based on frequency stats...",
132
+ lines=10 # Reduced lines needed
133
  )
134
+
135
+ with gr.Column(scale=1): # Make params/output column narrower
136
+ # Generation parameters
137
+ gr.Markdown("## Generation Parameters")
138
+ max_length_slider = gr.Slider(
139
+ minimum=50, maximum=1024, value=350, step=16, label="Max New Tokens" # Reduced default length needed
 
 
140
  )
141
+ temperature_slider = gr.Slider(
142
+ minimum=0.1, maximum=1.5, value=0.5, step=0.05, label="Temperature" # Defaulting lower for stats analysis
143
+ )
144
+ top_p_slider = gr.Slider(
145
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P"
146
+ )
147
+ # Generate button
148
  submit_btn = gr.Button("Generate Response", variant="primary")
149
 
150
+ # Tips for using the interface
151
+ gr.Markdown("""
152
+ ## Testing Tips
153
+ - Input player move frequencies directly. AI stats are optional.
154
+ - Refine the **User Query** to guide the model's task.
155
+ - Adjust the **System Prompt** for role/task definition.
156
+ - Use lower **Temperature** for more deterministic, calculation-like responses based on stats.
157
+ """)
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ # Display final prompt and model response
162
  final_prompt_display = gr.Textbox(
163
  label="Formatted Input Sent to Model (via Chat Template)", lines=15
164
  )
165
+ with gr.Column():
166
  response_display = gr.Textbox(
167
  label="Model Response", lines=15, show_copy_button=True
168
  )
 
 
 
 
 
 
 
 
169
 
170
+
171
+ # Handle button click - Updated inputs list
172
  submit_btn.click(
173
  process_input,
174
  inputs=[
175
+ player_stats_input,
176
+ ai_stats_input,
177
+ system_prompt_input,
178
+ user_query_input, # New input
179
+ max_length_slider,
180
+ temperature_slider,
181
+ top_p_slider
182
  ],
183
  outputs=[final_prompt_display, response_display],
184
+ api_name="generate_rps_frequency_analysis" # Updated api_name
185
  )
186
 
187
  # --- Launch the demo ---