RPS_game_assist / app.py
rui3000's picture
Update app.py
9b05877 verified
import gradio as gr
import torch
import time
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
from db import init_db, save_test_result, get_test_history, get_test_details
# --- Initialize Database ---
db_initialized = init_db()
if not db_initialized:
print("WARNING: Database initialization failed. Test history will not be saved.")
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-Math-1.5B" # Replace with actual ID if found
# --- Load Model and Tokenizer ---
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto"
)
print("Model loaded successfully.")
# --- Generation Function (Returns response and token count) ---
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
"""Generate a response and return it along with the number of generated tokens."""
num_generated_tokens = 0
try:
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
device = model.device
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
input_ids_len = model_inputs.input_ids.shape[-1]
generation_kwargs = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
print("Generating response...")
with torch.no_grad():
generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
output_ids = generated_ids[0, input_ids_len:]
num_generated_tokens = len(output_ids)
response = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Generation complete.")
return response.strip(), num_generated_tokens
except Exception as e:
print(f"Error during generation: {e}")
return f"An error occurred: {str(e)}", num_generated_tokens
@spaces.GPU # Keep ZeroGPU decorator
def process_input(
analysis_mode, # Mode selector
player_stats,
player_behavior_input,
system_prompt, # Single system prompt from UI
max_length,
temperature,
top_p,
save_to_db=True # New parameter to toggle database saving
):
"""Process inputs based on selected analysis mode using the provided system prompt."""
print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}")
# Create the messages list using the system_prompt from the UI directly
messages = []
if system_prompt and system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
# Add content based on analysis mode (no empty content for any mode)
if analysis_mode == "Frequency Only":
user_content = f"Player Move Frequency Stats (Long-Term):\n{player_stats}"
messages.append({"role": "user", "content": user_content})
elif analysis_mode == "Behavior Analysis":
user_content = player_behavior_input
messages.append({"role": "user", "content": user_content})
else: # For Markov Prediction only mode
# Don't add any user message - let system prompt handle everything
user_content = ""
# Note: We're not appending an empty user message here
# --- Time Measurement Start ---
start_time = time.time()
# Generate response from the model
response, generated_tokens = generate_response(
messages,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
# --- Time Measurement End ---
end_time = time.time()
duration = round(end_time - start_time, 2)
# For display purposes - show what was actually sent to the model
if user_content:
display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
else:
display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}"
print(f"Processing finished in {duration} seconds.")
# Save to database if requested and if database is available
if save_to_db and db_initialized:
test_id = save_test_result(
analysis_mode=analysis_mode,
system_prompt=system_prompt,
input_content=user_content if user_content else "",
model_response=response,
generation_time=duration,
tokens_generated=generated_tokens,
temperature=temperature,
top_p=top_p,
max_length=max_length
)
if test_id:
print(f"Test saved to database with ID: {test_id}")
else:
print("Failed to save test to database")
# Return all results including time and tokens
return display_prompt, response, f"{duration} seconds", generated_tokens
# --- System Prompts (Defaults only, UI will hold the editable version) ---
DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats.
Follow these steps EXACTLY. Do NOT deviate.
Step 1: Identify Player's Most Frequent Move.
- Look ONLY at the 'Player Move Frequency Stats'.
- List the percentages: Rock (%), Paper (%), Scissors (%).
- State which move name has the highest percentage number.
Step 2: Determine the Counter Move using RPS Rules.
- REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper.
- Based *only* on the move identified in Step 1, state the single move name that beats it according to the rules. State the rule you used (e.g., "Paper beats Rock").
Step 3: Explain the Counter Choice.
- Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's most frequent move, [Most Frequent Move from Step 1]."
Step 4: State Final Recommendation.
- State *only* the recommended AI move name from Step 2. Example: "Recommendation: Paper"
Base your analysis strictly on the provided frequencies and the stated RPS rules.
"""
DEFAULT_SYSTEM_PROMPT_MARKOV = """You are analyzing a Rock-Paper-Scissors (RPS) game using a Markov transition matrix.
### TRANSITION MATRIX:
[
[0.20, 0.60, 0.20], # Row 0 (After Rock)
[0.30, 0.10, 0.60], # Row 1 (After Paper)
[0.50, 0.30, 0.20] # Row 2 (After Scissors)
]
### EXPLANATION:
- This matrix shows P(Next Move | Previous Move)
- Each row represents the previous move (0=Rock, 1=Paper, 2=Scissors)
- Each column represents the next move (0=Rock, 1=Paper, 2=Scissors)
- For example, entry [0,1]=0.60 means: after playing Rock, 60% chance of playing Paper next
### PLAYER INFORMATION:
- The player's last move was: Paper
- Our goal is to predict their most likely next move and determine our choice that counters the predicted move
### YOUR TASK:
1. Find the row in the matrix corresponding to the player's last move
2. From that row, identify which move has the highest probability value
3. That highest probability move is the player's predicted next move
4. Determine the optimal counter move using RPS rules:
* Rock beats Scissors
* Scissors beats Paper
* Paper beats Rock
### SHOW YOUR MATHEMATICAL WORK:
- Identify the correct row number for the player's last move
- Extract all probability values from that row
- Compare the numerical values to find the maximum
- Apply game rules to determine the counter move
### OUTPUT FORMAT:
Player's Last Move: [Move]
Probabilities: [List the probabilities]
Predicted Next Move: [Move with highest probability]
Optimal Counter: [Move that beats the predicted move]
"""
DEFAULT_SYSTEM_PROMPT_BEHAVIOR = """You are an RPS assistant analyzing player behavior after wins, losses, and ties. Predict the player's next move and give counter strategy based on the Behavioral probabilities.
**Behavioral Probabilities P(Change/not change | Win/Loss/Tie):**
* P(not change | Win) = 0.70
* P(Change | Win) = 0.30
* P(not change | Loss) = 0.25
* P(Change | Loss) = 0.75
* P(not change | Tie) = 0.50
* P(Change | Tie) = 0.50
**Input Provided by User:**
* Player's Last Outcome: [Win/Loss/Tie]
* Player's Last Move: [Rock/Paper/Scissors]
**Your Task:**
1. Based on the Player's Last Outcome, determine the **Predicted Behavior** by comparing P(not change | Win/Loss/Tie) and P(Change | Win/Loss/Tie).
2. Determine the **Player's Predicted Next Move**:
* If Predicted Behavior is "not change", predict the same move as Player's Last Move.
* If Predicted Behavior is "Change", predict a move different from Player's Last Move (randomly select between the two remaining options with equal probability).
3. Recommend the **AI Counter Move** that beats the predicted player move:
* Paper beats Rock
* Rock beats Scissors
* Scissors beats Paper
**Output Format:**
Predicted Behavior: [not change/Change] (Based on P(not change|Outcome)=[Prob], P(Change|Outcome)=[Prob])
Prediction Logic: [Brief explanation of your reasoning]
Predicted Player Move: [Rock/Paper/Scissors]
Recommended AI Counter: [Rock/Paper/Scissors]
"""
# --- Default Input Values ---
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_PLAYER_BEHAVIOR = "Player's Last Outcome: Win\nPlayer's Last Move: Rock"
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Tab("Model Testing"):
gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester with Test History")
gr.Markdown("Test model advice using Frequency Stats, Markov Predictions, or Win/Loss/Tie Behavior Analysis.")
# Mode Selector - now with three options
analysis_mode_selector = gr.Radio(
label="Select Analysis Mode",
choices=["Frequency Only", "Markov Prediction Only", "Behavior Analysis"],
value="Frequency Only" # Default mode
)
# --- Visible System Prompt Textbox ---
system_prompt_input = gr.Textbox(
label="System Prompt (Edit based on selected mode)",
value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt
lines=15
)
# Input Sections (conditionally visible)
with gr.Group(visible=True) as frequency_inputs: # Visible by default
gr.Markdown("### Frequency Analysis Inputs")
player_stats_input = gr.Textbox(
label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4,
info="Overall player move distribution."
)
with gr.Group(visible=False) as markov_inputs: # Hidden by default
gr.Markdown("### Markov Prediction Analysis Inputs")
gr.Markdown("*Use the System Prompt field to directly input your Markov analysis instructions.*")
# New behavior analysis inputs
with gr.Group(visible=False) as behavior_inputs:
gr.Markdown("### Win/Loss/Tie Behavior Analysis Inputs")
player_behavior_input = gr.Textbox(
label="Player's Last Outcome and Move", value=DEFAULT_PLAYER_BEHAVIOR, lines=4,
info="Enter the last outcome (Win/Loss/Tie) and move (Rock/Paper/Scissors)."
)
# General Inputs / Parameters / Outputs
with gr.Row():
with gr.Column():
gr.Markdown("#### Generation Parameters")
max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
# Add a checkbox to control saving to database
save_to_db_checkbox = gr.Checkbox(
label="Save this test to database",
value=True,
info="Store input and output in SQLite database for later reference"
)
submit_btn = gr.Button("Generate Response", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("#### Performance Metrics")
time_output = gr.Textbox(label="Generation Time", interactive=False)
tokens_output = gr.Number(label="Generated Tokens", interactive=False)
with gr.Column():
gr.Markdown("""
#### Testing Tips
- Select the desired **Analysis Mode**.
- Fill in the inputs for the **selected mode only**.
- **Edit the System Prompt** above as needed for testing.
- Use low **Temperature** for factual analysis.
""")
with gr.Row():
final_prompt_display = gr.Textbox(
label="Formatted Input Sent to Model (via Chat Template)", lines=20
)
response_display = gr.Textbox(
label="Model Response", lines=20, show_copy_button=True
)
# Add a new tab for test history
with gr.Tab("Test History"):
gr.Markdown("### Saved Test Results")
refresh_btn = gr.Button("Refresh History")
# Display test history as a dataframe
test_history_df = gr.Dataframe(
headers=["Test ID", "Analysis Mode", "Timestamp", "Generation Time", "Tokens"],
label="Recent Tests",
interactive=False
)
# Add a number input to load a specific test
test_id_input = gr.Number(
label="Test ID",
precision=0,
info="Enter a Test ID to load details"
)
load_test_btn = gr.Button("Load Test")
# Display test details
with gr.Group():
test_mode_display = gr.Textbox(label="Analysis Mode", interactive=False)
test_prompt_display = gr.Textbox(label="System Prompt", interactive=False, lines=8)
test_input_display = gr.Textbox(label="Input Content", interactive=False, lines=4)
test_response_display = gr.Textbox(label="Model Response", interactive=False, lines=8)
with gr.Row():
test_time_display = gr.Number(label="Generation Time (s)", interactive=False)
test_tokens_display = gr.Number(label="Tokens Generated", interactive=False)
test_temp_display = gr.Number(label="Temperature", interactive=False)
test_topp_display = gr.Number(label="Top P", interactive=False)
# --- Event Handlers ---
# Function to update UI visibility AND system prompt content based on mode selection
def update_ui_visibility_and_prompt(mode):
if mode == "Frequency Only":
return {
frequency_inputs: gr.update(visible=True),
markov_inputs: gr.update(visible=False),
behavior_inputs: gr.update(visible=False),
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) # Load Frequency prompt
}
elif mode == "Markov Prediction Only":
return {
frequency_inputs: gr.update(visible=False),
markov_inputs: gr.update(visible=True),
behavior_inputs: gr.update(visible=False),
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_MARKOV) # Load Markov prompt
}
elif mode == "Behavior Analysis":
return {
frequency_inputs: gr.update(visible=False),
markov_inputs: gr.update(visible=False),
behavior_inputs: gr.update(visible=True),
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_BEHAVIOR) # Load Behavior prompt
}
else: # Default case
return {
frequency_inputs: gr.update(visible=True),
markov_inputs: gr.update(visible=False),
behavior_inputs: gr.update(visible=False),
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ)
}
# Function to update test history display
def update_test_history():
if db_initialized:
history = get_test_history(limit=20)
return [[h[0], h[1], h[2], h[3], h[4]] for h in history]
else:
return [["N/A", "Database Not Available", "N/A", 0, 0]]
# Function to load test details
def load_test_details(test_id):
if not db_initialized:
return ["Database Not Available", "", "", "", 0, 0, 0, 0]
test = get_test_details(test_id)
if test:
return [
test["analysis_mode"],
test["system_prompt"],
test["input_content"] or "",
test["model_response"],
test["generation_time"],
test["tokens_generated"],
test["temperature"],
test["top_p"]
]
else:
return ["Test not found", "", "", "", 0, 0, 0, 0]
# Link the radio button change to the UI update function
analysis_mode_selector.change(
fn=update_ui_visibility_and_prompt, # Use the combined update function
inputs=analysis_mode_selector,
outputs=[frequency_inputs, markov_inputs, behavior_inputs, system_prompt_input] # Components to update
)
# Handle button click - Pass the single visible system prompt
submit_btn.click(
process_input,
inputs=[
analysis_mode_selector,
player_stats_input,
player_behavior_input,
system_prompt_input, # Pass the visible system prompt textbox
max_length_slider,
temperature_slider,
top_p_slider,
save_to_db_checkbox # Pass the checkbox value
],
outputs=[
final_prompt_display, response_display,
time_output, tokens_output
]
)
# Connect buttons for test history tab
refresh_btn.click(
update_test_history,
outputs=[test_history_df]
)
load_test_btn.click(
load_test_details,
inputs=[test_id_input],
outputs=[
test_mode_display, test_prompt_display, test_input_display,
test_response_display, test_time_display, test_tokens_display,
test_temp_display, test_topp_display
]
)
# Initialize history on page load
demo.load(update_test_history, outputs=[test_history_df])
# --- Launch the demo ---
if __name__ == "__main__":
demo.launch()