rui3000 commited on
Commit
24a51a6
·
verified ·
1 Parent(s): 5f32a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -84
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # --- Configuration ---
@@ -16,9 +17,10 @@ model = AutoModelForCausalLM.from_pretrained(
16
  print("Model loaded successfully.")
17
 
18
 
19
- # --- Generation Function (Using Chat Template - No changes needed here) ---
20
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
21
- """Generate a response from the Qwen2 model using chat template."""
 
22
  try:
23
  prompt_text = tokenizer.apply_chat_template(
24
  messages,
@@ -26,6 +28,8 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
26
  add_generation_prompt=True
27
  )
28
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
 
 
29
  generation_kwargs = {
30
  "max_new_tokens": max_length,
31
  "temperature": temperature,
@@ -33,155 +37,151 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
33
  "do_sample": True,
34
  "pad_token_id": tokenizer.eos_token_id,
35
  }
 
36
  print("Generating response...")
37
  with torch.no_grad():
 
38
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
39
- input_ids_len = model_inputs.input_ids.shape[-1]
 
40
  output_ids = generated_ids[0, input_ids_len:]
 
 
41
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
42
  print("Generation complete.")
43
- return response.strip()
44
 
45
  except Exception as e:
46
  print(f"Error during generation: {e}")
47
- return f"An error occurred: {str(e)}"
48
 
49
- # --- Input Processing Function (Simplified for Frequency Stats) ---
50
  def process_input(
51
  player_stats,
52
  ai_stats,
53
  system_prompt,
54
- user_query, # Changed from prompt template/custom
55
  max_length,
56
  temperature,
57
  top_p
58
  ):
59
- """Process frequency stats and user query for the model."""
60
 
61
- # Construct the user message content using the provided stats and query
62
  user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
63
- if ai_stats and ai_stats.strip(): # Include AI stats if provided
64
  user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
65
  user_content += f"User Query:\n{user_query}"
66
 
67
- # Create the messages list for the chat template
68
  messages = []
69
- if system_prompt and system_prompt.strip(): # Add system prompt if provided
70
  messages.append({"role": "system", "content": system_prompt})
71
  messages.append({"role": "user", "content": user_content})
72
 
 
 
 
73
  # Generate response from the model
74
- response = generate_response(
75
  messages,
76
  max_length=max_length,
77
  temperature=temperature,
78
  top_p=top_p
79
  )
80
 
81
- # For display purposes, show the constructed input
 
 
 
 
82
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
83
 
84
- return display_prompt, response
 
 
 
85
 
86
- # --- Gradio Interface (Simplified for Frequency Stats) ---
 
 
87
 
88
- # Define a default system prompt suitable for frequency analysis
89
- DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist.
90
- Analyze the provided frequency statistics for the player's (and potentially AI's) past moves.
91
- Based *only* on these statistics, determine the statistically optimal counter-strategy or recommendation for the AI's next move.
92
- Explain your reasoning clearly based on the probabilities implied by the frequencies and the rules of RPS (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
93
- Provide a clear recommendation (Rock, Paper, or Scissors) and justify it using expected outcomes or probabilities."""
94
 
95
- # Default example stats
 
 
96
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
97
- DEFAULT_AI_STATS = "Rock: 33%\nPaper: 34%\nScissors: 33%" # Example AI stats
98
- DEFAULT_USER_QUERY = "Based on the player's move frequencies, what move should the AI make next to maximize its statistical chance of winning? Explain your reasoning."
99
 
100
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
  gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
102
- gr.Markdown("Test how the model provides strategic advice based *only* on Player and AI move frequency statistics.")
103
 
104
  with gr.Row():
105
- with gr.Column(scale=2): # Make input column wider
106
- # Input for Player Stats
107
  player_stats_input = gr.Textbox(
108
- label="Player Move Frequency Stats",
109
- value=DEFAULT_PLAYER_STATS,
110
- lines=4,
111
- info="Enter the observed frequencies of the player's moves."
112
  )
113
- # Input for AI Stats (Optional)
114
  ai_stats_input = gr.Textbox(
115
- label="AI Move Frequency Stats (Optional)",
116
- value=DEFAULT_AI_STATS,
117
- lines=4,
118
- info="Optionally, enter the AI's own move frequencies if relevant."
119
  )
120
- # Input for User Query
121
  user_query_input = gr.Textbox(
122
- label="Your Query / Instruction",
123
- value=DEFAULT_USER_QUERY,
124
- lines=3,
125
- info="Ask the specific question based on the frequency stats."
126
  )
127
- # System prompt (optional)
128
  system_prompt_input = gr.Textbox(
129
- label="System Prompt (Optional)",
130
- value=DEFAULT_SYSTEM_PROMPT,
131
- placeholder="Define the AI's role and task based on frequency stats...",
132
- lines=10 # Reduced lines needed
133
  )
134
 
135
- with gr.Column(scale=1): # Make params/output column narrower
136
- # Generation parameters
137
  gr.Markdown("## Generation Parameters")
138
- max_length_slider = gr.Slider(
139
- minimum=50, maximum=1024, value=350, step=16, label="Max New Tokens" # Reduced default length needed
140
- )
141
- temperature_slider = gr.Slider(
142
- minimum=0.1, maximum=1.5, value=0.5, step=0.05, label="Temperature" # Defaulting lower for stats analysis
143
- )
144
- top_p_slider = gr.Slider(
145
- minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P"
146
- )
147
- # Generate button
148
  submit_btn = gr.Button("Generate Response", variant="primary")
149
 
150
- # Tips for using the interface
 
 
 
 
151
  gr.Markdown("""
152
  ## Testing Tips
153
- - Input player move frequencies directly. AI stats are optional.
154
- - Refine the **User Query** to guide the model's task.
155
- - Adjust the **System Prompt** for role/task definition.
156
- - Use lower **Temperature** for more deterministic, calculation-like responses based on stats.
157
  """)
158
 
159
  with gr.Row():
160
- with gr.Column():
161
- # Display final prompt and model response
162
- final_prompt_display = gr.Textbox(
163
- label="Formatted Input Sent to Model (via Chat Template)", lines=15
164
- )
165
- with gr.Column():
166
- response_display = gr.Textbox(
167
- label="Model Response", lines=15, show_copy_button=True
168
- )
169
-
170
 
171
- # Handle button click - Updated inputs list
172
  submit_btn.click(
173
  process_input,
174
  inputs=[
175
- player_stats_input,
176
- ai_stats_input,
177
- system_prompt_input,
178
- user_query_input, # New input
179
- max_length_slider,
180
- temperature_slider,
181
- top_p_slider
182
  ],
183
- outputs=[final_prompt_display, response_display],
184
- api_name="generate_rps_frequency_analysis" # Updated api_name
185
  )
186
 
187
  # --- Launch the demo ---
 
1
  import gradio as gr
2
  import torch
3
+ import time # Import time module
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # --- Configuration ---
 
17
  print("Model loaded successfully.")
18
 
19
 
20
+ # --- Generation Function (Updated to return token count) ---
21
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
22
+ """Generate a response and return it along with the number of generated tokens."""
23
+ num_generated_tokens = 0
24
  try:
25
  prompt_text = tokenizer.apply_chat_template(
26
  messages,
 
28
  add_generation_prompt=True
29
  )
30
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
31
+ input_ids_len = model_inputs.input_ids.shape[-1] # Length of input tokens
32
+
33
  generation_kwargs = {
34
  "max_new_tokens": max_length,
35
  "temperature": temperature,
 
37
  "do_sample": True,
38
  "pad_token_id": tokenizer.eos_token_id,
39
  }
40
+
41
  print("Generating response...")
42
  with torch.no_grad():
43
+ # Generate response - Ensure output_scores or similar isn't needed if just counting
44
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
45
+
46
+ # Calculate generated tokens
47
  output_ids = generated_ids[0, input_ids_len:]
48
+ num_generated_tokens = len(output_ids)
49
+
50
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
51
  print("Generation complete.")
52
+ return response.strip(), num_generated_tokens # Return response and token count
53
 
54
  except Exception as e:
55
  print(f"Error during generation: {e}")
56
+ return f"An error occurred: {str(e)}", num_generated_tokens # Return error and token count
57
 
58
+ # --- Input Processing Function (Updated for Time/Token outputs) ---
59
  def process_input(
60
  player_stats,
61
  ai_stats,
62
  system_prompt,
63
+ user_query,
64
  max_length,
65
  temperature,
66
  top_p
67
  ):
68
+ """Process inputs, generate response, and return display info, response, time, and token count."""
69
 
70
+ # Construct the user message content
71
  user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
72
+ if ai_stats and ai_stats.strip():
73
  user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
74
  user_content += f"User Query:\n{user_query}"
75
 
76
+ # Create the messages list
77
  messages = []
78
+ if system_prompt and system_prompt.strip():
79
  messages.append({"role": "system", "content": system_prompt})
80
  messages.append({"role": "user", "content": user_content})
81
 
82
+ # --- Time Measurement Start ---
83
+ start_time = time.time()
84
+
85
  # Generate response from the model
86
+ response, generated_tokens = generate_response( # Capture token count
87
  messages,
88
  max_length=max_length,
89
  temperature=temperature,
90
  top_p=top_p
91
  )
92
 
93
+ # --- Time Measurement End ---
94
+ end_time = time.time()
95
+ duration = round(end_time - start_time, 2) # Calculate duration
96
+
97
+ # For display purposes
98
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
99
 
100
+ # Return all results including time and tokens
101
+ return display_prompt, response, f"{duration} seconds", generated_tokens
102
+
103
+ # --- Gradio Interface (Added Time/Token displays, refined System Prompt) ---
104
 
105
+ # Refined default system prompt for better reasoning
106
+ DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
107
+ Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
108
 
109
+ Follow these steps:
110
+ 1. **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
111
+ 2. **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
112
+ 3. **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
113
+ 4. **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).
 
114
 
115
+ Base your analysis strictly on the provided frequencies and standard RPS rules."""
116
+
117
+ # Default example stats and query
118
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
119
+ DEFAULT_AI_STATS = "" # Keep AI stats optional and clear by default
120
+ DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
121
 
122
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
123
  gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
124
+ gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")
125
 
126
  with gr.Row():
127
+ with gr.Column(scale=2): # Input column
 
128
  player_stats_input = gr.Textbox(
129
+ label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
130
+ info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
 
 
131
  )
 
132
  ai_stats_input = gr.Textbox(
133
+ label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
134
+ info="Optionally, enter AI's own move frequencies."
 
 
135
  )
 
136
  user_query_input = gr.Textbox(
137
+ label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
138
+ info="Ask the specific question based on the stats."
 
 
139
  )
 
140
  system_prompt_input = gr.Textbox(
141
+ label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, # Set default value
142
+ lines=12 # Adjusted lines
 
 
143
  )
144
 
145
+ with gr.Column(scale=1): # Params/Output column
 
146
  gr.Markdown("## Generation Parameters")
147
+ max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
148
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") # Lowered default further
149
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
 
 
 
 
 
 
 
150
  submit_btn = gr.Button("Generate Response", variant="primary")
151
 
152
+ gr.Markdown("## Performance Metrics")
153
+ # Outputs for Time and Tokens
154
+ time_output = gr.Textbox(label="Generation Time", interactive=False)
155
+ tokens_output = gr.Number(label="Generated Tokens", interactive=False) # Use Number for token count
156
+
157
  gr.Markdown("""
158
  ## Testing Tips
159
+ - Focus on player stats for optimal counter strategy.
160
+ - Use the refined **System Prompt** for better reasoning guidance.
161
+ - Lower **Temperature** encourages more direct, statistical answers.
 
162
  """)
163
 
164
  with gr.Row():
165
+ # Display final prompt and model response (side-by-side)
166
+ final_prompt_display = gr.Textbox(
167
+ label="Formatted Input Sent to Model (via Chat Template)", lines=20 # Increased lines
168
+ )
169
+ response_display = gr.Textbox(
170
+ label="Model Response", lines=20, show_copy_button=True # Increased lines
171
+ )
 
 
 
172
 
173
+ # Handle button click - Updated inputs and outputs list
174
  submit_btn.click(
175
  process_input,
176
  inputs=[
177
+ player_stats_input, ai_stats_input, system_prompt_input,
178
+ user_query_input, max_length_slider, temperature_slider, top_p_slider
179
+ ],
180
+ outputs=[
181
+ final_prompt_display, response_display,
182
+ time_output, tokens_output # Added new outputs
 
183
  ],
184
+ api_name="generate_rps_frequency_analysis_v2" # Updated api_name
 
185
  )
186
 
187
  # --- Launch the demo ---