Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,634 Bytes
699d672 3b34d1e 2e54946 d54daef 699d672 480da6f 2e54946 480da6f 699d672 480da6f 3ef427a 2e54946 699d672 480da6f 699d672 5f32a93 2e54946 480da6f 24a51a6 480da6f 3ef427a 3b34d1e 2e54946 24a51a6 480da6f 24a51a6 480da6f 2e54946 480da6f 24a51a6 480da6f 24a51a6 480da6f 2e54946 480da6f 2e54946 480da6f 2e54946 23a7862 5f32a93 b053859 24a51a6 480da6f 23a7862 24a51a6 2e54946 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 480da6f 24a51a6 480da6f 5f32a93 24a51a6 5f32a93 2e54946 480da6f 23a7862 5f32a93 24a51a6 2e54946 24a51a6 480da6f 2e54946 24a51a6 2e54946 3ef427a 24a51a6 3ef427a 24a51a6 3ef427a 24a51a6 5f32a93 2e54946 24a51a6 3ef427a 480da6f 5f32a93 24a51a6 480da6f 699d672 24a51a6 5f32a93 24a51a6 699d672 5f32a93 24a51a6 5f32a93 24a51a6 5f32a93 2e54946 23a7862 5f32a93 24a51a6 5f32a93 24a51a6 2e54946 24a51a6 3ef427a 480da6f 24a51a6 2e54946 24a51a6 5f32a93 24a51a6 5f32a93 24a51a6 2e54946 24a51a6 2e54946 24a51a6 5f32a93 23a7862 24a51a6 2e54946 23a7862 2e54946 3b34d1e 699d672 480da6f 2e54946 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import torch
import time
import spaces # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
# --- Load Model and Tokenizer ---
# Note: Model loading happens when the Space starts.
# device_map="auto" will attempt to use the GPU when allocated by @spaces.GPU
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto" # Keep this, it helps distribute within the allocated GPU(s)
)
print("Model loaded successfully.")
# --- Generation Function (Returns response and token count) ---
# This function will run on the GPU allocated via the decorator on process_input
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
"""Generate a response and return it along with the number of generated tokens."""
num_generated_tokens = 0
try:
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Ensure model_inputs are sent to the correct device the model is on
# device_map='auto' handles this, but explicitly checking model.device is safer
device = model.device
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
input_ids_len = model_inputs.input_ids.shape[-1]
generation_kwargs = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
print("Generating response...")
with torch.no_grad():
# Generate response
generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
# Calculate generated tokens
output_ids = generated_ids[0, input_ids_len:]
num_generated_tokens = len(output_ids)
response = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Generation complete.")
return response.strip(), num_generated_tokens
except Exception as e:
print(f"Error during generation: {e}")
# Ensure error message is returned correctly even if tokens couldn't be counted
return f"An error occurred: {str(e)}", num_generated_tokens
# --- Input Processing Function (Decorated for ZeroGPU) ---
@spaces.GPU # Add the ZeroGPU decorator here
def process_input(
player_stats,
ai_stats,
system_prompt,
user_query,
max_length,
temperature,
top_p
):
"""Process inputs, generate response, and return display info, response, time, and token count."""
print("GPU requested via decorator, starting processing...") # Add a log message
# Construct the user message content
user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
if ai_stats and ai_stats.strip():
user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
user_content += f"User Query:\n{user_query}"
# Create the messages list
messages = []
if system_prompt and system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
# --- Time Measurement Start ---
start_time = time.time()
# Generate response from the model
response, generated_tokens = generate_response(
messages,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
# --- Time Measurement End ---
end_time = time.time()
duration = round(end_time - start_time, 2)
# For display purposes
display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
print(f"Processing finished in {duration} seconds.") # Add a log message
# Return all results including time and tokens
return display_prompt, response, f"{duration} seconds", generated_tokens
# --- Gradio Interface (No changes needed here) ---
DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
Follow these steps:
1. **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
2. **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
3. **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
4. **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).
Base your analysis strictly on the provided frequencies and standard RPS rules."""
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_AI_STATS = ""
DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")
with gr.Row():
with gr.Column(scale=2): # Input column
player_stats_input = gr.Textbox(
label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
)
ai_stats_input = gr.Textbox(
label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
info="Optionally, enter AI's own move frequencies."
)
user_query_input = gr.Textbox(
label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
info="Ask the specific question based on the stats."
)
system_prompt_input = gr.Textbox(
label="System Prompt", value=DEFAULT_SYSTEM_PROMPT,
lines=12
)
with gr.Column(scale=1): # Params/Output column
gr.Markdown("## Generation Parameters")
max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
submit_btn = gr.Button("Generate Response", variant="primary")
gr.Markdown("## Performance Metrics")
time_output = gr.Textbox(label="Generation Time", interactive=False)
tokens_output = gr.Number(label="Generated Tokens", interactive=False)
gr.Markdown("""
## Testing Tips
- Focus on player stats for optimal counter strategy.
- Use the refined **System Prompt** for better reasoning guidance.
- Lower **Temperature** encourages more direct, statistical answers.
""")
with gr.Row():
final_prompt_display = gr.Textbox(
label="Formatted Input Sent to Model (via Chat Template)", lines=20
)
response_display = gr.Textbox(
label="Model Response", lines=20, show_copy_button=True
)
submit_btn.click(
process_input,
inputs=[
player_stats_input, ai_stats_input, system_prompt_input,
user_query_input, max_length_slider, temperature_slider, top_p_slider
],
outputs=[
final_prompt_display, response_display,
time_output, tokens_output
],
api_name="generate_rps_frequency_analysis_v2"
)
# --- Launch the demo ---
if __name__ == "__main__":
# Share=True is needed for ZeroGPU to work correctly if running locally for testing
# but usually not needed when deployed on HF Spaces platform.
demo.launch()
|