Spaces:
Sleeping
Sleeping
File size: 8,398 Bytes
699d672 3b34d1e 24a51a6 d54daef 699d672 480da6f 699d672 480da6f 3ef427a 480da6f 699d672 480da6f 699d672 5f32a93 24a51a6 480da6f 24a51a6 480da6f 3ef427a 3b34d1e 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 480da6f 24a51a6 23a7862 5f32a93 b053859 24a51a6 480da6f 23a7862 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 480da6f 24a51a6 480da6f 5f32a93 24a51a6 5f32a93 24a51a6 480da6f 23a7862 5f32a93 24a51a6 480da6f 24a51a6 3ef427a 24a51a6 3ef427a 24a51a6 3ef427a 24a51a6 5f32a93 24a51a6 3ef427a 480da6f 5f32a93 24a51a6 480da6f 699d672 24a51a6 5f32a93 24a51a6 699d672 5f32a93 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 23a7862 5f32a93 24a51a6 5f32a93 24a51a6 3ef427a 480da6f 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 23a7862 24a51a6 23a7862 24a51a6 3b34d1e 699d672 480da6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
import torch
import time # Import time module
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
# --- Load Model and Tokenizer ---
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto"
)
print("Model loaded successfully.")
# --- Generation Function (Updated to return token count) ---
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
"""Generate a response and return it along with the number of generated tokens."""
num_generated_tokens = 0
try:
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
input_ids_len = model_inputs.input_ids.shape[-1] # Length of input tokens
generation_kwargs = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
print("Generating response...")
with torch.no_grad():
# Generate response - Ensure output_scores or similar isn't needed if just counting
generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
# Calculate generated tokens
output_ids = generated_ids[0, input_ids_len:]
num_generated_tokens = len(output_ids)
response = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Generation complete.")
return response.strip(), num_generated_tokens # Return response and token count
except Exception as e:
print(f"Error during generation: {e}")
return f"An error occurred: {str(e)}", num_generated_tokens # Return error and token count
# --- Input Processing Function (Updated for Time/Token outputs) ---
def process_input(
player_stats,
ai_stats,
system_prompt,
user_query,
max_length,
temperature,
top_p
):
"""Process inputs, generate response, and return display info, response, time, and token count."""
# Construct the user message content
user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
if ai_stats and ai_stats.strip():
user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
user_content += f"User Query:\n{user_query}"
# Create the messages list
messages = []
if system_prompt and system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
# --- Time Measurement Start ---
start_time = time.time()
# Generate response from the model
response, generated_tokens = generate_response( # Capture token count
messages,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
# --- Time Measurement End ---
end_time = time.time()
duration = round(end_time - start_time, 2) # Calculate duration
# For display purposes
display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
# Return all results including time and tokens
return display_prompt, response, f"{duration} seconds", generated_tokens
# --- Gradio Interface (Added Time/Token displays, refined System Prompt) ---
# Refined default system prompt for better reasoning
DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
Follow these steps:
1. **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
2. **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
3. **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
4. **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).
Base your analysis strictly on the provided frequencies and standard RPS rules."""
# Default example stats and query
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_AI_STATS = "" # Keep AI stats optional and clear by default
DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")
with gr.Row():
with gr.Column(scale=2): # Input column
player_stats_input = gr.Textbox(
label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
)
ai_stats_input = gr.Textbox(
label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
info="Optionally, enter AI's own move frequencies."
)
user_query_input = gr.Textbox(
label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
info="Ask the specific question based on the stats."
)
system_prompt_input = gr.Textbox(
label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, # Set default value
lines=12 # Adjusted lines
)
with gr.Column(scale=1): # Params/Output column
gr.Markdown("## Generation Parameters")
max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") # Lowered default further
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
submit_btn = gr.Button("Generate Response", variant="primary")
gr.Markdown("## Performance Metrics")
# Outputs for Time and Tokens
time_output = gr.Textbox(label="Generation Time", interactive=False)
tokens_output = gr.Number(label="Generated Tokens", interactive=False) # Use Number for token count
gr.Markdown("""
## Testing Tips
- Focus on player stats for optimal counter strategy.
- Use the refined **System Prompt** for better reasoning guidance.
- Lower **Temperature** encourages more direct, statistical answers.
""")
with gr.Row():
# Display final prompt and model response (side-by-side)
final_prompt_display = gr.Textbox(
label="Formatted Input Sent to Model (via Chat Template)", lines=20 # Increased lines
)
response_display = gr.Textbox(
label="Model Response", lines=20, show_copy_button=True # Increased lines
)
# Handle button click - Updated inputs and outputs list
submit_btn.click(
process_input,
inputs=[
player_stats_input, ai_stats_input, system_prompt_input,
user_query_input, max_length_slider, temperature_slider, top_p_slider
],
outputs=[
final_prompt_display, response_display,
time_output, tokens_output # Added new outputs
],
api_name="generate_rps_frequency_analysis_v2" # Updated api_name
)
# --- Launch the demo ---
if __name__ == "__main__":
demo.launch() |