NikhilJoson commited on
Commit
5c3cb3b
·
verified ·
1 Parent(s): 28bb0de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -18,7 +18,9 @@ model_path = "deepseek-ai/Janus-Pro-7B"
18
  config = AutoConfig.from_pretrained(model_path)
19
  language_config = config.language_config
20
  language_config._attn_implementation = 'eager'
21
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
 
 
22
  if torch.cuda.is_available():
23
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
24
  else:
@@ -33,7 +35,16 @@ sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu
33
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
34
 
35
  # Patterns for detecting image generation requests
36
- GENERATION_PATTERNS = [r"generate (.+)", r"create (.+)", r"draw (.+)", r"make (.+)", r"show (.+)", r"visualize (.+)", r"imagine (.+)", r"picture (.+)",]
 
 
 
 
 
 
 
 
 
37
 
38
  def is_generation_request(message):
39
  """Determine if a message is requesting image generation"""
@@ -82,9 +93,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
82
  if chat_history and len(chat_history) > 0:
83
  # Get the last few turns of conversation for context (limit to last 3 turns)
84
  recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
85
- # context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
86
- context_text = " ".join([f"{user_msg}" for user_msg, _ in recent_context])
87
-
88
 
89
  # Only use context if it's not too long
90
  if len(context_text) < 200: # Arbitrary length limit
@@ -157,10 +166,18 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
157
 
158
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
159
 
160
- outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
161
- pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
162
- eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
163
- do_sample=False if temperature == 0 else True, use_cache=True,)
 
 
 
 
 
 
 
 
164
 
165
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
166
 
@@ -172,7 +189,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
172
 
173
 
174
  def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
175
- image_token_num_per_image: int = 1024, patch_size: int = 16, progress=gr.Progress(track_tqdm=True)):
176
  # Clear CUDA cache before generating
177
  torch.cuda.empty_cache()
178
 
@@ -187,9 +204,7 @@ def generate(input_ids, width, height, temperature: float = 1, parallel_size: in
187
  pkv = None
188
  for i in range(image_token_num_per_image):
189
  with torch.no_grad():
190
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
191
- use_cache=True,
192
- past_key_values=pkv)
193
  pkv = outputs.past_key_values
194
  hidden_states = outputs.last_hidden_state
195
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -221,7 +236,7 @@ def unpack(dec, width, height, parallel_size=5):
221
 
222
 
223
  @torch.inference_mode()
224
- @spaces.GPU(duration=180) # Specify a duration to avoid timeout
225
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
226
  # Clear CUDA cache and avoid tracking gradients
227
  torch.cuda.empty_cache()
@@ -238,7 +253,8 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
238
  messages = [{'role': '<|User|>', 'content': prompt},
239
  {'role': '<|Assistant|>', 'content': ''}]
240
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
241
- sft_format=vl_chat_processor.sft_format, system_prompt='')
 
242
  text = text + vl_chat_processor.image_start_tag
243
 
244
  input_ids = torch.LongTensor(tokenizer.encode(text))
@@ -276,10 +292,18 @@ def add_image_to_chat(image, chat_history):
276
  def clear_chat(image):
277
  return [], image, None
278
 
 
279
 
280
  # Gradio interface
281
  with gr.Blocks() as demo:
282
  gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
 
 
 
 
 
 
 
283
 
284
  # State variables to maintain context
285
  chat_history = gr.State([])
@@ -297,17 +321,10 @@ with gr.Blocks() as demo:
297
 
298
  with gr.Accordion("Image Generation Options", open=False):
299
  cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
300
- t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
301
 
302
  clear_button = gr.Button("Clear Chat")
303
 
304
- gr.Markdown("""
305
- ### Tips:
306
- 1. Upload an image to discuss it
307
- 2. Type commands like "generate [description]" to create images
308
- 3. Continue chatting about uploaded or generated images
309
- 4. Use natural language like "show me a sunset" or "create a portrait"
310
- """)
311
 
312
  with gr.Column(scale=2):
313
  chat_interface = gr.Chatbot(label="Chat History", height=500)
 
18
  config = AutoConfig.from_pretrained(model_path)
19
  language_config = config.language_config
20
  language_config._attn_implementation = 'eager'
21
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
22
+ language_config=language_config,
23
+ trust_remote_code=True)
24
  if torch.cuda.is_available():
25
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
26
  else:
 
35
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
36
 
37
  # Patterns for detecting image generation requests
38
+ GENERATION_PATTERNS = [
39
+ r"generate (.+)",
40
+ r"create (.+)",
41
+ r"draw (.+)",
42
+ r"make (.+)",
43
+ r"show (.+)",
44
+ r"visualize (.+)",
45
+ r"imagine (.+)",
46
+ r"picture (.+)",
47
+ ]
48
 
49
  def is_generation_request(message):
50
  """Determine if a message is requesting image generation"""
 
93
  if chat_history and len(chat_history) > 0:
94
  # Get the last few turns of conversation for context (limit to last 3 turns)
95
  recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
96
+ context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
 
 
97
 
98
  # Only use context if it's not too long
99
  if len(context_text) < 200: # Arbitrary length limit
 
166
 
167
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
168
 
169
+ outputs = vl_gpt.language_model.generate(
170
+ inputs_embeds=inputs_embeds,
171
+ attention_mask=prepare_inputs.attention_mask,
172
+ pad_token_id=tokenizer.eos_token_id,
173
+ bos_token_id=tokenizer.bos_token_id,
174
+ eos_token_id=tokenizer.eos_token_id,
175
+ max_new_tokens=512,
176
+ do_sample=False if temperature == 0 else True,
177
+ use_cache=True,
178
+ temperature=temperature,
179
+ top_p=top_p,
180
+ )
181
 
182
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
183
 
 
189
 
190
 
191
  def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
192
+ image_token_num_per_image: int = 576, patch_size: int = 16, progress=gr.Progress(track_tqdm=True)):
193
  # Clear CUDA cache before generating
194
  torch.cuda.empty_cache()
195
 
 
204
  pkv = None
205
  for i in range(image_token_num_per_image):
206
  with torch.no_grad():
207
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
 
 
208
  pkv = outputs.past_key_values
209
  hidden_states = outputs.last_hidden_state
210
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
236
 
237
 
238
  @torch.inference_mode()
239
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
240
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
241
  # Clear CUDA cache and avoid tracking gradients
242
  torch.cuda.empty_cache()
 
253
  messages = [{'role': '<|User|>', 'content': prompt},
254
  {'role': '<|Assistant|>', 'content': ''}]
255
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
256
+ sft_format=vl_chat_processor.sft_format,
257
+ system_prompt='')
258
  text = text + vl_chat_processor.image_start_tag
259
 
260
  input_ids = torch.LongTensor(tokenizer.encode(text))
 
292
  def clear_chat(image):
293
  return [], image, None
294
 
295
+
296
 
297
  # Gradio interface
298
  with gr.Blocks() as demo:
299
  gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
300
+ gr.Markdown("""
301
+ ### Tips:
302
+ 1. Upload an image to discuss it
303
+ 2. Type commands like "generate [description]" to create images
304
+ 3. Continue chatting about uploaded or generated images
305
+ 4. Use natural language like "show me a sunset" or "create a portrait"
306
+ """)
307
 
308
  # State variables to maintain context
309
  chat_history = gr.State([])
 
321
 
322
  with gr.Accordion("Image Generation Options", open=False):
323
  cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
324
+ t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
325
 
326
  clear_button = gr.Button("Clear Chat")
327
 
 
 
 
 
 
 
 
328
 
329
  with gr.Column(scale=2):
330
  chat_interface = gr.Chatbot(label="Chat History", height=500)