rui3000 commited on
Commit
2e54946
·
verified ·
1 Parent(s): 24a51a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -1,23 +1,27 @@
1
  import gradio as gr
2
  import torch
3
- import time # Import time module
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # --- Configuration ---
7
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
8
 
9
  # --- Load Model and Tokenizer ---
 
 
10
  print(f"Loading model: {MODEL_ID}")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
  torch_dtype="auto",
15
- device_map="auto"
16
  )
17
  print("Model loaded successfully.")
18
 
19
 
20
- # --- Generation Function (Updated to return token count) ---
 
21
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
22
  """Generate a response and return it along with the number of generated tokens."""
23
  num_generated_tokens = 0
@@ -27,8 +31,11 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
27
  tokenize=False,
28
  add_generation_prompt=True
29
  )
30
- model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
31
- input_ids_len = model_inputs.input_ids.shape[-1] # Length of input tokens
 
 
 
32
 
33
  generation_kwargs = {
34
  "max_new_tokens": max_length,
@@ -40,7 +47,7 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
40
 
41
  print("Generating response...")
42
  with torch.no_grad():
43
- # Generate response - Ensure output_scores or similar isn't needed if just counting
44
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
45
 
46
  # Calculate generated tokens
@@ -49,13 +56,15 @@ def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
49
 
50
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
51
  print("Generation complete.")
52
- return response.strip(), num_generated_tokens # Return response and token count
53
 
54
  except Exception as e:
55
  print(f"Error during generation: {e}")
56
- return f"An error occurred: {str(e)}", num_generated_tokens # Return error and token count
 
57
 
58
- # --- Input Processing Function (Updated for Time/Token outputs) ---
 
59
  def process_input(
60
  player_stats,
61
  ai_stats,
@@ -66,7 +75,7 @@ def process_input(
66
  top_p
67
  ):
68
  """Process inputs, generate response, and return display info, response, time, and token count."""
69
-
70
  # Construct the user message content
71
  user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
72
  if ai_stats and ai_stats.strip():
@@ -83,7 +92,7 @@ def process_input(
83
  start_time = time.time()
84
 
85
  # Generate response from the model
86
- response, generated_tokens = generate_response( # Capture token count
87
  messages,
88
  max_length=max_length,
89
  temperature=temperature,
@@ -92,17 +101,17 @@ def process_input(
92
 
93
  # --- Time Measurement End ---
94
  end_time = time.time()
95
- duration = round(end_time - start_time, 2) # Calculate duration
96
 
97
  # For display purposes
98
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
99
 
 
100
  # Return all results including time and tokens
101
  return display_prompt, response, f"{duration} seconds", generated_tokens
102
 
103
- # --- Gradio Interface (Added Time/Token displays, refined System Prompt) ---
104
 
105
- # Refined default system prompt for better reasoning
106
  DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
107
  Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
108
 
@@ -114,9 +123,8 @@ Follow these steps:
114
 
115
  Base your analysis strictly on the provided frequencies and standard RPS rules."""
116
 
117
- # Default example stats and query
118
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
119
- DEFAULT_AI_STATS = "" # Keep AI stats optional and clear by default
120
  DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
121
 
122
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -138,21 +146,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
138
  info="Ask the specific question based on the stats."
139
  )
140
  system_prompt_input = gr.Textbox(
141
- label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, # Set default value
142
- lines=12 # Adjusted lines
143
  )
144
 
145
  with gr.Column(scale=1): # Params/Output column
146
  gr.Markdown("## Generation Parameters")
147
  max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
148
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature") # Lowered default further
149
  top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
150
  submit_btn = gr.Button("Generate Response", variant="primary")
151
 
152
  gr.Markdown("## Performance Metrics")
153
- # Outputs for Time and Tokens
154
  time_output = gr.Textbox(label="Generation Time", interactive=False)
155
- tokens_output = gr.Number(label="Generated Tokens", interactive=False) # Use Number for token count
156
 
157
  gr.Markdown("""
158
  ## Testing Tips
@@ -162,15 +169,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
162
  """)
163
 
164
  with gr.Row():
165
- # Display final prompt and model response (side-by-side)
166
  final_prompt_display = gr.Textbox(
167
- label="Formatted Input Sent to Model (via Chat Template)", lines=20 # Increased lines
168
  )
169
  response_display = gr.Textbox(
170
- label="Model Response", lines=20, show_copy_button=True # Increased lines
171
  )
172
 
173
- # Handle button click - Updated inputs and outputs list
174
  submit_btn.click(
175
  process_input,
176
  inputs=[
@@ -179,11 +184,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
179
  ],
180
  outputs=[
181
  final_prompt_display, response_display,
182
- time_output, tokens_output # Added new outputs
183
  ],
184
- api_name="generate_rps_frequency_analysis_v2" # Updated api_name
185
  )
186
 
187
  # --- Launch the demo ---
188
  if __name__ == "__main__":
189
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import time
4
+ import spaces # Import the spaces library
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  # --- Configuration ---
8
  MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
9
 
10
  # --- Load Model and Tokenizer ---
11
+ # Note: Model loading happens when the Space starts.
12
+ # device_map="auto" will attempt to use the GPU when allocated by @spaces.GPU
13
  print(f"Loading model: {MODEL_ID}")
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  torch_dtype="auto",
18
+ device_map="auto" # Keep this, it helps distribute within the allocated GPU(s)
19
  )
20
  print("Model loaded successfully.")
21
 
22
 
23
+ # --- Generation Function (Returns response and token count) ---
24
+ # This function will run on the GPU allocated via the decorator on process_input
25
  def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
26
  """Generate a response and return it along with the number of generated tokens."""
27
  num_generated_tokens = 0
 
31
  tokenize=False,
32
  add_generation_prompt=True
33
  )
34
+ # Ensure model_inputs are sent to the correct device the model is on
35
+ # device_map='auto' handles this, but explicitly checking model.device is safer
36
+ device = model.device
37
+ model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
38
+ input_ids_len = model_inputs.input_ids.shape[-1]
39
 
40
  generation_kwargs = {
41
  "max_new_tokens": max_length,
 
47
 
48
  print("Generating response...")
49
  with torch.no_grad():
50
+ # Generate response
51
  generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)
52
 
53
  # Calculate generated tokens
 
56
 
57
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
58
  print("Generation complete.")
59
+ return response.strip(), num_generated_tokens
60
 
61
  except Exception as e:
62
  print(f"Error during generation: {e}")
63
+ # Ensure error message is returned correctly even if tokens couldn't be counted
64
+ return f"An error occurred: {str(e)}", num_generated_tokens
65
 
66
+ # --- Input Processing Function (Decorated for ZeroGPU) ---
67
+ @spaces.GPU # Add the ZeroGPU decorator here
68
  def process_input(
69
  player_stats,
70
  ai_stats,
 
75
  top_p
76
  ):
77
  """Process inputs, generate response, and return display info, response, time, and token count."""
78
+ print("GPU requested via decorator, starting processing...") # Add a log message
79
  # Construct the user message content
80
  user_content = f"Player Move Frequency Stats:\n{player_stats}\n\n"
81
  if ai_stats and ai_stats.strip():
 
92
  start_time = time.time()
93
 
94
  # Generate response from the model
95
+ response, generated_tokens = generate_response(
96
  messages,
97
  max_length=max_length,
98
  temperature=temperature,
 
101
 
102
  # --- Time Measurement End ---
103
  end_time = time.time()
104
+ duration = round(end_time - start_time, 2)
105
 
106
  # For display purposes
107
  display_prompt = f"System Prompt (if used):\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
108
 
109
+ print(f"Processing finished in {duration} seconds.") # Add a log message
110
  # Return all results including time and tokens
111
  return display_prompt, response, f"{duration} seconds", generated_tokens
112
 
113
+ # --- Gradio Interface (No changes needed here) ---
114
 
 
115
  DEFAULT_SYSTEM_PROMPT = """You are an expert Rock-Paper-Scissors (RPS) strategist focusing on statistical analysis.
116
  Your task is to recommend the optimal AI move based *only* on the provided move frequency statistics for the player.
117
 
 
123
 
124
  Base your analysis strictly on the provided frequencies and standard RPS rules."""
125
 
 
126
  DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
127
+ DEFAULT_AI_STATS = ""
128
  DEFAULT_USER_QUERY = "Based *only* on the player's move frequencies, what single move should the AI make next to maximize its statistical chance of winning? Explain your reasoning clearly step-by-step as instructed."
129
 
130
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
146
  info="Ask the specific question based on the stats."
147
  )
148
  system_prompt_input = gr.Textbox(
149
+ label="System Prompt", value=DEFAULT_SYSTEM_PROMPT,
150
+ lines=12
151
  )
152
 
153
  with gr.Column(scale=1): # Params/Output column
154
  gr.Markdown("## Generation Parameters")
155
  max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
156
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
157
  top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
158
  submit_btn = gr.Button("Generate Response", variant="primary")
159
 
160
  gr.Markdown("## Performance Metrics")
 
161
  time_output = gr.Textbox(label="Generation Time", interactive=False)
162
+ tokens_output = gr.Number(label="Generated Tokens", interactive=False)
163
 
164
  gr.Markdown("""
165
  ## Testing Tips
 
169
  """)
170
 
171
  with gr.Row():
 
172
  final_prompt_display = gr.Textbox(
173
+ label="Formatted Input Sent to Model (via Chat Template)", lines=20
174
  )
175
  response_display = gr.Textbox(
176
+ label="Model Response", lines=20, show_copy_button=True
177
  )
178
 
 
179
  submit_btn.click(
180
  process_input,
181
  inputs=[
 
184
  ],
185
  outputs=[
186
  final_prompt_display, response_display,
187
+ time_output, tokens_output
188
  ],
189
+ api_name="generate_rps_frequency_analysis_v2"
190
  )
191
 
192
  # --- Launch the demo ---
193
  if __name__ == "__main__":
194
+ # Share=True is needed for ZeroGPU to work correctly if running locally for testing
195
+ # but usually not needed when deployed on HF Spaces platform.
196
+ demo.launch()