NikhilJoson commited on
Commit
e0402e9
ยท
verified ยท
1 Parent(s): 7afe207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -278
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM, pipeline as translation_pipeline
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
@@ -11,290 +11,58 @@ from Upsample import RealESRGAN
11
  import spaces # Import spaces for ZeroGPU compatibility
12
  import re
13
 
14
- # ๋ฒˆ์—ญ ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (ํ•œ๊ธ€ โ†’ ์˜์–ด)
15
- translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
16
 
17
- def translate_if_korean(prompt: str) -> str:
18
- """ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
19
- if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ๊ฐ€-ํžฃ]', prompt):
20
- try:
21
- translation = translator(prompt)[0]['translation_text']
22
- return translation
23
- except Exception as e:
24
- print(f"Translation error: {e}")
25
- return prompt
26
- return prompt
27
 
28
- # Load model and processor
29
- model_path = "deepseek-ai/Janus-Pro-7B"
30
- config = AutoConfig.from_pretrained(model_path)
31
- language_config = config.language_config
32
- language_config._attn_implementation = 'eager'
33
- vl_gpt = AutoModelForCausalLM.from_pretrained(
34
- model_path,
35
- language_config=language_config,
36
- trust_remote_code=True
37
- )
38
- if torch.cuda.is_available():
39
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
40
- else:
41
- vl_gpt = vl_gpt.to(torch.float16)
42
 
43
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
44
- tokenizer = vl_chat_processor.tokenizer
45
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
46
 
47
- # SR model
48
- sr_model = RealESRGAN(torch.device(cuda_device), scale=2)
49
- sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
50
-
51
- @torch.inference_mode()
52
- @spaces.GPU(duration=120)
53
- def multimodal_understanding(image, question, seed, top_p, temperature):
54
- # (์ƒ๋žต) ๊ธฐ์กด multimodal ์ดํ•ด ํ•จ์ˆ˜ ๋‚ด์šฉ ๊ทธ๋Œ€๋กœ...
55
- torch.cuda.empty_cache()
56
- torch.manual_seed(seed)
57
- np.random.seed(seed)
58
- torch.cuda.manual_seed(seed)
59
-
60
- conversation = [
61
- {
62
- "role": "<|User|>",
63
- "content": f"<image_placeholder>\n{question}",
64
- "images": [image],
65
- },
66
- {"role": "<|Assistant|>", "content": ""},
67
- ]
68
-
69
- pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image]
70
- prepare_inputs = vl_chat_processor(
71
- conversations=conversation, images=pil_images, force_batchify=True
72
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
73
-
74
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
75
 
76
- outputs = vl_gpt.language_model.generate(
77
- inputs_embeds=inputs_embeds,
78
- attention_mask=prepare_inputs.attention_mask,
79
- pad_token_id=tokenizer.eos_token_id,
80
- bos_token_id=tokenizer.bos_token_id,
81
- eos_token_id=tokenizer.eos_token_id,
82
- max_new_tokens=512,
83
- do_sample=False if temperature == 0 else True,
84
- use_cache=True,
85
- temperature=temperature,
86
- top_p=top_p,
87
- )
88
 
89
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
90
- return answer
91
-
92
- def generate(input_ids, width, height, temperature: float = 1,
93
- parallel_size: int = 5, cfg_weight: float = 5,
94
- image_token_num_per_image: int = 576, patch_size: int = 16):
95
- torch.cuda.empty_cache()
96
-
97
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
98
- for i in range(parallel_size * 2):
99
- tokens[i, :] = input_ids
100
- if i % 2 != 0:
101
- tokens[i, 1:-1] = vl_chat_processor.pad_id
102
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
103
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
104
-
105
- pkv = None
106
- for i in range(image_token_num_per_image):
107
- with torch.no_grad():
108
- outputs = vl_gpt.language_model.model(
109
- inputs_embeds=inputs_embeds,
110
- use_cache=True,
111
- past_key_values=pkv
112
- )
113
- pkv = outputs.past_key_values
114
- hidden_states = outputs.last_hidden_state
115
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
116
- logit_cond = logits[0::2, :]
117
- logit_uncond = logits[1::2, :]
118
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
119
- probs = torch.softmax(logits / temperature, dim=-1)
120
- next_token = torch.multinomial(probs, num_samples=1)
121
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
122
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
123
-
124
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
125
- inputs_embeds = img_embeds.unsqueeze(dim=1)
126
 
127
- patches = vl_gpt.gen_vision_model.decode_code(
128
- generated_tokens.to(dtype=torch.int),
129
- shape=[parallel_size, 8, width // patch_size, height // patch_size]
130
- )
131
- return generated_tokens.to(dtype=torch.int), patches
132
-
133
- def unpack(dec, width, height, parallel_size=5):
134
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
135
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
136
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
137
- visual_img[:, :, :] = dec
138
- return visual_img
139
-
140
- @torch.inference_mode()
141
- @spaces.GPU(duration=120)
142
- def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
143
- # ๋ฒˆ์—ญ: ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ณ€ํ™˜
144
- prompt = translate_if_korean(prompt)
145
 
146
- torch.cuda.empty_cache()
147
- if seed is not None:
148
- torch.manual_seed(seed)
149
- torch.cuda.manual_seed(seed)
150
- np.random.seed(seed)
151
- width = 384
152
- height = 384
153
- parallel_size = 5
 
 
 
 
 
 
 
 
 
154
 
155
- with torch.no_grad():
156
- messages = [{'role': '<|User|>', 'content': prompt},
157
- {'role': '<|Assistant|>', 'content': ''}]
158
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
159
- conversations=messages,
160
- sft_format=vl_chat_processor.sft_format,
161
- system_prompt=''
162
- )
163
- text = text + vl_chat_processor.image_start_tag
164
- input_ids = torch.LongTensor(tokenizer.encode(text))
165
- output, patches = generate(
166
- input_ids,
167
- width // 16 * 16,
168
- height // 16 * 16,
169
- cfg_weight=guidance,
170
- parallel_size=parallel_size,
171
- temperature=t2i_temperature
172
- )
173
- images = unpack(
174
- patches,
175
- width // 16 * 16,
176
- height // 16 * 16,
177
- parallel_size=parallel_size
178
- )
179
-
180
- stime = time.time()
181
- ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
182
- print(f'upsample time: {time.time() - stime}')
183
- return ret_images
184
-
185
- @spaces.GPU(duration=60)
186
- def image_upsample(img: Image.Image) -> Image.Image:
187
- if img is None:
188
- raise Exception("Image not uploaded")
189
- width, height = img.size
190
- if width >= 5000 or height >= 5000:
191
- raise Exception("The image is too large.")
192
- global sr_model
193
- result = sr_model.predict(img.convert('RGB'))
194
- return result
195
-
196
- # Custom CSS for a sleek, modern and highly readable interface
197
- custom_css = """
198
- body {
199
- background: #f0f2f5;
200
- font-family: 'Segoe UI', sans-serif;
201
- color: #333;
202
- }
203
- h1, h2, h3 {
204
- font-weight: 600;
205
- }
206
- .gradio-container {
207
- padding: 20px;
208
- }
209
- header {
210
- text-align: center;
211
- padding: 20px;
212
- margin-bottom: 20px;
213
- }
214
- header h1 {
215
- font-size: 3em;
216
- color: #2c3e50;
217
- }
218
- .gr-button {
219
- background-color: #3498db !important;
220
- color: #fff !important;
221
- border: none !important;
222
- padding: 10px 20px !important;
223
- border-radius: 5px !important;
224
- font-size: 1em !important;
225
- }
226
- .gr-button:hover {
227
- background-color: #2980b9 !important;
228
- }
229
- .gr-input, .gr-slider, .gr-number, .gr-textbox {
230
- border-radius: 5px;
231
- }
232
- .gr-gallery-item {
233
- border-radius: 10px;
234
- overflow: hidden;
235
- box-shadow: 0 2px 10px rgba(0,0,0,0.1);
236
- }
237
- """
238
-
239
- # Gradio Interface
240
- with gr.Blocks(css=custom_css, title="Multimodal & T2I") as demo:
241
- with gr.Column(variant="panel"):
242
- gr.Markdown("<header><h1>Chat With Janus-Pro-7B</h1></header>")
243
-
244
- with gr.Tabs():
245
- with gr.TabItem("Multimodal Understanding"):
246
- gr.Markdown("### Chat with Images")
247
- with gr.Row():
248
- image_input = gr.Image(label="Upload Image", type="numpy")
249
- with gr.Column():
250
- question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4)
251
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
252
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p")
253
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
254
- understanding_button = gr.Button("Chat", elem_id="understanding-button")
255
- understanding_output = gr.Textbox(label="Response", lines=6)
256
- with gr.Accordion("Examples", open=False):
257
- gr.Examples(
258
- label="Multimodal Understanding Examples",
259
- examples=[
260
- ["explain this meme", "doge.png"]
261
- ],
262
- inputs=[question_input, image_input],
263
- )
264
- understanding_button.click(
265
- multimodal_understanding,
266
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
267
- outputs=understanding_output,
268
- )
269
-
270
- with gr.TabItem("Text-to-Image Generation"):
271
- gr.Markdown("### Generate Images from Text")
272
- with gr.Row():
273
- prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4)
274
- with gr.Row():
275
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
276
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
277
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
278
- generation_button = gr.Button("Generate Images", elem_id="generation-button")
279
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
280
- with gr.Accordion("Examples", open=False):
281
- gr.Examples(
282
- label="Text-to-Image Examples",
283
- examples=[
284
- "Master shifu racoon wearing drip attire as a street gangster.",
285
- "The face of a beautiful girl",
286
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
287
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.",
288
- "๊ณ ์–‘์ด๊ฐ€ ์šฐ์ฃผ๋ณต์„ ์ž…๊ณ  ๋‹ฌ์— ์žˆ๋Š” ๋ชจ์Šต"
289
- ],
290
- inputs=prompt_input,
291
- )
292
- generation_button.click(
293
- fn=generate_image,
294
- inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
295
- outputs=image_output,
296
- )
297
-
298
- gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
299
 
300
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoModelForVision2Seq, pipeline
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
 
11
  import spaces # Import spaces for ZeroGPU compatibility
12
  import re
13
 
 
 
14
 
15
+ # Load Janus Pro models for vision and text tasks
16
+ vision_model = AutoModelForVision2Seq.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
17
+ text_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
18
+ processor = AutoProcessor.from_pretrained("deepseek-ai/janus-pro-7b")
19
+ image_pipe = pipeline("text-to-image", model="deepseek-ai/janus-pro-7b")
 
 
 
 
 
20
 
21
+ last_uploaded_image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def detect_image_request(user_input):
24
+ image_keywords = ["generate an image", "create an image", "show me a picture", "draw", "visualize","generate image",
25
+ "image generation", "get me an image", "get an image", "need an image", "need image",]
26
+ return any(re.search(keyword, user_input, re.IGNORECASE) for keyword in image_keywords)
27
 
28
+ def chatbot(history, image=None, user_input=""):
29
+ global last_uploaded_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ if image:
32
+ last_uploaded_image = image # Store the latest uploaded image
 
 
 
 
 
 
 
 
 
 
33
 
34
+ if detect_image_request(user_input):
35
+ image = image_pipe(user_input)
36
+ history.append((user_input, "[Generated Image]"))
37
+ return history, "", image[0]["image"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ if last_uploaded_image:
40
+ inputs = processor(images=last_uploaded_image, return_tensors="pt").to("cuda")
41
+ output = vision_model.generate(**inputs)
42
+ response = processor.decode(output[0], skip_special_tokens=True)
43
+ else:
44
+ response = text_model.generate(user_input)
45
+ response = processor.decode(response[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ history.append((user_input, response))
48
+ return history, "", None
49
+
50
+ def reset_chat():
51
+ global last_uploaded_image
52
+ last_uploaded_image = None
53
+ return [], ""
54
+
55
+ with gr.Blocks() as demo:
56
+ gr.Markdown("# Janus Pro Chatbot with Vision & Image Generation")
57
+ chatbot_interface = gr.Chatbot()
58
+ with gr.Row():
59
+ image_input = gr.Image(type="pil", label="Upload Image")
60
+ text_input = gr.Textbox(label="Type your message")
61
+ send_button = gr.Button("Send")
62
+ reset_button = gr.Button("Reset Chat")
63
+ image_output = gr.Image(label="Generated Image")
64
 
65
+ send_button.click(chatbot, [chatbot_interface, image_input, text_input], [chatbot_interface, text_input, image_output])
66
+ reset_button.click(reset_chat, [], [chatbot_interface, text_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ demo.launch()