Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
import json
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
|
6 |
# --- Configuration ---
|
@@ -16,90 +15,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
16 |
)
|
17 |
print("Model loaded successfully.")
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
"rps_simple": {
|
22 |
-
"description": "Rock-Paper-Scissors (Simple Format)",
|
23 |
-
"data": {
|
24 |
-
"game_type": "rps",
|
25 |
-
"encoding": {"rock": 0, "paper": 1, "scissors": 2},
|
26 |
-
"result_encoding": {"ai_win": 0, "player_win": 1, "tie": 2},
|
27 |
-
"rounds": [
|
28 |
-
{"round": 1, "player": 0, "ai": 2, "result": 1}, {"round": 2, "player": 1, "ai": 1, "result": 2},
|
29 |
-
{"round": 3, "player": 2, "ai": 0, "result": 0}, {"round": 4, "player": 0, "ai": 0, "result": 2},
|
30 |
-
{"round": 5, "player": 1, "ai": 0, "result": 1}, {"round": 6, "player": 2, "ai": 2, "result": 2},
|
31 |
-
{"round": 7, "player": 0, "ai": 1, "result": 0}, {"round": 8, "player": 1, "ai": 2, "result": 0},
|
32 |
-
{"round": 9, "player": 2, "ai": 1, "result": 1}, {"round": 10, "player": 0, "ai": 2, "result": 1}
|
33 |
-
],
|
34 |
-
"summary": {"player_wins": 4, "ai_wins": 3, "ties": 3}
|
35 |
-
}
|
36 |
-
},
|
37 |
-
"rps_numeric": {
|
38 |
-
"description": "Rock-Paper-Scissors (Compressed Numeric Format)",
|
39 |
-
"data": {
|
40 |
-
"rules": "RPS: 0=Rock,1=Paper,2=Scissors. Result: 0=AI_win,1=Player_win,2=Tie",
|
41 |
-
"rounds": [[1,0,2,1],[2,1,1,2],[3,2,0,0],[4,0,0,2],[5,1,0,1],[6,2,2,2],[7,0,1,0],[8,1,2,0],[9,2,1,1],[10,0,2,1]],
|
42 |
-
"score": {"P": 4, "AI": 3, "Tie": 3}
|
43 |
-
}
|
44 |
-
}
|
45 |
-
}
|
46 |
-
|
47 |
-
# --- Predefined Prompts (User's structure) ---
|
48 |
-
PROMPT_TEMPLATES = {
|
49 |
-
"detailed_analysis_recommendation": "Analyze the game history provided. Identify patterns in the player's moves. Based on your analysis, explain the reasoning and recommend the best move for the AI (or player if specified) in the next round.",
|
50 |
-
"player_pattern_focus": "Focus specifically on the player's move patterns. Do they favor a specific move? Do they follow sequences? Do they react predictably after winning or losing?",
|
51 |
-
"brief_recommendation": "Based on the history, what single move (Rock, Paper, or Scissors) should be played next and give a one-sentence justification?",
|
52 |
-
"structured_output_request": "Provide a structured analysis with these sections: 1) Obvious player patterns, 2) Potential opponent counter-strategies, 3) Final move recommendation with reasoning."
|
53 |
-
}
|
54 |
-
|
55 |
-
# --- Formatting Functions (Updated format_rps_simple) ---
|
56 |
-
def format_rps_simple(game_data):
|
57 |
-
"""Format the RPS data clearly, explicitly stating moves and results."""
|
58 |
-
game = game_data["data"]
|
59 |
-
move_names = {0: "Rock", 1: "Paper", 2: "Scissors"}
|
60 |
-
result_map = {0: "AI wins", 1: "Player wins", 2: "Tie"} # Changed name
|
61 |
-
player_moves = {"Rock": 0, "Paper": 0, "Scissors": 0}
|
62 |
-
|
63 |
-
formatted_data = "Game: Rock-Paper-Scissors\n"
|
64 |
-
formatted_data += "Move codes: 0=Rock, 1=Paper, 2=Scissors\n"
|
65 |
-
formatted_data += "Result codes: 0=AI wins, 1=Player wins, 2=Tie\n\n" # Simplified explanation
|
66 |
-
|
67 |
-
formatted_data += "Game Data (Round, Player Move, AI Move, Result Text):\n" # Clarified header
|
68 |
-
for round_data in game["rounds"]:
|
69 |
-
r_num, p_move, ai_move, result_code = round_data["round"], round_data["player"], round_data["ai"], round_data["result"]
|
70 |
-
player_moves[move_names[p_move]] += 1
|
71 |
-
result_text = result_map[result_code]
|
72 |
-
# Explicitly add text names and result text in the main data line
|
73 |
-
formatted_data += f"R{r_num}: Player={move_names[p_move]}({p_move}), AI={move_names[ai_move]}({ai_move}), Result={result_text}\n"
|
74 |
-
|
75 |
-
formatted_data += "\nSummary:\n"
|
76 |
-
formatted_data += f"Player wins: {game['summary']['player_wins']}\n"
|
77 |
-
formatted_data += f"AI wins: {game['summary']['ai_wins']}\n"
|
78 |
-
formatted_data += f"Ties: {game['summary']['ties']}\n\n"
|
79 |
-
|
80 |
-
formatted_data += "Player move frequencies:\n"
|
81 |
-
total_rounds = len(game["rounds"])
|
82 |
-
for move, count in player_moves.items():
|
83 |
-
percentage = round((count / total_rounds) * 100) if total_rounds > 0 else 0
|
84 |
-
formatted_data += f"{move}: {count} times ({percentage}%)\n"
|
85 |
-
return formatted_data
|
86 |
-
|
87 |
-
def format_rps_numeric(game_data):
|
88 |
-
"""Format the RPS data in a highly compressed numeric format"""
|
89 |
-
game = game_data["data"]
|
90 |
-
formatted_data = "RPS Game Data (compressed format)\n"
|
91 |
-
formatted_data += f"Rules: {game['rules']}\n\n"
|
92 |
-
rounds_str = ",".join([str(r) for r in game['rounds']])
|
93 |
-
formatted_data += f"Rounds: {rounds_str}\n\n"
|
94 |
-
formatted_data += f"Score: Player={game['score']['P']} AI={game['score']['AI']} Ties={game['score']['Tie']}\n"
|
95 |
-
return formatted_data
|
96 |
-
|
97 |
-
FORMAT_FUNCTIONS = {
|
98 |
-
"rps_simple": format_rps_simple,
|
99 |
-
"rps_numeric": format_rps_numeric
|
100 |
-
}
|
101 |
-
|
102 |
-
# --- Generation Function (Using Chat Template) ---
|
103 |
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
|
104 |
"""Generate a response from the Qwen2 model using chat template."""
|
105 |
try:
|
@@ -129,106 +46,142 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
|
|
129 |
print(f"Error during generation: {e}")
|
130 |
return f"An error occurred: {str(e)}"
|
131 |
|
132 |
-
# --- Input Processing Function (
|
133 |
def process_input(
|
134 |
-
|
135 |
-
|
136 |
-
custom_prompt,
|
137 |
-
use_custom_prompt,
|
138 |
system_prompt,
|
|
|
139 |
max_length,
|
140 |
temperature,
|
141 |
top_p
|
142 |
):
|
143 |
-
"""Process
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
148 |
messages = []
|
149 |
-
if system_prompt and system_prompt.strip():
|
150 |
messages.append({"role": "system", "content": system_prompt})
|
151 |
messages.append({"role": "user", "content": user_content})
|
|
|
|
|
152 |
response = generate_response(
|
153 |
messages,
|
154 |
max_length=max_length,
|
155 |
temperature=temperature,
|
156 |
top_p=top_p
|
157 |
)
|
|
|
|
|
158 |
display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
|
159 |
-
return display_prompt, response
|
160 |
|
161 |
-
|
162 |
|
163 |
-
#
|
164 |
-
DEFAULT_SYSTEM_PROMPT = """You are a highly accurate and methodical Rock-Paper-Scissors (RPS) strategy analyst.
|
165 |
-
Your goal is to analyze the provided game history and give the user strategic advice for their next move.
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
6. **Recommendation:** Provide a single, clear recommendation (Rock, Paper, or Scissors) for the *next* round and justify it concisely based on your reasoning.
|
174 |
|
175 |
-
|
|
|
|
|
|
|
176 |
|
177 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
178 |
-
gr.Markdown(f"# {MODEL_ID} - RPS
|
179 |
-
gr.Markdown("Test how the model
|
180 |
|
181 |
with gr.Row():
|
182 |
-
with gr.Column():
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
)
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
label="System Prompt (Optional)",
|
189 |
-
|
190 |
-
|
191 |
-
lines=
|
192 |
)
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
label="Custom Prompt (if Use Custom Prompt is checked)",
|
200 |
-
placeholder="Enter your custom prompt/question here", lines=3
|
201 |
)
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
206 |
submit_btn = gr.Button("Generate Response", variant="primary")
|
207 |
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
final_prompt_display = gr.Textbox(
|
210 |
label="Formatted Input Sent to Model (via Chat Template)", lines=15
|
211 |
)
|
|
|
212 |
response_display = gr.Textbox(
|
213 |
label="Model Response", lines=15, show_copy_button=True
|
214 |
)
|
215 |
-
gr.Markdown("""
|
216 |
-
## Testing Tips
|
217 |
-
- **Game Data Format**: Selects how history is structured. 'rps_simple' uses the improved format now.
|
218 |
-
- **System Prompt**: Crucial for setting the AI's role and desired output style. The default is now much more detailed.
|
219 |
-
- **Prompt Template / Custom Prompt**: Asks the specific question.
|
220 |
-
- **Generation Params**: Try lowering `Temperature` (e.g., to 0.3-0.5) for more factual, less random output.
|
221 |
-
- **Chat Template**: This version uses the model's chat template correctly.
|
222 |
-
""")
|
223 |
|
|
|
|
|
224 |
submit_btn.click(
|
225 |
process_input,
|
226 |
inputs=[
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
229 |
],
|
230 |
outputs=[final_prompt_display, response_display],
|
231 |
-
api_name="
|
232 |
)
|
233 |
|
234 |
# --- Launch the demo ---
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
5 |
# --- Configuration ---
|
|
|
15 |
)
|
16 |
print("Model loaded successfully.")
|
17 |
|
18 |
+
|
19 |
+
# --- Generation Function (Using Chat Template - No changes needed here) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
|
21 |
"""Generate a response from the Qwen2 model using chat template."""
|
22 |
try:
|
|
|
46 |
print(f"Error during generation: {e}")
|
47 |
return f"An error occurred: {str(e)}"
|
48 |
|
49 |
+
# --- Input Processing Function (Simplified for Frequency Stats) ---
|
50 |
def process_input(
|
51 |
+
player_stats,
|
52 |
+
ai_stats,
|
|
|
|
|
53 |
system_prompt,
|
54 |
+
user_query, # Changed from prompt template/custom
|
55 |
max_length,
|
56 |
temperature,
|
57 |
top_p
|
58 |
):
|
59 |
+
"""Process frequency stats and user query for the model."""
|
60 |
+
|
61 |
+
# Construct the user message content using the provided stats and query
|
62 |
+
user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
|
63 |
+
if ai_stats and ai_stats.strip(): # Include AI stats if provided
|
64 |
+
user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
|
65 |
+
user_content += f"User Query:\n{user_query}"
|
66 |
+
|
67 |
+
# Create the messages list for the chat template
|
68 |
messages = []
|
69 |
+
if system_prompt and system_prompt.strip(): # Add system prompt if provided
|
70 |
messages.append({"role": "system", "content": system_prompt})
|
71 |
messages.append({"role": "user", "content": user_content})
|
72 |
+
|
73 |
+
# Generate response from the model
|
74 |
response = generate_response(
|
75 |
messages,
|
76 |
max_length=max_length,
|
77 |
temperature=temperature,
|
78 |
top_p=top_p
|
79 |
)
|
80 |
+
|
81 |
+
# For display purposes, show the constructed input
|
82 |
display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
|
|
|
83 |
|
84 |
+
return display_prompt, response
|
85 |
|
86 |
+
# --- Gradio Interface (Simplified for Frequency Stats) ---
|
|
|
|
|
87 |
|
88 |
+
# Define a default system prompt suitable for frequency analysis
|
89 |
+
DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist.
|
90 |
+
Analyze the provided frequency statistics for the player's (and potentially AI's) past moves.
|
91 |
+
Based *only* on these statistics, determine the statistically optimal counter-strategy or recommendation for the AI's next move.
|
92 |
+
Explain your reasoning clearly based on the probabilities implied by the frequencies and the rules of RPS (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
|
93 |
+
Provide a clear recommendation (Rock, Paper, or Scissors) and justify it using expected outcomes or probabilities."""
|
|
|
94 |
|
95 |
+
# Default example stats
|
96 |
+
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
|
97 |
+
DEFAULT_AI_STATS = "Rock: 33%\nPaper: 34%\nScissors: 33%" # Example AI stats
|
98 |
+
DEFAULT_USER_QUERY = "Based on the player's move frequencies, what move should the AI make next to maximize its statistical chance of winning? Explain your reasoning."
|
99 |
|
100 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
101 |
+
gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
|
102 |
+
gr.Markdown("Test how the model provides strategic advice based *only* on Player and AI move frequency statistics.")
|
103 |
|
104 |
with gr.Row():
|
105 |
+
with gr.Column(scale=2): # Make input column wider
|
106 |
+
# Input for Player Stats
|
107 |
+
player_stats_input = gr.Textbox(
|
108 |
+
label="Player Move Frequency Stats",
|
109 |
+
value=DEFAULT_PLAYER_STATS,
|
110 |
+
lines=4,
|
111 |
+
info="Enter the observed frequencies of the player's moves."
|
112 |
)
|
113 |
+
# Input for AI Stats (Optional)
|
114 |
+
ai_stats_input = gr.Textbox(
|
115 |
+
label="AI Move Frequency Stats (Optional)",
|
116 |
+
value=DEFAULT_AI_STATS,
|
117 |
+
lines=4,
|
118 |
+
info="Optionally, enter the AI's own move frequencies if relevant."
|
119 |
+
)
|
120 |
+
# Input for User Query
|
121 |
+
user_query_input = gr.Textbox(
|
122 |
+
label="Your Query / Instruction",
|
123 |
+
value=DEFAULT_USER_QUERY,
|
124 |
+
lines=3,
|
125 |
+
info="Ask the specific question based on the frequency stats."
|
126 |
+
)
|
127 |
+
# System prompt (optional)
|
128 |
+
system_prompt_input = gr.Textbox(
|
129 |
label="System Prompt (Optional)",
|
130 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
131 |
+
placeholder="Define the AI's role and task based on frequency stats...",
|
132 |
+
lines=10 # Reduced lines needed
|
133 |
)
|
134 |
+
|
135 |
+
with gr.Column(scale=1): # Make params/output column narrower
|
136 |
+
# Generation parameters
|
137 |
+
gr.Markdown("## Generation Parameters")
|
138 |
+
max_length_slider = gr.Slider(
|
139 |
+
minimum=50, maximum=1024, value=350, step=16, label="Max New Tokens" # Reduced default length needed
|
|
|
|
|
140 |
)
|
141 |
+
temperature_slider = gr.Slider(
|
142 |
+
minimum=0.1, maximum=1.5, value=0.5, step=0.05, label="Temperature" # Defaulting lower for stats analysis
|
143 |
+
)
|
144 |
+
top_p_slider = gr.Slider(
|
145 |
+
minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P"
|
146 |
+
)
|
147 |
+
# Generate button
|
148 |
submit_btn = gr.Button("Generate Response", variant="primary")
|
149 |
|
150 |
+
# Tips for using the interface
|
151 |
+
gr.Markdown("""
|
152 |
+
## Testing Tips
|
153 |
+
- Input player move frequencies directly. AI stats are optional.
|
154 |
+
- Refine the **User Query** to guide the model's task.
|
155 |
+
- Adjust the **System Prompt** for role/task definition.
|
156 |
+
- Use lower **Temperature** for more deterministic, calculation-like responses based on stats.
|
157 |
+
""")
|
158 |
+
|
159 |
+
with gr.Row():
|
160 |
+
with gr.Column():
|
161 |
+
# Display final prompt and model response
|
162 |
final_prompt_display = gr.Textbox(
|
163 |
label="Formatted Input Sent to Model (via Chat Template)", lines=15
|
164 |
)
|
165 |
+
with gr.Column():
|
166 |
response_display = gr.Textbox(
|
167 |
label="Model Response", lines=15, show_copy_button=True
|
168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
|
171 |
+
# Handle button click - Updated inputs list
|
172 |
submit_btn.click(
|
173 |
process_input,
|
174 |
inputs=[
|
175 |
+
player_stats_input,
|
176 |
+
ai_stats_input,
|
177 |
+
system_prompt_input,
|
178 |
+
user_query_input, # New input
|
179 |
+
max_length_slider,
|
180 |
+
temperature_slider,
|
181 |
+
top_p_slider
|
182 |
],
|
183 |
outputs=[final_prompt_display, response_display],
|
184 |
+
api_name="generate_rps_frequency_analysis" # Updated api_name
|
185 |
)
|
186 |
|
187 |
# --- Launch the demo ---
|