Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import time | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from db import init_db, save_test_result, get_test_history, get_test_details | |
# --- Initialize Database --- | |
db_initialized = init_db() | |
if not db_initialized: | |
print("WARNING: Database initialization failed. Test history will not be saved.") | |
# --- Configuration --- | |
MODEL_ID = "Qwen/Qwen2.5-Math-1.5B" # Replace with actual ID if found | |
# --- Load Model and Tokenizer --- | |
print(f"Loading model: {MODEL_ID}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
print("Model loaded successfully.") | |
# --- Generation Function (Returns response and token count) --- | |
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9): | |
"""Generate a response and return it along with the number of generated tokens.""" | |
num_generated_tokens = 0 | |
try: | |
prompt_text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
device = model.device | |
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device) | |
input_ids_len = model_inputs.input_ids.shape[-1] | |
generation_kwargs = { | |
"max_new_tokens": max_length, | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
} | |
print("Generating response...") | |
with torch.no_grad(): | |
generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs) | |
output_ids = generated_ids[0, input_ids_len:] | |
num_generated_tokens = len(output_ids) | |
response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
print("Generation complete.") | |
return response.strip(), num_generated_tokens | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return f"An error occurred: {str(e)}", num_generated_tokens | |
# Keep ZeroGPU decorator | |
def process_input( | |
analysis_mode, # Mode selector | |
player_stats, | |
player_behavior_input, | |
system_prompt, # Single system prompt from UI | |
max_length, | |
temperature, | |
top_p, | |
save_to_db=True # New parameter to toggle database saving | |
): | |
"""Process inputs based on selected analysis mode using the provided system prompt.""" | |
print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}") | |
# Create the messages list using the system_prompt from the UI directly | |
messages = [] | |
if system_prompt and system_prompt.strip(): | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add content based on analysis mode (no empty content for any mode) | |
if analysis_mode == "Frequency Only": | |
user_content = f"Player Move Frequency Stats (Long-Term):\n{player_stats}" | |
messages.append({"role": "user", "content": user_content}) | |
elif analysis_mode == "Behavior Analysis": | |
user_content = player_behavior_input | |
messages.append({"role": "user", "content": user_content}) | |
else: # For Markov Prediction only mode | |
# Don't add any user message - let system prompt handle everything | |
user_content = "" | |
# Note: We're not appending an empty user message here | |
# --- Time Measurement Start --- | |
start_time = time.time() | |
# Generate response from the model | |
response, generated_tokens = generate_response( | |
messages, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p | |
) | |
# --- Time Measurement End --- | |
end_time = time.time() | |
duration = round(end_time - start_time, 2) | |
# For display purposes - show what was actually sent to the model | |
if user_content: | |
display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}" | |
else: | |
display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}" | |
print(f"Processing finished in {duration} seconds.") | |
# Save to database if requested and if database is available | |
if save_to_db and db_initialized: | |
test_id = save_test_result( | |
analysis_mode=analysis_mode, | |
system_prompt=system_prompt, | |
input_content=user_content if user_content else "", | |
model_response=response, | |
generation_time=duration, | |
tokens_generated=generated_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
max_length=max_length | |
) | |
if test_id: | |
print(f"Test saved to database with ID: {test_id}") | |
else: | |
print("Failed to save test to database") | |
# Return all results including time and tokens | |
return display_prompt, response, f"{duration} seconds", generated_tokens | |
# --- System Prompts (Defaults only, UI will hold the editable version) --- | |
DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats. | |
Follow these steps EXACTLY. Do NOT deviate. | |
Step 1: Identify Player's Most Frequent Move. | |
- Look ONLY at the 'Player Move Frequency Stats'. | |
- List the percentages: Rock (%), Paper (%), Scissors (%). | |
- State which move name has the highest percentage number. | |
Step 2: Determine the Counter Move using RPS Rules. | |
- REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper. | |
- Based *only* on the move identified in Step 1, state the single move name that beats it according to the rules. State the rule you used (e.g., "Paper beats Rock"). | |
Step 3: Explain the Counter Choice. | |
- Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's most frequent move, [Most Frequent Move from Step 1]." | |
Step 4: State Final Recommendation. | |
- State *only* the recommended AI move name from Step 2. Example: "Recommendation: Paper" | |
Base your analysis strictly on the provided frequencies and the stated RPS rules. | |
""" | |
DEFAULT_SYSTEM_PROMPT_MARKOV = """You are analyzing a Rock-Paper-Scissors (RPS) game using a Markov transition matrix. | |
### TRANSITION MATRIX: | |
[ | |
[0.20, 0.60, 0.20], # Row 0 (After Rock) | |
[0.30, 0.10, 0.60], # Row 1 (After Paper) | |
[0.50, 0.30, 0.20] # Row 2 (After Scissors) | |
] | |
### EXPLANATION: | |
- This matrix shows P(Next Move | Previous Move) | |
- Each row represents the previous move (0=Rock, 1=Paper, 2=Scissors) | |
- Each column represents the next move (0=Rock, 1=Paper, 2=Scissors) | |
- For example, entry [0,1]=0.60 means: after playing Rock, 60% chance of playing Paper next | |
### PLAYER INFORMATION: | |
- The player's last move was: Paper | |
- Our goal is to predict their most likely next move and determine our choice that counters the predicted move | |
### YOUR TASK: | |
1. Find the row in the matrix corresponding to the player's last move | |
2. From that row, identify which move has the highest probability value | |
3. That highest probability move is the player's predicted next move | |
4. Determine the optimal counter move using RPS rules: | |
* Rock beats Scissors | |
* Scissors beats Paper | |
* Paper beats Rock | |
### SHOW YOUR MATHEMATICAL WORK: | |
- Identify the correct row number for the player's last move | |
- Extract all probability values from that row | |
- Compare the numerical values to find the maximum | |
- Apply game rules to determine the counter move | |
### OUTPUT FORMAT: | |
Player's Last Move: [Move] | |
Probabilities: [List the probabilities] | |
Predicted Next Move: [Move with highest probability] | |
Optimal Counter: [Move that beats the predicted move] | |
""" | |
DEFAULT_SYSTEM_PROMPT_BEHAVIOR = """You are an RPS assistant analyzing player behavior after wins, losses, and ties. Predict the player's next move and give counter strategy based on the Behavioral probabilities. | |
**Behavioral Probabilities P(Change/not change | Win/Loss/Tie):** | |
* P(not change | Win) = 0.70 | |
* P(Change | Win) = 0.30 | |
* P(not change | Loss) = 0.25 | |
* P(Change | Loss) = 0.75 | |
* P(not change | Tie) = 0.50 | |
* P(Change | Tie) = 0.50 | |
**Input Provided by User:** | |
* Player's Last Outcome: [Win/Loss/Tie] | |
* Player's Last Move: [Rock/Paper/Scissors] | |
**Your Task:** | |
1. Based on the Player's Last Outcome, determine the **Predicted Behavior** by comparing P(not change | Win/Loss/Tie) and P(Change | Win/Loss/Tie). | |
2. Determine the **Player's Predicted Next Move**: | |
* If Predicted Behavior is "not change", predict the same move as Player's Last Move. | |
* If Predicted Behavior is "Change", predict a move different from Player's Last Move (randomly select between the two remaining options with equal probability). | |
3. Recommend the **AI Counter Move** that beats the predicted player move: | |
* Paper beats Rock | |
* Rock beats Scissors | |
* Scissors beats Paper | |
**Output Format:** | |
Predicted Behavior: [not change/Change] (Based on P(not change|Outcome)=[Prob], P(Change|Outcome)=[Prob]) | |
Prediction Logic: [Brief explanation of your reasoning] | |
Predicted Player Move: [Rock/Paper/Scissors] | |
Recommended AI Counter: [Rock/Paper/Scissors] | |
""" | |
# --- Default Input Values --- | |
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%" | |
DEFAULT_PLAYER_BEHAVIOR = "Player's Last Outcome: Win\nPlayer's Last Move: Rock" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Tab("Model Testing"): | |
gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester with Test History") | |
gr.Markdown("Test model advice using Frequency Stats, Markov Predictions, or Win/Loss/Tie Behavior Analysis.") | |
# Mode Selector - now with three options | |
analysis_mode_selector = gr.Radio( | |
label="Select Analysis Mode", | |
choices=["Frequency Only", "Markov Prediction Only", "Behavior Analysis"], | |
value="Frequency Only" # Default mode | |
) | |
# --- Visible System Prompt Textbox --- | |
system_prompt_input = gr.Textbox( | |
label="System Prompt (Edit based on selected mode)", | |
value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt | |
lines=15 | |
) | |
# Input Sections (conditionally visible) | |
with gr.Group(visible=True) as frequency_inputs: # Visible by default | |
gr.Markdown("### Frequency Analysis Inputs") | |
player_stats_input = gr.Textbox( | |
label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4, | |
info="Overall player move distribution." | |
) | |
with gr.Group(visible=False) as markov_inputs: # Hidden by default | |
gr.Markdown("### Markov Prediction Analysis Inputs") | |
gr.Markdown("*Use the System Prompt field to directly input your Markov analysis instructions.*") | |
# New behavior analysis inputs | |
with gr.Group(visible=False) as behavior_inputs: | |
gr.Markdown("### Win/Loss/Tie Behavior Analysis Inputs") | |
player_behavior_input = gr.Textbox( | |
label="Player's Last Outcome and Move", value=DEFAULT_PLAYER_BEHAVIOR, lines=4, | |
info="Enter the last outcome (Win/Loss/Tie) and move (Rock/Paper/Scissors)." | |
) | |
# General Inputs / Parameters / Outputs | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("#### Generation Parameters") | |
max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens") | |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
# Add a checkbox to control saving to database | |
save_to_db_checkbox = gr.Checkbox( | |
label="Save this test to database", | |
value=True, | |
info="Store input and output in SQLite database for later reference" | |
) | |
submit_btn = gr.Button("Generate Response", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("#### Performance Metrics") | |
time_output = gr.Textbox(label="Generation Time", interactive=False) | |
tokens_output = gr.Number(label="Generated Tokens", interactive=False) | |
with gr.Column(): | |
gr.Markdown(""" | |
#### Testing Tips | |
- Select the desired **Analysis Mode**. | |
- Fill in the inputs for the **selected mode only**. | |
- **Edit the System Prompt** above as needed for testing. | |
- Use low **Temperature** for factual analysis. | |
""") | |
with gr.Row(): | |
final_prompt_display = gr.Textbox( | |
label="Formatted Input Sent to Model (via Chat Template)", lines=20 | |
) | |
response_display = gr.Textbox( | |
label="Model Response", lines=20, show_copy_button=True | |
) | |
# Add a new tab for test history | |
with gr.Tab("Test History"): | |
gr.Markdown("### Saved Test Results") | |
refresh_btn = gr.Button("Refresh History") | |
# Display test history as a dataframe | |
test_history_df = gr.Dataframe( | |
headers=["Test ID", "Analysis Mode", "Timestamp", "Generation Time", "Tokens"], | |
label="Recent Tests", | |
interactive=False | |
) | |
# Add a number input to load a specific test | |
test_id_input = gr.Number( | |
label="Test ID", | |
precision=0, | |
info="Enter a Test ID to load details" | |
) | |
load_test_btn = gr.Button("Load Test") | |
# Display test details | |
with gr.Group(): | |
test_mode_display = gr.Textbox(label="Analysis Mode", interactive=False) | |
test_prompt_display = gr.Textbox(label="System Prompt", interactive=False, lines=8) | |
test_input_display = gr.Textbox(label="Input Content", interactive=False, lines=4) | |
test_response_display = gr.Textbox(label="Model Response", interactive=False, lines=8) | |
with gr.Row(): | |
test_time_display = gr.Number(label="Generation Time (s)", interactive=False) | |
test_tokens_display = gr.Number(label="Tokens Generated", interactive=False) | |
test_temp_display = gr.Number(label="Temperature", interactive=False) | |
test_topp_display = gr.Number(label="Top P", interactive=False) | |
# --- Event Handlers --- | |
# Function to update UI visibility AND system prompt content based on mode selection | |
def update_ui_visibility_and_prompt(mode): | |
if mode == "Frequency Only": | |
return { | |
frequency_inputs: gr.update(visible=True), | |
markov_inputs: gr.update(visible=False), | |
behavior_inputs: gr.update(visible=False), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) # Load Frequency prompt | |
} | |
elif mode == "Markov Prediction Only": | |
return { | |
frequency_inputs: gr.update(visible=False), | |
markov_inputs: gr.update(visible=True), | |
behavior_inputs: gr.update(visible=False), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_MARKOV) # Load Markov prompt | |
} | |
elif mode == "Behavior Analysis": | |
return { | |
frequency_inputs: gr.update(visible=False), | |
markov_inputs: gr.update(visible=False), | |
behavior_inputs: gr.update(visible=True), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_BEHAVIOR) # Load Behavior prompt | |
} | |
else: # Default case | |
return { | |
frequency_inputs: gr.update(visible=True), | |
markov_inputs: gr.update(visible=False), | |
behavior_inputs: gr.update(visible=False), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) | |
} | |
# Function to update test history display | |
def update_test_history(): | |
if db_initialized: | |
history = get_test_history(limit=20) | |
return [[h[0], h[1], h[2], h[3], h[4]] for h in history] | |
else: | |
return [["N/A", "Database Not Available", "N/A", 0, 0]] | |
# Function to load test details | |
def load_test_details(test_id): | |
if not db_initialized: | |
return ["Database Not Available", "", "", "", 0, 0, 0, 0] | |
test = get_test_details(test_id) | |
if test: | |
return [ | |
test["analysis_mode"], | |
test["system_prompt"], | |
test["input_content"] or "", | |
test["model_response"], | |
test["generation_time"], | |
test["tokens_generated"], | |
test["temperature"], | |
test["top_p"] | |
] | |
else: | |
return ["Test not found", "", "", "", 0, 0, 0, 0] | |
# Link the radio button change to the UI update function | |
analysis_mode_selector.change( | |
fn=update_ui_visibility_and_prompt, # Use the combined update function | |
inputs=analysis_mode_selector, | |
outputs=[frequency_inputs, markov_inputs, behavior_inputs, system_prompt_input] # Components to update | |
) | |
# Handle button click - Pass the single visible system prompt | |
submit_btn.click( | |
process_input, | |
inputs=[ | |
analysis_mode_selector, | |
player_stats_input, | |
player_behavior_input, | |
system_prompt_input, # Pass the visible system prompt textbox | |
max_length_slider, | |
temperature_slider, | |
top_p_slider, | |
save_to_db_checkbox # Pass the checkbox value | |
], | |
outputs=[ | |
final_prompt_display, response_display, | |
time_output, tokens_output | |
] | |
) | |
# Connect buttons for test history tab | |
refresh_btn.click( | |
update_test_history, | |
outputs=[test_history_df] | |
) | |
load_test_btn.click( | |
load_test_details, | |
inputs=[test_id_input], | |
outputs=[ | |
test_mode_display, test_prompt_display, test_input_display, | |
test_response_display, test_time_display, test_tokens_display, | |
test_temp_display, test_topp_display | |
] | |
) | |
# Initialize history on page load | |
demo.load(update_test_history, outputs=[test_history_df]) | |
# --- Launch the demo --- | |
if __name__ == "__main__": | |
demo.launch() |