File size: 8,398 Bytes
699d672
3b34d1e
24a51a6
d54daef
699d672
480da6f
 
 
 
 
 
699d672
480da6f
3ef427a
480da6f
699d672
480da6f
699d672
5f32a93
24a51a6
480da6f
24a51a6
 
480da6f
 
 
 
3ef427a
3b34d1e
480da6f
24a51a6
 
480da6f
 
 
 
 
 
 
24a51a6
480da6f
 
24a51a6
480da6f
24a51a6
 
480da6f
24a51a6
 
480da6f
 
24a51a6
480da6f
 
 
24a51a6
480da6f
24a51a6
23a7862
5f32a93
 
b053859
24a51a6
480da6f
 
23a7862
 
24a51a6
5f32a93
24a51a6
5f32a93
24a51a6
5f32a93
 
 
24a51a6
480da6f
24a51a6
480da6f
 
5f32a93
24a51a6
 
 
5f32a93
24a51a6
480da6f
23a7862
 
 
 
5f32a93
24a51a6
 
 
 
 
480da6f
 
24a51a6
 
 
 
3ef427a
24a51a6
 
 
3ef427a
24a51a6
 
 
 
 
3ef427a
24a51a6
 
 
5f32a93
24a51a6
 
3ef427a
480da6f
5f32a93
24a51a6
480da6f
699d672
24a51a6
5f32a93
24a51a6
 
699d672
5f32a93
24a51a6
 
5f32a93
 
24a51a6
 
5f32a93
 
24a51a6
 
23a7862
5f32a93
24a51a6
5f32a93
24a51a6
 
 
3ef427a
480da6f
24a51a6
 
 
 
 
5f32a93
 
24a51a6
 
 
5f32a93
 
 
24a51a6
 
 
 
 
 
 
5f32a93
24a51a6
23a7862
 
 
24a51a6
 
 
 
 
 
23a7862
24a51a6
3b34d1e
699d672
480da6f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import torch
import time # Import time module
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- Configuration ---
MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"

# --- Load Model and Tokenizer ---
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto"
)
print("Model loaded successfully.")


# --- Generation Function (Updated to return token count) ---
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
    """Generate a response and return it along with the number of generated tokens."""
    num_generated_tokens = 0
    try:
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
        input_ids_len = model_inputs.input_ids.shape[-1] # Length of input tokens

        generation_kwargs = {
            "max_new_tokens": max_length,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        print("Generating response...")
        with torch.no_grad():
            # Generate response - Ensure output_scores or similar isn't needed if just counting
            generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)

        # Calculate generated tokens
        output_ids = generated_ids[0, input_ids_len:]
        num_generated_tokens = len(output_ids)

        response = tokenizer.decode(output_ids, skip_special_tokens=True)
        print("Generation complete.")
        return response.strip(), num_generated_tokens # Return response and token count

    except Exception as e:
        print(f"Error during generation: {e}")
        return f"An error occurred: {str(e)}", num_generated_tokens # Return error and token count

# --- Input Processing Function (Updated for Time/Token outputs) ---
def process_input(
    player_stats,
    ai_stats,
    system_prompt,
    user_query,
    max_length,
    temperature,
    top_p
):
    """Process inputs, generate response, and return display info, response, time, and token count."""

    # Construct the user message content
    user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
    if ai_stats and ai_stats.strip():
         user_content += f"AI Move Frequency Stats:\n{ai_stats}\n\n"
    user_content += f"User Query:\n{user_query}"

    # Create the messages list
    messages = []
    if system_prompt and system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_content})

    # --- Time Measurement Start ---
    start_time = time.time()

    # Generate response from the model
    response, generated_tokens = generate_response( # Capture token count
        messages,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p
    )

    # --- Time Measurement End ---
    end_time = time.time()
    duration = round(end_time - start_time, 2) # Calculate duration

    # For display purposes
    display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"

    # Return all results including time and tokens
    return display_prompt, response, f"{duration} seconds", generated_tokens

# --- Gradio Interface (Added Time/Token displays, refined System Prompt) ---

# Refined default system prompt for better reasoning
DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.

Follow these steps:
1.  **Identify Player's Most Frequent Move:** Note the move (Rock, Paper, or Scissors) the player uses most often according to the stats.
2.  **Determine Best Counter:** Identify the RPS move that directly beats the player's most frequent move (Rock beats Scissors, Scissors beats Paper, Paper beats Rock).
3.  **Justify Recommendation:** Explain *why* this counter-move is statistically optimal. You can mention the expected outcome. For example: 'Playing Paper counters the player's most frequent move, Rock (40% frequency). This offers the highest probability of winning against the player's likely action.' Avoid irrelevant justifications based on the AI's own move frequencies.
4.  **State Recommendation:** Clearly state the recommended move (Rock, Paper, or Scissors).

Base your analysis strictly on the provided frequencies and standard RPS rules."""

# Default example stats and query
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_AI_STATS = "" # Keep AI stats optional and clear by default
DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"# {MODEL_ID} - RPS Frequency Analysis Tester")
    gr.Markdown("Test model advice based on Player/AI move frequencies. Includes Generation Time and Token Count.")

    with gr.Row():
        with gr.Column(scale=2): # Input column
            player_stats_input = gr.Textbox(
                label="Player Move Frequency Stats", value=DEFAULT_PLAYER_STATS, lines=4,
                info="Enter player's move frequencies (e.g., Rock: 50%, Paper: 30%, Scissors: 20%)."
            )
            ai_stats_input = gr.Textbox(
                label="AI Move Frequency Stats (Optional)", value=DEFAULT_AI_STATS, lines=4,
                info="Optionally, enter AI's own move frequencies."
            )
            user_query_input = gr.Textbox(
                label="Your Query / Instruction", value=DEFAULT_USER_QUERY, lines=3,
                info="Ask the specific question based on the stats."
            )
            system_prompt_input = gr.Textbox(
                label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, # Set default value
                lines=12 # Adjusted lines
            )

        with gr.Column(scale=1): # Params/Output column
            gr.Markdown("## Generation Parameters")
            max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
            temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") # Lowered default further
            top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
            submit_btn = gr.Button("Generate Response", variant="primary")

            gr.Markdown("## Performance Metrics")
            # Outputs for Time and Tokens
            time_output = gr.Textbox(label="Generation Time", interactive=False)
            tokens_output = gr.Number(label="Generated Tokens", interactive=False) # Use Number for token count

            gr.Markdown("""
            ## Testing Tips
            - Focus on player stats for optimal counter strategy.
            - Use the refined **System Prompt** for better reasoning guidance.
            - Lower **Temperature** encourages more direct, statistical answers.
            """)

    with gr.Row():
        # Display final prompt and model response (side-by-side)
        final_prompt_display = gr.Textbox(
            label="Formatted Input Sent to Model (via Chat Template)", lines=20 # Increased lines
        )
        response_display = gr.Textbox(
            label="Model Response", lines=20, show_copy_button=True # Increased lines
        )

    # Handle button click - Updated inputs and outputs list
    submit_btn.click(
        process_input,
        inputs=[
            player_stats_input, ai_stats_input, system_prompt_input,
            user_query_input, max_length_slider, temperature_slider, top_p_slider
        ],
        outputs=[
            final_prompt_display, response_display,
            time_output, tokens_output # Added new outputs
        ],
        api_name="generate_rps_frequency_analysis_v2" # Updated api_name
    )

# --- Launch the demo ---
if __name__ == "__main__":
    demo.launch()