Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import time | |
import spaces # Import the spaces library | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# --- Configuration --- | |
MODEL_ID = "Qwen/Qwen2.5-Math-1.5B" # Replace with actual ID if found | |
# --- Load Model and Tokenizer --- | |
print(f"Loading model: {MODEL_ID}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
print("Model loaded successfully.") | |
# --- Generation Function (Returns response and token count) --- | |
# No changes needed here | |
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9): | |
"""Generate a response and return it along with the number of generated tokens.""" | |
num_generated_tokens = 0 | |
try: | |
prompt_text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
device = model.device | |
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device) | |
input_ids_len = model_inputs.input_ids.shape[-1] | |
generation_kwargs = { | |
"max_new_tokens": max_length, | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
} | |
print("Generating response...") | |
with torch.no_grad(): | |
generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs) | |
output_ids = generated_ids[0, input_ids_len:] | |
num_generated_tokens = len(output_ids) | |
response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
print("Generation complete.") | |
return response.strip(), num_generated_tokens | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return f"An error occurred: {str(e)}", num_generated_tokens | |
# --- Input Processing Function (Takes single system prompt) --- | |
# Keep ZeroGPU decorator | |
def process_input( | |
analysis_mode, # Mode selector | |
player_stats, | |
player_last_move, | |
markov_prediction_text, | |
system_prompt, # Single system prompt from UI | |
user_query, | |
max_length, | |
temperature, | |
top_p | |
): | |
"""Process inputs based on selected analysis mode using the provided system prompt.""" | |
print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}") | |
# Construct user content based on mode | |
if analysis_mode == "Frequency Only": | |
user_content = f"Player Move Frequency Stats (Long-Term):\n{player_stats}\n\n" | |
user_content += f"User Query:\n{user_query}" | |
elif analysis_mode == "Markov Prediction Only": | |
user_content = f"Player's Last Move:\n{player_last_move}\n\n" | |
user_content += f"Predicted Next Move (Short-Term Markov Analysis):\n{markov_prediction_text}\n\n" | |
user_content += f"User Query:\n{user_query}" | |
else: | |
return "Invalid analysis mode selected.", "", "0 seconds", 0 | |
# Create the messages list using the system_prompt from the UI | |
messages = [] | |
if system_prompt and system_prompt.strip(): | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": user_content}) | |
# --- Time Measurement Start --- | |
start_time = time.time() | |
# Generate response from the model | |
response, generated_tokens = generate_response( | |
messages, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p | |
) | |
# --- Time Measurement End --- | |
end_time = time.time() | |
duration = round(end_time - start_time, 2) | |
# For display purposes | |
display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}" | |
print(f"Processing finished in {duration} seconds.") | |
# Return all results including time and tokens | |
return display_prompt, response, f"{duration} seconds", generated_tokens | |
# --- System Prompts (Defaults only, UI will hold the editable version) --- | |
DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats. | |
Follow these steps EXACTLY. Do NOT deviate. | |
Step 1: Identify Player's Most Frequent Move. | |
- Look ONLY at the 'Player Move Frequency Stats'. | |
- List the percentages: Rock (%), Paper (%), Scissors (%). | |
- State which move name has the highest percentage number. | |
Step 2: Determine the Counter Move using RPS Rules. | |
- REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper. | |
- Based *only* on the move identified in Step 1, state the single move name that beats it according to the rules. State the rule you used (e.g., "Paper beats Rock"). | |
Step 3: Explain the Counter Choice. | |
- Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's most frequent move, [Most Frequent Move from Step 1]." | |
Step 4: State Final Recommendation. | |
- State *only* the recommended AI move name from Step 2. Example: "Recommendation: Paper" | |
Base your analysis strictly on the provided frequencies and the stated RPS rules. | |
""" | |
# *** UPDATED Markov System Prompt v2 *** | |
DEFAULT_SYSTEM_PROMPT_MARKOV = """You are an RPS assistant using short-term pattern analysis (Markov prediction). | |
Your ONLY task is to recommend the AI move that beats the player's PREDICTED next move. Accuracy is critical. | |
Input Information Provided: | |
- Player's Predicted Next Move (from Markov analysis): [This is the key input!] | |
Instructions: | |
1. **Identify Prediction:** State the player's PREDICTED next move (Rock, Paper, or Scissors) based *only* on the 'Predicted Next Move' input. | |
2. **Find Counter:** Apply the RPS rules (Paper beats Rock, Rock beats Scissors, Scissors beats Paper). Determine the single move that correctly beats the PREDICTED move from Step 1. State *only* the name of this counter move. Double-check the rules. | |
3. **Recommend:** Clearly state the counter move found in Step 2 as the AI's recommended move. | |
Example Output Format: | |
1. Predicted Player Move: [Predicted move name] | |
2. Counter Move: [Counter move name] | |
3. Recommendation: Play [Counter move name]. | |
""" | |
# --- Default Input Values --- | |
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%" | |
DEFAULT_PLAYER_LAST_MOVE = "Rock" | |
DEFAULT_MARKOV_PREDICTION = "Based on the last move (Rock), the player's most likely next move is Paper (60% probability)." | |
DEFAULT_USER_QUERY = "Based on the provided information for the selected analysis mode, what single move should the AI make next? Explain your reasoning step-by-step as instructed." | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester") | |
gr.Markdown("Test model advice using either Frequency Stats OR Short-Term (Markov) Predictions.") | |
# Mode Selector | |
analysis_mode_selector = gr.Radio( | |
label="Select Analysis Mode", | |
choices=["Frequency Only", "Markov Prediction Only"], | |
value="Frequency Only" # Default mode | |
) | |
# --- Visible System Prompt Textbox --- | |
system_prompt_input = gr.Textbox( | |
label="System Prompt (Edit based on selected mode)", | |
value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt | |
lines=15 | |
) | |
# Input Sections (conditionally visible) | |
with gr.Group(visible=True) as frequency_inputs: # Visible by default | |
gr.Markdown("### Frequency Analysis Inputs") | |
player_stats_input = gr.Textbox( | |
label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4, | |
info="Overall player move distribution." | |
) | |
with gr.Group(visible=False) as markov_inputs: # Hidden by default | |
gr.Markdown("### Markov Prediction Analysis Inputs") | |
player_last_move_input = gr.Dropdown( | |
label="Player's Last Move", choices=["Rock", "Paper", "Scissors"], value=DEFAULT_PLAYER_LAST_MOVE, | |
info="The player's most recent actual move." | |
) | |
markov_prediction_input = gr.Textbox( | |
label="Predicted Next Move (Short-Term Markov Analysis)", value=DEFAULT_MARKOV_PREDICTION, lines=3, | |
info="Provide the pre-calculated prediction based on the last move (e.g., 'Player likely plays Paper (60%)')." | |
) | |
# General Inputs / Parameters / Outputs | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_query_input = gr.Textbox( | |
label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3, | |
info="Ask the specific question based on the selected mode's analysis." | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("#### Generation Parameters") | |
max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens") | |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
submit_btn = gr.Button("Generate Response", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("#### Performance Metrics") | |
time_output = gr.Textbox(label="Generation Time", interactive=False) | |
tokens_output = gr.Number(label="Generated Tokens", interactive=False) | |
with gr.Column(): | |
gr.Markdown(""" | |
#### Testing Tips | |
- Select the desired **Analysis Mode**. | |
- Fill in the inputs for the **selected mode only**. | |
- **Edit the System Prompt** above as needed for testing. | |
- Use low **Temperature** for factual analysis. | |
""") | |
with gr.Row(): | |
final_prompt_display = gr.Textbox( | |
label="Formatted Input Sent to Model (via Chat Template)", lines=20 | |
) | |
response_display = gr.Textbox( | |
label="Model Response", lines=20, show_copy_button=True | |
) | |
# --- Event Handlers --- | |
# Function to update UI visibility AND system prompt content based on mode selection | |
def update_ui_visibility_and_prompt(mode): | |
if mode == "Frequency Only": | |
return { | |
frequency_inputs: gr.update(visible=True), | |
markov_inputs: gr.update(visible=False), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) # Load Frequency prompt | |
} | |
elif mode == "Markov Prediction Only": | |
return { | |
frequency_inputs: gr.update(visible=False), | |
markov_inputs: gr.update(visible=True), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_MARKOV) # Load Markov prompt | |
} | |
else: # Default case | |
return { | |
frequency_inputs: gr.update(visible=True), | |
markov_inputs: gr.update(visible=False), | |
system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) | |
} | |
# Link the radio button change to the UI update function | |
analysis_mode_selector.change( | |
fn=update_ui_visibility_and_prompt, # Use the combined update function | |
inputs=analysis_mode_selector, | |
outputs=[frequency_inputs, markov_inputs, system_prompt_input] # Components to update | |
) | |
# Handle button click - Pass the single visible system prompt | |
submit_btn.click( | |
process_input, | |
inputs=[ | |
analysis_mode_selector, | |
player_stats_input, | |
player_last_move_input, | |
markov_prediction_input, | |
system_prompt_input, # Pass the visible system prompt textbox | |
user_query_input, | |
max_length_slider, | |
temperature_slider, | |
top_p_slider | |
], | |
outputs=[ | |
final_prompt_display, response_display, | |
time_output, tokens_output | |
], | |
api_name="generate_rps_selectable_analysis_v2" # Updated api_name | |
) | |
# --- Launch the demo --- | |
if __name__ == "__main__": | |
demo.launch() | |