NikhilJoson commited on
Commit
2d9bb6e
·
verified ·
1 Parent(s): d889a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -40
app.py CHANGED
@@ -11,12 +11,15 @@ import time
11
  from Upsample import RealESRGAN
12
  import spaces # Import spaces for ZeroGPU compatibility
13
 
 
14
  # Load model and processor
15
  model_path = "deepseek-ai/Janus-Pro-7B"
16
  config = AutoConfig.from_pretrained(model_path)
17
  language_config = config.language_config
18
  language_config._attn_implementation = 'eager'
19
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
 
 
20
  if torch.cuda.is_available():
21
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
22
  else:
@@ -30,11 +33,10 @@ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
31
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
32
 
33
- last_uploaded_image = None
34
-
35
  @torch.inference_mode()
36
  @spaces.GPU(duration=120)
37
- def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
 
38
  # Clear CUDA cache before generating
39
  torch.cuda.empty_cache()
40
 
@@ -43,56 +45,295 @@ def multimodal_understanding(image, question, seed, top_p, temperature, progress
43
  np.random.seed(seed)
44
  torch.cuda.manual_seed(seed)
45
 
46
- conversation = [
47
- {
48
- "role": "<|User |>",
49
- "content": f"<image_placeholder>\n{question}",
50
- "images": [image],
51
- },
52
- {"role": "<|Assistant|>", "content": ""},
53
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- pil_images = [Image.fromarray(image)]
56
- prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True
57
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
58
 
59
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
60
 
61
- outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
62
- pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
63
- eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
64
- do_sample=False if temperature == 0 else True, use_cache=True,)
 
 
 
 
 
 
 
 
65
 
66
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
67
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Gradio interface
70
  css = '''
71
  .gradio-container {max-width: 960px !important}
72
  '''
73
  with gr.Blocks(css=css) as demo:
74
- gr.Markdown("# Janus Pro 7B Chat Interface")
75
-
76
- chat_history = gr.Chatbot(label="Chat History")
77
- message_input = gr.Textbox(label="Type your message here...")
78
- image_input = gr.Image(label="Upload an image (optional)")
79
-
80
- def respond(message, image):
81
- global last_uploaded_image
82
- if image is not None:
83
- last_uploaded_image = image # Update the last uploaded image
84
- response = multimodal_understanding(image, message, seed=42, top_p=0.95, temperature=0.1)
85
- elif last_uploaded_image is not None:
86
- response = multimodal_understanding(last_uploaded_image, message, seed=42, top_p=0.95, temperature=0.1)
87
- else:
88
- response = "Please provide an image for multimodal understanding."
89
 
90
- return message, response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def submit_message(message, image):
93
- response = respond(message, image)
94
- return message, response
 
 
95
 
96
- message_input.submit(submit_message, inputs=[message_input, image_input], outputs=[message_input, chat_history])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  demo.launch(share=True)
 
11
  from Upsample import RealESRGAN
12
  import spaces # Import spaces for ZeroGPU compatibility
13
 
14
+
15
  # Load model and processor
16
  model_path = "deepseek-ai/Janus-Pro-7B"
17
  config = AutoConfig.from_pretrained(model_path)
18
  language_config = config.language_config
19
  language_config._attn_implementation = 'eager'
20
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
21
+ language_config=language_config,
22
+ trust_remote_code=True)
23
  if torch.cuda.is_available():
24
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
25
  else:
 
33
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
 
 
 
36
  @torch.inference_mode()
37
  @spaces.GPU(duration=120)
38
+ # Multimodal Chat function with conversation history
39
+ def multimodal_chat(image, message, chat_history, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
40
  # Clear CUDA cache before generating
41
  torch.cuda.empty_cache()
42
 
 
45
  np.random.seed(seed)
46
  torch.cuda.manual_seed(seed)
47
 
48
+ # Process the conversation history and add current message
49
+ conversation = []
50
+
51
+ # Check if we have existing history
52
+ if chat_history:
53
+ # Add previous conversation turns
54
+ for user_msg, assistant_msg in chat_history:
55
+ conversation.append({
56
+ "role": "<|User|>",
57
+ "content": user_msg,
58
+ "images": [], # No images for previous turns
59
+ })
60
+ conversation.append({
61
+ "role": "<|Assistant|>",
62
+ "content": assistant_msg,
63
+ })
64
+
65
+ # Add the current user message with image (if provided)
66
+ user_content = message
67
+ images_list = []
68
+
69
+ # Only include image placeholder if image is provided or this is the first message
70
+ if image is not None:
71
+ user_content = f"<image_placeholder>\n{message}"
72
+ images_list = [image]
73
+
74
+ conversation.append({
75
+ "role": "<|User|>",
76
+ "content": user_content,
77
+ "images": images_list,
78
+ })
79
+ conversation.append({"role": "<|Assistant|>", "content": ""})
80
+
81
+ # Process images (if any)
82
+ pil_images = []
83
+ if image is not None:
84
+ pil_images = [Image.fromarray(image)]
85
 
86
+ prepare_inputs = vl_chat_processor(
87
+ conversations=conversation, images=pil_images, force_batchify=True
88
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
89
 
90
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
91
 
92
+ outputs = vl_gpt.language_model.generate(
93
+ inputs_embeds=inputs_embeds,
94
+ attention_mask=prepare_inputs.attention_mask,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ bos_token_id=tokenizer.bos_token_id,
97
+ eos_token_id=tokenizer.eos_token_id,
98
+ max_new_tokens=512,
99
+ do_sample=False if temperature == 0 else True,
100
+ use_cache=True,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ )
104
 
105
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
106
+
107
+ # Update chat history
108
+ chat_history.append((message, answer))
109
+
110
+ # Keep the last uploaded image in context
111
+ return "", chat_history, image
112
+
113
+
114
+ def generate(input_ids,
115
+ width,
116
+ height,
117
+ temperature: float = 1,
118
+ parallel_size: int = 5,
119
+ cfg_weight: float = 5,
120
+ image_token_num_per_image: int = 576,
121
+ patch_size: int = 16,
122
+ progress=gr.Progress(track_tqdm=True)):
123
+ # Clear CUDA cache before generating
124
+ torch.cuda.empty_cache()
125
+
126
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
127
+ for i in range(parallel_size * 2):
128
+ tokens[i, :] = input_ids
129
+ if i % 2 != 0:
130
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
131
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
132
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
133
+
134
+ pkv = None
135
+ for i in range(image_token_num_per_image):
136
+ with torch.no_grad():
137
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
138
+ use_cache=True,
139
+ past_key_values=pkv)
140
+ pkv = outputs.past_key_values
141
+ hidden_states = outputs.last_hidden_state
142
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
143
+ logit_cond = logits[0::2, :]
144
+ logit_uncond = logits[1::2, :]
145
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
146
+ probs = torch.softmax(logits / temperature, dim=-1)
147
+ next_token = torch.multinomial(probs, num_samples=1)
148
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
149
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
150
+
151
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
152
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
153
+
154
+ patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
155
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
156
+
157
+ return generated_tokens.to(dtype=torch.int), patches
158
+
159
+ def unpack(dec, width, height, parallel_size=5):
160
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
161
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
162
+
163
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
164
+ visual_img[:, :, :] = dec
165
+
166
+ return visual_img
167
+
168
+
169
+
170
+ @torch.inference_mode()
171
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
172
+ def generate_image(prompt,
173
+ seed=None,
174
+ guidance=5,
175
+ t2i_temperature=1.0,
176
+ progress=gr.Progress(track_tqdm=True)):
177
+ # Clear CUDA cache and avoid tracking gradients
178
+ torch.cuda.empty_cache()
179
+ # Set the seed for reproducible results
180
+ if seed is not None:
181
+ torch.manual_seed(seed)
182
+ torch.cuda.manual_seed(seed)
183
+ np.random.seed(seed)
184
+ width = 384
185
+ height = 384
186
+ parallel_size = 4
187
+
188
+ with torch.no_grad():
189
+ messages = [{'role': '<|User|>', 'content': prompt},
190
+ {'role': '<|Assistant|>', 'content': ''}]
191
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
192
+ sft_format=vl_chat_processor.sft_format,
193
+ system_prompt='')
194
+ text = text + vl_chat_processor.image_start_tag
195
+
196
+ input_ids = torch.LongTensor(tokenizer.encode(text))
197
+ output, patches = generate(input_ids,
198
+ width // 16 * 16,
199
+ height // 16 * 16,
200
+ cfg_weight=guidance,
201
+ parallel_size=parallel_size,
202
+ temperature=t2i_temperature)
203
+ images = unpack(patches,
204
+ width // 16 * 16,
205
+ height // 16 * 16,
206
+ parallel_size=parallel_size)
207
+
208
+ # return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
209
+ stime = time.time()
210
+ ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
211
+ print(f'upsample time: {time.time() - stime}')
212
+ return ret_images
213
+
214
+
215
+ @spaces.GPU(duration=60)
216
+ def image_upsample(img: Image.Image) -> Image.Image:
217
+ if img is None:
218
+ raise Exception("Image not uploaded")
219
+
220
+ width, height = img.size
221
+
222
+ if width >= 5000 or height >= 5000:
223
+ raise Exception("The image is too large.")
224
+
225
+ global sr_model
226
+ result = sr_model.predict(img.convert('RGB'))
227
+ return result
228
+
229
+
230
+ # Helper function to add uploaded image to the chat context
231
+ def add_image_to_chat(image, chat_history):
232
+ return image, chat_history
233
+
234
+
235
+ # Helper function to clear chat history but maintain the image
236
+ def clear_chat(image):
237
+ return [], image
238
+
239
 
240
  # Gradio interface
241
  css = '''
242
  .gradio-container {max-width: 960px !important}
243
  '''
244
  with gr.Blocks(css=css) as demo:
245
+ gr.Markdown("# Janus Pro 7B")
246
+
247
+ with gr.Tab("Multimodal Chat"):
248
+ gr.Markdown(value="## Multimodal Chat")
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ # State variables to maintain context
251
+ chat_history = gr.State([])
252
+ current_image = gr.State(None)
253
+
254
+ with gr.Row():
255
+ with gr.Column(scale=1):
256
+ image_input = gr.Image(label="Upload Image (only needed once)")
257
+ upload_button = gr.Button("Add Image to Chat")
258
+
259
+ with gr.Accordion("Advanced options", open=False):
260
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
261
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
262
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
263
+
264
+ clear_button = gr.Button("Clear Chat")
265
+
266
+ with gr.Column(scale=2):
267
+ chat_interface = gr.Chatbot(label="Chat History", height=500)
268
+ message_input = gr.Textbox(label="Your message", placeholder="Ask about the image or continue the conversation...", lines=2)
269
+ chat_button = gr.Button("Send")
270
+
271
+ examples_chat = gr.Examples(
272
+ label="Multimodal Chat examples",
273
+ examples=[
274
+ [
275
+ "explain this meme",
276
+ "doge.png",
277
+ ],
278
+ [
279
+ "Convert the formula into latex code.",
280
+ "equation.png",
281
+ ],
282
+ ],
283
+ inputs=[message_input, image_input],
284
+ )
285
+
286
+ with gr.Tab("Text-to-Image Generation"):
287
+ gr.Markdown(value="## Text-to-Image Generation")
288
 
289
+ prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
290
+
291
+ generation_button = gr.Button("Generate Images")
292
+
293
+ image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
294
 
295
+ with gr.Accordion("Advanced options", open=False):
296
+ with gr.Row():
297
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
298
+ t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
299
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
300
+
301
+ examples_t2i = gr.Examples(
302
+ label="Text to image generation examples.",
303
+ examples=[
304
+ "Master shifu racoon wearing drip attire as a street gangster.",
305
+ "The face of a beautiful girl",
306
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
307
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
308
+ "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
309
+ ],
310
+ inputs=prompt_input,
311
+ )
312
+
313
+ # Chat interface interactions
314
+ upload_button.click(
315
+ add_image_to_chat,
316
+ inputs=[image_input, chat_history],
317
+ outputs=[current_image, chat_history]
318
+ )
319
+
320
+ chat_button.click(
321
+ multimodal_chat,
322
+ inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature],
323
+ outputs=[message_input, chat_interface, current_image]
324
+ )
325
+
326
+ clear_button.click(
327
+ clear_chat,
328
+ inputs=[current_image],
329
+ outputs=[chat_interface, current_image]
330
+ )
331
+
332
+ # T2I interface interactions
333
+ generation_button.click(
334
+ fn=generate_image,
335
+ inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
336
+ outputs=image_output
337
+ )
338
 
339
  demo.launch(share=True)