Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,9 @@ model_path = "deepseek-ai/Janus-Pro-7B"
|
|
18 |
config = AutoConfig.from_pretrained(model_path)
|
19 |
language_config = config.language_config
|
20 |
language_config._attn_implementation = 'eager'
|
21 |
-
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
|
|
|
|
22 |
if torch.cuda.is_available():
|
23 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
24 |
else:
|
@@ -33,7 +35,16 @@ sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu
|
|
33 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
34 |
|
35 |
# Patterns for detecting image generation requests
|
36 |
-
GENERATION_PATTERNS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def is_generation_request(message):
|
39 |
"""Determine if a message is requesting image generation"""
|
@@ -82,9 +93,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
82 |
if chat_history and len(chat_history) > 0:
|
83 |
# Get the last few turns of conversation for context (limit to last 3 turns)
|
84 |
recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
|
85 |
-
|
86 |
-
context_text = " ".join([f"{user_msg}" for user_msg, _ in recent_context])
|
87 |
-
|
88 |
|
89 |
# Only use context if it's not too long
|
90 |
if len(context_text) < 200: # Arbitrary length limit
|
@@ -157,10 +166,18 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
157 |
|
158 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
159 |
|
160 |
-
outputs = vl_gpt.language_model.generate(
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
166 |
|
@@ -172,7 +189,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
172 |
|
173 |
|
174 |
def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
|
175 |
-
image_token_num_per_image: int =
|
176 |
# Clear CUDA cache before generating
|
177 |
torch.cuda.empty_cache()
|
178 |
|
@@ -187,9 +204,7 @@ def generate(input_ids, width, height, temperature: float = 1, parallel_size: in
|
|
187 |
pkv = None
|
188 |
for i in range(image_token_num_per_image):
|
189 |
with torch.no_grad():
|
190 |
-
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
191 |
-
use_cache=True,
|
192 |
-
past_key_values=pkv)
|
193 |
pkv = outputs.past_key_values
|
194 |
hidden_states = outputs.last_hidden_state
|
195 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
@@ -221,7 +236,7 @@ def unpack(dec, width, height, parallel_size=5):
|
|
221 |
|
222 |
|
223 |
@torch.inference_mode()
|
224 |
-
@spaces.GPU(duration=
|
225 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
226 |
# Clear CUDA cache and avoid tracking gradients
|
227 |
torch.cuda.empty_cache()
|
@@ -238,7 +253,8 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
|
|
238 |
messages = [{'role': '<|User|>', 'content': prompt},
|
239 |
{'role': '<|Assistant|>', 'content': ''}]
|
240 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
241 |
-
|
|
|
242 |
text = text + vl_chat_processor.image_start_tag
|
243 |
|
244 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
@@ -276,10 +292,18 @@ def add_image_to_chat(image, chat_history):
|
|
276 |
def clear_chat(image):
|
277 |
return [], image, None
|
278 |
|
|
|
279 |
|
280 |
# Gradio interface
|
281 |
with gr.Blocks() as demo:
|
282 |
gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
# State variables to maintain context
|
285 |
chat_history = gr.State([])
|
@@ -297,17 +321,10 @@ with gr.Blocks() as demo:
|
|
297 |
|
298 |
with gr.Accordion("Image Generation Options", open=False):
|
299 |
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
300 |
-
t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1
|
301 |
|
302 |
clear_button = gr.Button("Clear Chat")
|
303 |
|
304 |
-
gr.Markdown("""
|
305 |
-
### Tips:
|
306 |
-
1. Upload an image to discuss it
|
307 |
-
2. Type commands like "generate [description]" to create images
|
308 |
-
3. Continue chatting about uploaded or generated images
|
309 |
-
4. Use natural language like "show me a sunset" or "create a portrait"
|
310 |
-
""")
|
311 |
|
312 |
with gr.Column(scale=2):
|
313 |
chat_interface = gr.Chatbot(label="Chat History", height=500)
|
|
|
18 |
config = AutoConfig.from_pretrained(model_path)
|
19 |
language_config = config.language_config
|
20 |
language_config._attn_implementation = 'eager'
|
21 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
22 |
+
language_config=language_config,
|
23 |
+
trust_remote_code=True)
|
24 |
if torch.cuda.is_available():
|
25 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
26 |
else:
|
|
|
35 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
36 |
|
37 |
# Patterns for detecting image generation requests
|
38 |
+
GENERATION_PATTERNS = [
|
39 |
+
r"generate (.+)",
|
40 |
+
r"create (.+)",
|
41 |
+
r"draw (.+)",
|
42 |
+
r"make (.+)",
|
43 |
+
r"show (.+)",
|
44 |
+
r"visualize (.+)",
|
45 |
+
r"imagine (.+)",
|
46 |
+
r"picture (.+)",
|
47 |
+
]
|
48 |
|
49 |
def is_generation_request(message):
|
50 |
"""Determine if a message is requesting image generation"""
|
|
|
93 |
if chat_history and len(chat_history) > 0:
|
94 |
# Get the last few turns of conversation for context (limit to last 3 turns)
|
95 |
recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
|
96 |
+
context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
|
|
|
|
|
97 |
|
98 |
# Only use context if it's not too long
|
99 |
if len(context_text) < 200: # Arbitrary length limit
|
|
|
166 |
|
167 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
168 |
|
169 |
+
outputs = vl_gpt.language_model.generate(
|
170 |
+
inputs_embeds=inputs_embeds,
|
171 |
+
attention_mask=prepare_inputs.attention_mask,
|
172 |
+
pad_token_id=tokenizer.eos_token_id,
|
173 |
+
bos_token_id=tokenizer.bos_token_id,
|
174 |
+
eos_token_id=tokenizer.eos_token_id,
|
175 |
+
max_new_tokens=512,
|
176 |
+
do_sample=False if temperature == 0 else True,
|
177 |
+
use_cache=True,
|
178 |
+
temperature=temperature,
|
179 |
+
top_p=top_p,
|
180 |
+
)
|
181 |
|
182 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
183 |
|
|
|
189 |
|
190 |
|
191 |
def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5,
|
192 |
+
image_token_num_per_image: int = 576, patch_size: int = 16, progress=gr.Progress(track_tqdm=True)):
|
193 |
# Clear CUDA cache before generating
|
194 |
torch.cuda.empty_cache()
|
195 |
|
|
|
204 |
pkv = None
|
205 |
for i in range(image_token_num_per_image):
|
206 |
with torch.no_grad():
|
207 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
|
|
|
|
208 |
pkv = outputs.past_key_values
|
209 |
hidden_states = outputs.last_hidden_state
|
210 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
|
|
236 |
|
237 |
|
238 |
@torch.inference_mode()
|
239 |
+
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
240 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
241 |
# Clear CUDA cache and avoid tracking gradients
|
242 |
torch.cuda.empty_cache()
|
|
|
253 |
messages = [{'role': '<|User|>', 'content': prompt},
|
254 |
{'role': '<|Assistant|>', 'content': ''}]
|
255 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
256 |
+
sft_format=vl_chat_processor.sft_format,
|
257 |
+
system_prompt='')
|
258 |
text = text + vl_chat_processor.image_start_tag
|
259 |
|
260 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
|
|
292 |
def clear_chat(image):
|
293 |
return [], image, None
|
294 |
|
295 |
+
|
296 |
|
297 |
# Gradio interface
|
298 |
with gr.Blocks() as demo:
|
299 |
gr.Markdown("# Janus Pro 7B - Unified Chat Interface")
|
300 |
+
gr.Markdown("""
|
301 |
+
### Tips:
|
302 |
+
1. Upload an image to discuss it
|
303 |
+
2. Type commands like "generate [description]" to create images
|
304 |
+
3. Continue chatting about uploaded or generated images
|
305 |
+
4. Use natural language like "show me a sunset" or "create a portrait"
|
306 |
+
""")
|
307 |
|
308 |
# State variables to maintain context
|
309 |
chat_history = gr.State([])
|
|
|
321 |
|
322 |
with gr.Accordion("Image Generation Options", open=False):
|
323 |
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
324 |
+
t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
|
325 |
|
326 |
clear_button = gr.Button("Clear Chat")
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
with gr.Column(scale=2):
|
330 |
chat_interface = gr.Chatbot(label="Chat History", height=500)
|