NikhilJoson commited on
Commit
6c390aa
·
verified ·
1 Parent(s): 2d9bb6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -132
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import numpy as np
9
  import os
10
  import time
 
11
  from Upsample import RealESRGAN
12
  import spaces # Import spaces for ZeroGPU compatibility
13
 
@@ -17,9 +18,7 @@ model_path = "deepseek-ai/Janus-Pro-7B"
17
  config = AutoConfig.from_pretrained(model_path)
18
  language_config = config.language_config
19
  language_config._attn_implementation = 'eager'
20
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
21
- language_config=language_config,
22
- trust_remote_code=True)
23
  if torch.cuda.is_available():
24
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
25
  else:
@@ -33,13 +32,80 @@ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @torch.inference_mode()
37
  @spaces.GPU(duration=120)
38
- # Multimodal Chat function with conversation history
39
- def multimodal_chat(image, message, chat_history, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
40
  # Clear CUDA cache before generating
41
  torch.cuda.empty_cache()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # set seed
44
  torch.manual_seed(seed)
45
  np.random.seed(seed)
@@ -89,18 +155,10 @@ def multimodal_chat(image, message, chat_history, seed, top_p, temperature, prog
89
 
90
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
91
 
92
- outputs = vl_gpt.language_model.generate(
93
- inputs_embeds=inputs_embeds,
94
- attention_mask=prepare_inputs.attention_mask,
95
- pad_token_id=tokenizer.eos_token_id,
96
- bos_token_id=tokenizer.bos_token_id,
97
- eos_token_id=tokenizer.eos_token_id,
98
- max_new_tokens=512,
99
- do_sample=False if temperature == 0 else True,
100
- use_cache=True,
101
- temperature=temperature,
102
- top_p=top_p,
103
- )
104
 
105
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
106
 
@@ -108,18 +166,11 @@ def multimodal_chat(image, message, chat_history, seed, top_p, temperature, prog
108
  chat_history.append((message, answer))
109
 
110
  # Keep the last uploaded image in context
111
- return "", chat_history, image
112
-
113
-
114
- def generate(input_ids,
115
- width,
116
- height,
117
- temperature: float = 1,
118
- parallel_size: int = 5,
119
- cfg_weight: float = 5,
120
- image_token_num_per_image: int = 576,
121
- patch_size: int = 16,
122
- progress=gr.Progress(track_tqdm=True)):
123
  # Clear CUDA cache before generating
124
  torch.cuda.empty_cache()
125
 
@@ -152,7 +203,7 @@ def generate(input_ids,
152
  inputs_embeds = img_embeds.unsqueeze(dim=1)
153
 
154
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
155
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
156
 
157
  return generated_tokens.to(dtype=torch.int), patches
158
 
@@ -168,12 +219,8 @@ def unpack(dec, width, height, parallel_size=5):
168
 
169
 
170
  @torch.inference_mode()
171
- @spaces.GPU(duration=120) # Specify a duration to avoid timeout
172
- def generate_image(prompt,
173
- seed=None,
174
- guidance=5,
175
- t2i_temperature=1.0,
176
- progress=gr.Progress(track_tqdm=True)):
177
  # Clear CUDA cache and avoid tracking gradients
178
  torch.cuda.empty_cache()
179
  # Set the seed for reproducible results
@@ -181,31 +228,22 @@ def generate_image(prompt,
181
  torch.manual_seed(seed)
182
  torch.cuda.manual_seed(seed)
183
  np.random.seed(seed)
184
- width = 384
185
- height = 384
186
  parallel_size = 4
187
 
188
  with torch.no_grad():
189
  messages = [{'role': '<|User|>', 'content': prompt},
190
  {'role': '<|Assistant|>', 'content': ''}]
191
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
192
- sft_format=vl_chat_processor.sft_format,
193
- system_prompt='')
194
  text = text + vl_chat_processor.image_start_tag
195
 
196
  input_ids = torch.LongTensor(tokenizer.encode(text))
197
- output, patches = generate(input_ids,
198
- width // 16 * 16,
199
- height // 16 * 16,
200
- cfg_weight=guidance,
201
- parallel_size=parallel_size,
202
- temperature=t2i_temperature)
203
- images = unpack(patches,
204
- width // 16 * 16,
205
- height // 16 * 16,
206
- parallel_size=parallel_size)
207
-
208
- # return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
209
  stime = time.time()
210
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
211
  print(f'upsample time: {time.time() - stime}')
@@ -219,7 +257,7 @@ def image_upsample(img: Image.Image) -> Image.Image:
219
 
220
  width, height = img.size
221
 
222
- if width >= 5000 or height >= 5000:
223
  raise Exception("The image is too large.")
224
 
225
  global sr_model
@@ -234,7 +272,7 @@ def add_image_to_chat(image, chat_history):
234
 
235
  # Helper function to clear chat history but maintain the image
236
  def clear_chat(image):
237
- return [], image
238
 
239
 
240
  # Gradio interface
@@ -242,98 +280,80 @@ css = '''
242
  .gradio-container {max-width: 960px !important}
243
  '''
244
  with gr.Blocks(css=css) as demo:
245
- gr.Markdown("# Janus Pro 7B")
246
-
247
- with gr.Tab("Multimodal Chat"):
248
- gr.Markdown(value="## Multimodal Chat")
249
-
250
- # State variables to maintain context
251
- chat_history = gr.State([])
252
- current_image = gr.State(None)
253
-
254
- with gr.Row():
255
- with gr.Column(scale=1):
256
- image_input = gr.Image(label="Upload Image (only needed once)")
257
- upload_button = gr.Button("Add Image to Chat")
258
-
259
- with gr.Accordion("Advanced options", open=False):
260
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
261
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
262
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
263
-
264
- clear_button = gr.Button("Clear Chat")
265
-
266
- with gr.Column(scale=2):
267
- chat_interface = gr.Chatbot(label="Chat History", height=500)
268
- message_input = gr.Textbox(label="Your message", placeholder="Ask about the image or continue the conversation...", lines=2)
269
- chat_button = gr.Button("Send")
270
-
271
- examples_chat = gr.Examples(
272
- label="Multimodal Chat examples",
273
- examples=[
274
- [
275
- "explain this meme",
276
- "doge.png",
277
- ],
278
- [
279
- "Convert the formula into latex code.",
280
- "equation.png",
281
- ],
282
- ],
283
- inputs=[message_input, image_input],
284
- )
285
-
286
- with gr.Tab("Text-to-Image Generation"):
287
- gr.Markdown(value="## Text-to-Image Generation")
288
-
289
- prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
290
 
291
- generation_button = gr.Button("Generate Images")
 
 
292
 
293
- image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
294
-
295
- with gr.Accordion("Advanced options", open=False):
296
- with gr.Row():
 
 
 
 
 
 
 
297
  cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
298
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
299
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
300
-
301
- examples_t2i = gr.Examples(
302
- label="Text to image generation examples.",
303
- examples=[
304
- "Master shifu racoon wearing drip attire as a street gangster.",
305
- "The face of a beautiful girl",
306
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
307
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
308
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
309
- ],
310
- inputs=prompt_input,
311
- )
 
 
 
 
 
 
 
312
 
313
  # Chat interface interactions
314
- upload_button.click(
315
- add_image_to_chat,
316
- inputs=[image_input, chat_history],
317
- outputs=[current_image, chat_history]
318
- )
319
 
320
  chat_button.click(
321
- multimodal_chat,
322
- inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature],
323
- outputs=[message_input, chat_interface, current_image]
 
 
 
 
 
 
 
324
  )
325
 
326
  clear_button.click(
327
  clear_chat,
328
  inputs=[current_image],
329
- outputs=[chat_interface, current_image]
330
  )
331
 
332
- # T2I interface interactions
333
- generation_button.click(
334
- fn=generate_image,
335
- inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
336
- outputs=image_output
 
 
 
 
 
 
 
337
  )
338
 
339
  demo.launch(share=True)
 
8
  import numpy as np
9
  import os
10
  import time
11
+ import re
12
  from Upsample import RealESRGAN
13
  import spaces # Import spaces for ZeroGPU compatibility
14
 
 
18
  config = AutoConfig.from_pretrained(model_path)
19
  language_config = config.language_config
20
  language_config._attn_implementation = 'eager'
21
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
 
 
22
  if torch.cuda.is_available():
23
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
24
  else:
 
32
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
33
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
34
 
35
+ # Patterns for detecting image generation requests
36
+ GENERATION_PATTERNS = [r"generate (.+)", r"create (.+)", r"draw (.+)", r"make (.+)", r"show (.+)", r"visualize (.+)", r"imagine (.+)", r"picture (.+)",]
37
+
38
+ def is_generation_request(message):
39
+ """Determine if a message is requesting image generation"""
40
+ message = message.lower().strip()
41
+
42
+ # Check if message explicitly mentions image generation
43
+ for pattern in GENERATION_PATTERNS:
44
+ match = re.match(pattern, message, re.IGNORECASE)
45
+ if match:
46
+ return True, match.group(1)
47
+
48
+ # Check for specific keywords suggesting image generation
49
+ image_keywords = ["image", "picture", "photo", "artwork", "illustration", "painting", "drawing"]
50
+ generation_verbs = ["generate", "create", "make", "produce", "show me", "draw"]
51
+
52
+ for verb in generation_verbs:
53
+ for keyword in image_keywords:
54
+ if f"{verb} {keyword}" in message or f"{verb} an {keyword}" in message or f"{verb} a {keyword}" in message:
55
+ # Extract the prompt (everything after the keyword)
56
+ pattern = f"{verb}\\s+(?:an?\\s+)?{keyword}\\s+(?:of|showing|depicting|with)?\\s*(.*)"
57
+ match = re.search(pattern, message, re.IGNORECASE)
58
+ if match and match.group(1):
59
+ return True, match.group(1)
60
+ else:
61
+ # If we can't extract a specific prompt, use the whole message
62
+ return True, message
63
+
64
+ return False, None
65
+
66
+
67
  @torch.inference_mode()
68
  @spaces.GPU(duration=120)
69
+ # Unified chat function that handles both image understanding and generation
70
+ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_weight, t2i_temperature, progress=gr.Progress(track_tqdm=True)):
71
  # Clear CUDA cache before generating
72
  torch.cuda.empty_cache()
73
 
74
+ # Check if this is an image generation request
75
+ is_gen_request, extracted_prompt = is_generation_request(message)
76
+
77
+ if is_gen_request:
78
+ # Prepare a more detailed prompt by considering context from the conversation
79
+ context_prompt = extracted_prompt
80
+
81
+ # Optionally, enhance the prompt with context from previous messages
82
+ if chat_history and len(chat_history) > 0:
83
+ # Get the last few turns of conversation for context (limit to last 3 turns)
84
+ recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
85
+ context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
86
+
87
+ # Only use context if it's not too long
88
+ if len(context_text) < 200: # Arbitrary length limit
89
+ context_prompt = f"{context_text}. {extracted_prompt}"
90
+
91
+ # Generate images
92
+ generated_images = generate_image(
93
+ prompt=context_prompt,
94
+ seed=seed,
95
+ guidance=cfg_weight,
96
+ t2i_temperature=t2i_temperature
97
+ )
98
+
99
+ # Create a response that includes the generated images
100
+ response = f"I've generated the following images based on: '{extracted_prompt}'"
101
+
102
+ # Add the images to the chat as the bot's response
103
+ chat_history.append((message, response))
104
+
105
+ # Return the message, updated history, maintained image context, and generated images
106
+ return "", chat_history, image, generated_images
107
+
108
+ # Regular chat flow (no image generation)
109
  # set seed
110
  torch.manual_seed(seed)
111
  np.random.seed(seed)
 
155
 
156
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
157
 
158
+ outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
159
+ pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
160
+ eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
161
+ do_sample=False if temperature == 0 else True, use_cache=True,)
 
 
 
 
 
 
 
 
162
 
163
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
164
 
 
166
  chat_history.append((message, answer))
167
 
168
  # Keep the last uploaded image in context
169
+ return "", chat_history, image, None
170
+
171
+
172
+ def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
173
+ image_token_num_per_image: int = 1024, patch_size: int = 16, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
174
  # Clear CUDA cache before generating
175
  torch.cuda.empty_cache()
176
 
 
203
  inputs_embeds = img_embeds.unsqueeze(dim=1)
204
 
205
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
206
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
207
 
208
  return generated_tokens.to(dtype=torch.int), patches
209
 
 
219
 
220
 
221
  @torch.inference_mode()
222
+ @spaces.GPU(duration=180) # Specify a duration to avoid timeout
223
+ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
224
  # Clear CUDA cache and avoid tracking gradients
225
  torch.cuda.empty_cache()
226
  # Set the seed for reproducible results
 
228
  torch.manual_seed(seed)
229
  torch.cuda.manual_seed(seed)
230
  np.random.seed(seed)
231
+ width = 512
232
+ height = 512
233
  parallel_size = 4
234
 
235
  with torch.no_grad():
236
  messages = [{'role': '<|User|>', 'content': prompt},
237
  {'role': '<|Assistant|>', 'content': ''}]
238
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
239
+ sft_format=vl_chat_processor.sft_format, system_prompt='')
 
240
  text = text + vl_chat_processor.image_start_tag
241
 
242
  input_ids = torch.LongTensor(tokenizer.encode(text))
243
+ output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
244
+ parallel_size=parallel_size, temperature=t2i_temperature)
245
+ images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
246
+
 
 
 
 
 
 
 
 
247
  stime = time.time()
248
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
249
  print(f'upsample time: {time.time() - stime}')
 
257
 
258
  width, height = img.size
259
 
260
+ if width >= 4096 or height >= 4096:
261
  raise Exception("The image is too large.")
262
 
263
  global sr_model
 
272
 
273
  # Helper function to clear chat history but maintain the image
274
  def clear_chat(image):
275
+ return [], image, None
276
 
277
 
278
  # Gradio interface
 
280
  .gradio-container {max-width: 960px !important}
281
  '''
282
  with gr.Blocks(css=css) as demo:
283
+ gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ # State variables to maintain context
286
+ chat_history = gr.State([])
287
+ current_image = gr.State(None)
288
 
289
+ with gr.Row():
290
+ with gr.Column(scale=1):
291
+ image_input = gr.Image(label="Upload Image (optional)")
292
+ upload_button = gr.Button("Add Image to Chat")
293
+
294
+ with gr.Accordion("Chat Options", open=False):
295
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
296
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
297
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
298
+
299
+ with gr.Accordion("Image Generation Options", open=False):
300
  cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
301
+ t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
302
+
303
+ clear_button = gr.Button("Clear Chat")
304
+
305
+ gr.Markdown("""
306
+ ### Tips:
307
+ 1. Upload an image to discuss it
308
+ 2. Type commands like "generate [description]" to create images
309
+ 3. Continue chatting about uploaded or generated images
310
+ 4. Use natural language like "show me a sunset" or "create a portrait"
311
+ """)
312
+
313
+ with gr.Column(scale=2):
314
+ chat_interface = gr.Chatbot(label="Chat History", height=500)
315
+ message_input = gr.Textbox(
316
+ label="Your message",
317
+ placeholder="Ask about an image, continue chatting, or generate new images by typing 'generate [description]'",
318
+ lines=2
319
+ )
320
+ chat_button = gr.Button("Send")
321
+ generated_images = gr.Gallery(label="Generated Images", visible=True, columns=2, rows=2)
322
 
323
  # Chat interface interactions
324
+ upload_button.click(add_image_to_chat, inputs=[image_input, chat_history], outputs=[current_image, chat_history])
 
 
 
 
325
 
326
  chat_button.click(
327
+ unified_chat,
328
+ inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature, cfg_weight_input, t2i_temperature_input],
329
+ outputs=[message_input, chat_interface, current_image, generated_images]
330
+ )
331
+
332
+ # Also trigger on Enter key
333
+ message_input.submit(
334
+ unified_chat,
335
+ inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature, cfg_weight_input, t2i_temperature_input],
336
+ outputs=[message_input, chat_interface, current_image, generated_images]
337
  )
338
 
339
  clear_button.click(
340
  clear_chat,
341
  inputs=[current_image],
342
+ outputs=[chat_interface, current_image, generated_images]
343
  )
344
 
345
+ # Examples for the unified interface
346
+ examples = gr.Examples(
347
+ label="Example queries",
348
+ examples=[
349
+ ["What's in this image?"],
350
+ ["Generate a cute kitten with big eyes"],
351
+ ["Show me a mountain landscape at sunset"],
352
+ ["Can you explain what's happening in this picture?"],
353
+ ["Create an astronaut riding a horse"],
354
+ ["Generate a futuristic cityscape with flying cars"],
355
+ ],
356
+ inputs=message_input,
357
  )
358
 
359
  demo.launch(share=True)