Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
|
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import time
|
|
|
11 |
from Upsample import RealESRGAN
|
12 |
import spaces # Import spaces for ZeroGPU compatibility
|
13 |
|
@@ -17,9 +18,7 @@ model_path = "deepseek-ai/Janus-Pro-7B"
|
|
17 |
config = AutoConfig.from_pretrained(model_path)
|
18 |
language_config = config.language_config
|
19 |
language_config._attn_implementation = 'eager'
|
20 |
-
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
21 |
-
language_config=language_config,
|
22 |
-
trust_remote_code=True)
|
23 |
if torch.cuda.is_available():
|
24 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
25 |
else:
|
@@ -33,13 +32,80 @@ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
33 |
sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
|
34 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
@torch.inference_mode()
|
37 |
@spaces.GPU(duration=120)
|
38 |
-
#
|
39 |
-
def
|
40 |
# Clear CUDA cache before generating
|
41 |
torch.cuda.empty_cache()
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# set seed
|
44 |
torch.manual_seed(seed)
|
45 |
np.random.seed(seed)
|
@@ -89,18 +155,10 @@ def multimodal_chat(image, message, chat_history, seed, top_p, temperature, prog
|
|
89 |
|
90 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
91 |
|
92 |
-
outputs = vl_gpt.language_model.generate(
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
bos_token_id=tokenizer.bos_token_id,
|
97 |
-
eos_token_id=tokenizer.eos_token_id,
|
98 |
-
max_new_tokens=512,
|
99 |
-
do_sample=False if temperature == 0 else True,
|
100 |
-
use_cache=True,
|
101 |
-
temperature=temperature,
|
102 |
-
top_p=top_p,
|
103 |
-
)
|
104 |
|
105 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
106 |
|
@@ -108,18 +166,11 @@ def multimodal_chat(image, message, chat_history, seed, top_p, temperature, prog
|
|
108 |
chat_history.append((message, answer))
|
109 |
|
110 |
# Keep the last uploaded image in context
|
111 |
-
return "", chat_history, image
|
112 |
-
|
113 |
-
|
114 |
-
def generate(input_ids,
|
115 |
-
|
116 |
-
height,
|
117 |
-
temperature: float = 1,
|
118 |
-
parallel_size: int = 5,
|
119 |
-
cfg_weight: float = 5,
|
120 |
-
image_token_num_per_image: int = 576,
|
121 |
-
patch_size: int = 16,
|
122 |
-
progress=gr.Progress(track_tqdm=True)):
|
123 |
# Clear CUDA cache before generating
|
124 |
torch.cuda.empty_cache()
|
125 |
|
@@ -152,7 +203,7 @@ def generate(input_ids,
|
|
152 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
153 |
|
154 |
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
155 |
-
|
156 |
|
157 |
return generated_tokens.to(dtype=torch.int), patches
|
158 |
|
@@ -168,12 +219,8 @@ def unpack(dec, width, height, parallel_size=5):
|
|
168 |
|
169 |
|
170 |
@torch.inference_mode()
|
171 |
-
@spaces.GPU(duration=
|
172 |
-
def generate_image(prompt,
|
173 |
-
seed=None,
|
174 |
-
guidance=5,
|
175 |
-
t2i_temperature=1.0,
|
176 |
-
progress=gr.Progress(track_tqdm=True)):
|
177 |
# Clear CUDA cache and avoid tracking gradients
|
178 |
torch.cuda.empty_cache()
|
179 |
# Set the seed for reproducible results
|
@@ -181,31 +228,22 @@ def generate_image(prompt,
|
|
181 |
torch.manual_seed(seed)
|
182 |
torch.cuda.manual_seed(seed)
|
183 |
np.random.seed(seed)
|
184 |
-
width =
|
185 |
-
height =
|
186 |
parallel_size = 4
|
187 |
|
188 |
with torch.no_grad():
|
189 |
messages = [{'role': '<|User|>', 'content': prompt},
|
190 |
{'role': '<|Assistant|>', 'content': ''}]
|
191 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
192 |
-
|
193 |
-
system_prompt='')
|
194 |
text = text + vl_chat_processor.image_start_tag
|
195 |
|
196 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
197 |
-
output, patches = generate(input_ids,
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
parallel_size=parallel_size,
|
202 |
-
temperature=t2i_temperature)
|
203 |
-
images = unpack(patches,
|
204 |
-
width // 16 * 16,
|
205 |
-
height // 16 * 16,
|
206 |
-
parallel_size=parallel_size)
|
207 |
-
|
208 |
-
# return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
|
209 |
stime = time.time()
|
210 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
211 |
print(f'upsample time: {time.time() - stime}')
|
@@ -219,7 +257,7 @@ def image_upsample(img: Image.Image) -> Image.Image:
|
|
219 |
|
220 |
width, height = img.size
|
221 |
|
222 |
-
if width >=
|
223 |
raise Exception("The image is too large.")
|
224 |
|
225 |
global sr_model
|
@@ -234,7 +272,7 @@ def add_image_to_chat(image, chat_history):
|
|
234 |
|
235 |
# Helper function to clear chat history but maintain the image
|
236 |
def clear_chat(image):
|
237 |
-
return [], image
|
238 |
|
239 |
|
240 |
# Gradio interface
|
@@ -242,98 +280,80 @@ css = '''
|
|
242 |
.gradio-container {max-width: 960px !important}
|
243 |
'''
|
244 |
with gr.Blocks(css=css) as demo:
|
245 |
-
gr.Markdown("# Janus Pro 7B")
|
246 |
-
|
247 |
-
with gr.Tab("Multimodal Chat"):
|
248 |
-
gr.Markdown(value="## Multimodal Chat")
|
249 |
-
|
250 |
-
# State variables to maintain context
|
251 |
-
chat_history = gr.State([])
|
252 |
-
current_image = gr.State(None)
|
253 |
-
|
254 |
-
with gr.Row():
|
255 |
-
with gr.Column(scale=1):
|
256 |
-
image_input = gr.Image(label="Upload Image (only needed once)")
|
257 |
-
upload_button = gr.Button("Add Image to Chat")
|
258 |
-
|
259 |
-
with gr.Accordion("Advanced options", open=False):
|
260 |
-
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
261 |
-
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
262 |
-
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
263 |
-
|
264 |
-
clear_button = gr.Button("Clear Chat")
|
265 |
-
|
266 |
-
with gr.Column(scale=2):
|
267 |
-
chat_interface = gr.Chatbot(label="Chat History", height=500)
|
268 |
-
message_input = gr.Textbox(label="Your message", placeholder="Ask about the image or continue the conversation...", lines=2)
|
269 |
-
chat_button = gr.Button("Send")
|
270 |
-
|
271 |
-
examples_chat = gr.Examples(
|
272 |
-
label="Multimodal Chat examples",
|
273 |
-
examples=[
|
274 |
-
[
|
275 |
-
"explain this meme",
|
276 |
-
"doge.png",
|
277 |
-
],
|
278 |
-
[
|
279 |
-
"Convert the formula into latex code.",
|
280 |
-
"equation.png",
|
281 |
-
],
|
282 |
-
],
|
283 |
-
inputs=[message_input, image_input],
|
284 |
-
)
|
285 |
-
|
286 |
-
with gr.Tab("Text-to-Image Generation"):
|
287 |
-
gr.Markdown(value="## Text-to-Image Generation")
|
288 |
-
|
289 |
-
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
290 |
|
291 |
-
|
|
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
# Chat interface interactions
|
314 |
-
upload_button.click(
|
315 |
-
add_image_to_chat,
|
316 |
-
inputs=[image_input, chat_history],
|
317 |
-
outputs=[current_image, chat_history]
|
318 |
-
)
|
319 |
|
320 |
chat_button.click(
|
321 |
-
|
322 |
-
inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature],
|
323 |
-
outputs=[message_input, chat_interface, current_image]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
)
|
325 |
|
326 |
clear_button.click(
|
327 |
clear_chat,
|
328 |
inputs=[current_image],
|
329 |
-
outputs=[chat_interface, current_image]
|
330 |
)
|
331 |
|
332 |
-
#
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
)
|
338 |
|
339 |
demo.launch(share=True)
|
|
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import time
|
11 |
+
import re
|
12 |
from Upsample import RealESRGAN
|
13 |
import spaces # Import spaces for ZeroGPU compatibility
|
14 |
|
|
|
18 |
config = AutoConfig.from_pretrained(model_path)
|
19 |
language_config = config.language_config
|
20 |
language_config._attn_implementation = 'eager'
|
21 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
|
|
|
|
|
22 |
if torch.cuda.is_available():
|
23 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
24 |
else:
|
|
|
32 |
sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
|
33 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
34 |
|
35 |
+
# Patterns for detecting image generation requests
|
36 |
+
GENERATION_PATTERNS = [r"generate (.+)", r"create (.+)", r"draw (.+)", r"make (.+)", r"show (.+)", r"visualize (.+)", r"imagine (.+)", r"picture (.+)",]
|
37 |
+
|
38 |
+
def is_generation_request(message):
|
39 |
+
"""Determine if a message is requesting image generation"""
|
40 |
+
message = message.lower().strip()
|
41 |
+
|
42 |
+
# Check if message explicitly mentions image generation
|
43 |
+
for pattern in GENERATION_PATTERNS:
|
44 |
+
match = re.match(pattern, message, re.IGNORECASE)
|
45 |
+
if match:
|
46 |
+
return True, match.group(1)
|
47 |
+
|
48 |
+
# Check for specific keywords suggesting image generation
|
49 |
+
image_keywords = ["image", "picture", "photo", "artwork", "illustration", "painting", "drawing"]
|
50 |
+
generation_verbs = ["generate", "create", "make", "produce", "show me", "draw"]
|
51 |
+
|
52 |
+
for verb in generation_verbs:
|
53 |
+
for keyword in image_keywords:
|
54 |
+
if f"{verb} {keyword}" in message or f"{verb} an {keyword}" in message or f"{verb} a {keyword}" in message:
|
55 |
+
# Extract the prompt (everything after the keyword)
|
56 |
+
pattern = f"{verb}\\s+(?:an?\\s+)?{keyword}\\s+(?:of|showing|depicting|with)?\\s*(.*)"
|
57 |
+
match = re.search(pattern, message, re.IGNORECASE)
|
58 |
+
if match and match.group(1):
|
59 |
+
return True, match.group(1)
|
60 |
+
else:
|
61 |
+
# If we can't extract a specific prompt, use the whole message
|
62 |
+
return True, message
|
63 |
+
|
64 |
+
return False, None
|
65 |
+
|
66 |
+
|
67 |
@torch.inference_mode()
|
68 |
@spaces.GPU(duration=120)
|
69 |
+
# Unified chat function that handles both image understanding and generation
|
70 |
+
def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_weight, t2i_temperature, progress=gr.Progress(track_tqdm=True)):
|
71 |
# Clear CUDA cache before generating
|
72 |
torch.cuda.empty_cache()
|
73 |
|
74 |
+
# Check if this is an image generation request
|
75 |
+
is_gen_request, extracted_prompt = is_generation_request(message)
|
76 |
+
|
77 |
+
if is_gen_request:
|
78 |
+
# Prepare a more detailed prompt by considering context from the conversation
|
79 |
+
context_prompt = extracted_prompt
|
80 |
+
|
81 |
+
# Optionally, enhance the prompt with context from previous messages
|
82 |
+
if chat_history and len(chat_history) > 0:
|
83 |
+
# Get the last few turns of conversation for context (limit to last 3 turns)
|
84 |
+
recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
|
85 |
+
context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
|
86 |
+
|
87 |
+
# Only use context if it's not too long
|
88 |
+
if len(context_text) < 200: # Arbitrary length limit
|
89 |
+
context_prompt = f"{context_text}. {extracted_prompt}"
|
90 |
+
|
91 |
+
# Generate images
|
92 |
+
generated_images = generate_image(
|
93 |
+
prompt=context_prompt,
|
94 |
+
seed=seed,
|
95 |
+
guidance=cfg_weight,
|
96 |
+
t2i_temperature=t2i_temperature
|
97 |
+
)
|
98 |
+
|
99 |
+
# Create a response that includes the generated images
|
100 |
+
response = f"I've generated the following images based on: '{extracted_prompt}'"
|
101 |
+
|
102 |
+
# Add the images to the chat as the bot's response
|
103 |
+
chat_history.append((message, response))
|
104 |
+
|
105 |
+
# Return the message, updated history, maintained image context, and generated images
|
106 |
+
return "", chat_history, image, generated_images
|
107 |
+
|
108 |
+
# Regular chat flow (no image generation)
|
109 |
# set seed
|
110 |
torch.manual_seed(seed)
|
111 |
np.random.seed(seed)
|
|
|
155 |
|
156 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
157 |
|
158 |
+
outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
|
159 |
+
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
|
160 |
+
eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
|
161 |
+
do_sample=False if temperature == 0 else True, use_cache=True,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
164 |
|
|
|
166 |
chat_history.append((message, answer))
|
167 |
|
168 |
# Keep the last uploaded image in context
|
169 |
+
return "", chat_history, image, None
|
170 |
+
|
171 |
+
|
172 |
+
def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
|
173 |
+
image_token_num_per_image: int = 1024, patch_size: int = 16, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# Clear CUDA cache before generating
|
175 |
torch.cuda.empty_cache()
|
176 |
|
|
|
203 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
204 |
|
205 |
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
206 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
207 |
|
208 |
return generated_tokens.to(dtype=torch.int), patches
|
209 |
|
|
|
219 |
|
220 |
|
221 |
@torch.inference_mode()
|
222 |
+
@spaces.GPU(duration=180) # Specify a duration to avoid timeout
|
223 |
+
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
224 |
# Clear CUDA cache and avoid tracking gradients
|
225 |
torch.cuda.empty_cache()
|
226 |
# Set the seed for reproducible results
|
|
|
228 |
torch.manual_seed(seed)
|
229 |
torch.cuda.manual_seed(seed)
|
230 |
np.random.seed(seed)
|
231 |
+
width = 512
|
232 |
+
height = 512
|
233 |
parallel_size = 4
|
234 |
|
235 |
with torch.no_grad():
|
236 |
messages = [{'role': '<|User|>', 'content': prompt},
|
237 |
{'role': '<|Assistant|>', 'content': ''}]
|
238 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
239 |
+
sft_format=vl_chat_processor.sft_format, system_prompt='')
|
|
|
240 |
text = text + vl_chat_processor.image_start_tag
|
241 |
|
242 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
243 |
+
output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
|
244 |
+
parallel_size=parallel_size, temperature=t2i_temperature)
|
245 |
+
images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
|
246 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
stime = time.time()
|
248 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
249 |
print(f'upsample time: {time.time() - stime}')
|
|
|
257 |
|
258 |
width, height = img.size
|
259 |
|
260 |
+
if width >= 4096 or height >= 4096:
|
261 |
raise Exception("The image is too large.")
|
262 |
|
263 |
global sr_model
|
|
|
272 |
|
273 |
# Helper function to clear chat history but maintain the image
|
274 |
def clear_chat(image):
|
275 |
+
return [], image, None
|
276 |
|
277 |
|
278 |
# Gradio interface
|
|
|
280 |
.gradio-container {max-width: 960px !important}
|
281 |
'''
|
282 |
with gr.Blocks(css=css) as demo:
|
283 |
+
gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
+
# State variables to maintain context
|
286 |
+
chat_history = gr.State([])
|
287 |
+
current_image = gr.State(None)
|
288 |
|
289 |
+
with gr.Row():
|
290 |
+
with gr.Column(scale=1):
|
291 |
+
image_input = gr.Image(label="Upload Image (optional)")
|
292 |
+
upload_button = gr.Button("Add Image to Chat")
|
293 |
+
|
294 |
+
with gr.Accordion("Chat Options", open=False):
|
295 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
296 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
297 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
298 |
+
|
299 |
+
with gr.Accordion("Image Generation Options", open=False):
|
300 |
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
301 |
+
t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
|
302 |
+
|
303 |
+
clear_button = gr.Button("Clear Chat")
|
304 |
+
|
305 |
+
gr.Markdown("""
|
306 |
+
### Tips:
|
307 |
+
1. Upload an image to discuss it
|
308 |
+
2. Type commands like "generate [description]" to create images
|
309 |
+
3. Continue chatting about uploaded or generated images
|
310 |
+
4. Use natural language like "show me a sunset" or "create a portrait"
|
311 |
+
""")
|
312 |
+
|
313 |
+
with gr.Column(scale=2):
|
314 |
+
chat_interface = gr.Chatbot(label="Chat History", height=500)
|
315 |
+
message_input = gr.Textbox(
|
316 |
+
label="Your message",
|
317 |
+
placeholder="Ask about an image, continue chatting, or generate new images by typing 'generate [description]'",
|
318 |
+
lines=2
|
319 |
+
)
|
320 |
+
chat_button = gr.Button("Send")
|
321 |
+
generated_images = gr.Gallery(label="Generated Images", visible=True, columns=2, rows=2)
|
322 |
|
323 |
# Chat interface interactions
|
324 |
+
upload_button.click(add_image_to_chat, inputs=[image_input, chat_history], outputs=[current_image, chat_history])
|
|
|
|
|
|
|
|
|
325 |
|
326 |
chat_button.click(
|
327 |
+
unified_chat,
|
328 |
+
inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature, cfg_weight_input, t2i_temperature_input],
|
329 |
+
outputs=[message_input, chat_interface, current_image, generated_images]
|
330 |
+
)
|
331 |
+
|
332 |
+
# Also trigger on Enter key
|
333 |
+
message_input.submit(
|
334 |
+
unified_chat,
|
335 |
+
inputs=[current_image, message_input, chat_interface, und_seed_input, top_p, temperature, cfg_weight_input, t2i_temperature_input],
|
336 |
+
outputs=[message_input, chat_interface, current_image, generated_images]
|
337 |
)
|
338 |
|
339 |
clear_button.click(
|
340 |
clear_chat,
|
341 |
inputs=[current_image],
|
342 |
+
outputs=[chat_interface, current_image, generated_images]
|
343 |
)
|
344 |
|
345 |
+
# Examples for the unified interface
|
346 |
+
examples = gr.Examples(
|
347 |
+
label="Example queries",
|
348 |
+
examples=[
|
349 |
+
["What's in this image?"],
|
350 |
+
["Generate a cute kitten with big eyes"],
|
351 |
+
["Show me a mountain landscape at sunset"],
|
352 |
+
["Can you explain what's happening in this picture?"],
|
353 |
+
["Create an astronaut riding a horse"],
|
354 |
+
["Generate a futuristic cityscape with flying cars"],
|
355 |
+
],
|
356 |
+
inputs=message_input,
|
357 |
)
|
358 |
|
359 |
demo.launch(share=True)
|