File size: 8,634 Bytes
699d672
3b34d1e
2e54946
 
d54daef
699d672
480da6f
 
 
 
2e54946
 
480da6f
 
699d672
480da6f
3ef427a
2e54946
699d672
480da6f
699d672
5f32a93
2e54946
 
480da6f
24a51a6
 
480da6f
 
 
 
3ef427a
3b34d1e
2e54946
 
 
 
 
24a51a6
480da6f
 
 
 
 
 
 
24a51a6
480da6f
 
2e54946
480da6f
24a51a6
 
480da6f
24a51a6
 
480da6f
 
2e54946
480da6f
 
 
2e54946
 
480da6f
2e54946
 
23a7862
5f32a93
 
b053859
24a51a6
480da6f
 
23a7862
 
24a51a6
2e54946
24a51a6
5f32a93
24a51a6
5f32a93
 
 
24a51a6
480da6f
24a51a6
480da6f
 
5f32a93
24a51a6
 
 
5f32a93
2e54946
480da6f
23a7862
 
 
 
5f32a93
24a51a6
 
2e54946
24a51a6
 
480da6f
 
2e54946
24a51a6
 
 
2e54946
3ef427a
24a51a6
 
3ef427a
24a51a6
 
 
 
 
3ef427a
24a51a6
 
5f32a93
2e54946
24a51a6
3ef427a
480da6f
5f32a93
24a51a6
480da6f
699d672
24a51a6
5f32a93
24a51a6
 
699d672
5f32a93
24a51a6
 
5f32a93
 
24a51a6
 
5f32a93
 
2e54946
 
23a7862
5f32a93
24a51a6
5f32a93
24a51a6
2e54946
24a51a6
3ef427a
480da6f
24a51a6
 
2e54946
24a51a6
5f32a93
 
24a51a6
 
 
5f32a93
 
 
24a51a6
2e54946
24a51a6
 
2e54946
24a51a6
5f32a93
23a7862
 
 
24a51a6
 
 
 
 
2e54946
23a7862
2e54946
3b34d1e
699d672
480da6f
 
2e54946
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
import torch
import time
import spaces # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- Configuration ---
MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"

# --- Load Model and Tokenizer ---
# Note: Model loading happens when the Space starts.
# device_map="auto" will attempt to use the GPU when allocated by @spaces.GPU
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto" # Keep this, it helps distribute within the allocated GPU(s)
)
print("Model loaded successfully.")


# --- Generation Function (Returns response and token count) ---
# This function will run on the GPU allocated via the decorator on process_input
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
    """Generate a response and return it along with the number of generated tokens."""
    num_generated_tokens = 0
    try:
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        # Ensure model_inputs are sent to the correct device the model is on
        # device_map='auto' handles this, but explicitly checking model.device is safer
        device = model.device
        model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
        input_ids_len = model_inputs.input_ids.shape[-1]

        generation_kwargs = {
            "max_new_tokens": max_length,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        print("Generating response...")
        with torch.no_grad():
            # Generate response
            generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)

        # Calculate generated tokens
        output_ids = generated_ids[0, input_ids_len:]
        num_generated_tokens = len(output_ids)

        response = tokenizer.decode(output_ids, skip_special_tokens=True)
        print("Generation complete.")
        return response.strip(), num_generated_tokens

    except Exception as e:
        print(f"Error during generation: {e}")
        # Ensure error message is returned correctly even if tokens couldn't be counted
        return f"An error occurred: {str(e)}", num_generated_tokens

# --- Input Processing Function (Decorated for ZeroGPU) ---
@spaces.GPU # Add the ZeroGPU decorator here
def process_input(
    player_stats,
    ai_stats,
    system_prompt,
    user_query,
    max_length,
    temperature,
    top_p
):
    """Process inputs, generate response, and return display info, response, time, and token count."""
    print("GPU requested via decorator, starting processing...") # Add a log message
    # Construct the user message content
    user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
    if ai_stats and ai_stats.strip():
         user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
    user_content += f"User Query:\n{user_query}"

    # Create the messages list
    messages = []
    if system_prompt and system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_content})

    # --- Time Measurement Start ---
    start_time = time.time()

    # Generate response from the model
    response, generated_tokens = generate_response(
        messages,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p
    )

    # --- Time Measurement End ---
    end_time = time.time()
    duration = round(end_time - start_time, 2)

    # For display purposes
    display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"

    print(f"Processing finished in {duration} seconds.") # Add a log message
    # Return all results including time and tokens
    return display_prompt, response, f"{duration} seconds", generated_tokens

# --- Gradio Interface (No changes needed here) ---

DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.

Follow these steps:
1.  **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
2.  **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
3.  **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
4.  **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).

Base your analysis strictly on the provided frequencies and standard RPS rules."""

DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_AI_STATS = ""
DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
    gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")

    with gr.Row():
        with gr.Column(scale=2): # Input column
            player_stats_input = gr.Textbox(
                label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
                info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
            )
            ai_stats_input = gr.Textbox(
                label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
                info="Optionally, enter AI's own move frequencies."
            )
            user_query_input = gr.Textbox(
                label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
                info="Ask the specific question based on the stats."
            )
            system_prompt_input = gr.Textbox(
                label="System Prompt", value=DEFAULT_SYSTEM_PROMPT,
                lines=12
            )

        with gr.Column(scale=1): # Params/Output column
            gr.Markdown("## Generation Parameters")
            max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
            temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
            top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
            submit_btn = gr.Button("Generate Response", variant="primary")

            gr.Markdown("## Performance Metrics")
            time_output = gr.Textbox(label="Generation Time", interactive=False)
            tokens_output = gr.Number(label="Generated Tokens", interactive=False)

            gr.Markdown("""
            ## Testing Tips
            - Focus on player stats for optimal counter strategy.
            - Use the refined **System Prompt** for better reasoning guidance.
            - Lower **Temperature** encourages more direct, statistical answers.
            """)

    with gr.Row():
        final_prompt_display = gr.Textbox(
            label="Formatted Input Sent to Model (via Chat Template)", lines=20
        )
        response_display = gr.Textbox(
            label="Model Response", lines=20, show_copy_button=True
        )

    submit_btn.click(
        process_input,
        inputs=[
            player_stats_input, ai_stats_input, system_prompt_input,
            user_query_input, max_length_slider, temperature_slider, top_p_slider
        ],
        outputs=[
            final_prompt_display, response_display,
            time_output, tokens_output
        ],
        api_name="generate_rps_frequency_analysis_v2"
    )

# --- Launch the demo ---
if __name__ == "__main__":
    # Share=True is needed for ZeroGPU to work correctly if running locally for testing
    # but usually not needed when deployed on HF Spaces platform.
    demo.launch()