Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -85,22 +85,12 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
85 |
is_gen_request, extracted_prompt = is_generation_request(message)
|
86 |
|
87 |
if is_gen_request:
|
88 |
-
#
|
89 |
context_prompt = extracted_prompt
|
90 |
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
|
95 |
-
context_text = " ".join([f"User: {user_msg}" for user_msg, _ in recent_context])
|
96 |
-
#context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
|
97 |
-
|
98 |
-
# Only use context if it's not too long
|
99 |
-
if len(context_text) < 200: # Arbitrary length limit
|
100 |
-
context_prompt = f"{context_text}. {extracted_prompt}"
|
101 |
-
|
102 |
-
# Generate images
|
103 |
-
generated_images = generate_image(prompt=context_prompt, seed=seed, guidance=cfg_weight, t2i_temperature=t2i_temperature)
|
104 |
|
105 |
# Create a response that includes the generated images
|
106 |
response = f"I've generated the following images based on: '{extracted_prompt}'"
|
@@ -111,7 +101,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
111 |
# Return the message, updated history, maintained image context, and generated images
|
112 |
return "", chat_history, image, generated_images
|
113 |
|
114 |
-
#
|
115 |
# set seed
|
116 |
torch.manual_seed(seed)
|
117 |
np.random.seed(seed)
|
@@ -155,9 +145,8 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
155 |
if image is not None:
|
156 |
pil_images = [Image.fromarray(image)]
|
157 |
|
158 |
-
prepare_inputs = vl_chat_processor(
|
159 |
-
|
160 |
-
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
161 |
|
162 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
163 |
|
@@ -224,7 +213,8 @@ def unpack(dec, width, height, parallel_size=5):
|
|
224 |
|
225 |
@torch.inference_mode()
|
226 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
227 |
-
def generate_image(prompt,
|
|
|
228 |
# Clear CUDA cache and avoid tracking gradients
|
229 |
torch.cuda.empty_cache()
|
230 |
# Set the seed for reproducible results
|
@@ -236,18 +226,45 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
|
|
236 |
height = 384
|
237 |
parallel_size = 1
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
with torch.no_grad():
|
240 |
-
messages = [{'role': '<|User|>', 'content':
|
241 |
{'role': '<|Assistant|>', 'content': ''}]
|
242 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
243 |
-
sft_format=vl_chat_processor.sft_format,
|
244 |
-
system_prompt='')
|
245 |
text = text + vl_chat_processor.image_start_tag
|
246 |
|
247 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
248 |
-
output, patches = generate(input_ids,
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
stime = time.time()
|
253 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
|
|
85 |
is_gen_request, extracted_prompt = is_generation_request(message)
|
86 |
|
87 |
if is_gen_request:
|
88 |
+
# Extract the prompt directly
|
89 |
context_prompt = extracted_prompt
|
90 |
|
91 |
+
# Generate images with full conversation history
|
92 |
+
generated_images = generate_image(prompt=context_prompt, conversation_history=chat_history, # Pass the full chat history
|
93 |
+
seed=seed, guidance=cfg_weight, t2i_temperature=t2i_temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# Create a response that includes the generated images
|
96 |
response = f"I've generated the following images based on: '{extracted_prompt}'"
|
|
|
101 |
# Return the message, updated history, maintained image context, and generated images
|
102 |
return "", chat_history, image, generated_images
|
103 |
|
104 |
+
# Rest of the function remains the same...
|
105 |
# set seed
|
106 |
torch.manual_seed(seed)
|
107 |
np.random.seed(seed)
|
|
|
145 |
if image is not None:
|
146 |
pil_images = [Image.fromarray(image)]
|
147 |
|
148 |
+
prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True
|
149 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
|
|
150 |
|
151 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
152 |
|
|
|
213 |
|
214 |
@torch.inference_mode()
|
215 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
216 |
+
def generate_image(prompt, conversation_history=None, # Add conversation history parameter
|
217 |
+
seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
218 |
# Clear CUDA cache and avoid tracking gradients
|
219 |
torch.cuda.empty_cache()
|
220 |
# Set the seed for reproducible results
|
|
|
226 |
height = 384
|
227 |
parallel_size = 1
|
228 |
|
229 |
+
# Prepare a richer context-aware prompt
|
230 |
+
full_prompt = prompt
|
231 |
+
|
232 |
+
# Add conversation history context if available
|
233 |
+
if conversation_history and len(conversation_history) > 0:
|
234 |
+
# Build a context string from the last few conversation turns
|
235 |
+
# Limit to last 3-5 turns to keep prompt manageable
|
236 |
+
recent_turns = conversation_history[-5:] if len(conversation_history) > 5 else conversation_history
|
237 |
+
|
238 |
+
context_parts = []
|
239 |
+
for user_msg, assistant_msg in recent_turns:
|
240 |
+
if user_msg and user_msg.strip():
|
241 |
+
context_parts.append(f"User: {user_msg}")
|
242 |
+
if assistant_msg and assistant_msg.strip():
|
243 |
+
context_parts.append(f"Assistant: {assistant_msg}")
|
244 |
+
|
245 |
+
conversation_context = "\n".join(context_parts)
|
246 |
+
|
247 |
+
# Combine conversation context with the prompt
|
248 |
+
full_prompt = f"Based on this conversation:\n{conversation_context}\n\nGenerate: {prompt}"
|
249 |
+
|
250 |
with torch.no_grad():
|
251 |
+
messages = [{'role': '<|User|>', 'content': full_prompt},
|
252 |
{'role': '<|Assistant|>', 'content': ''}]
|
253 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
254 |
+
sft_format=vl_chat_processor.sft_format, system_prompt='')
|
|
|
255 |
text = text + vl_chat_processor.image_start_tag
|
256 |
|
257 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
258 |
+
output, patches = generate(input_ids,
|
259 |
+
width // 16 * 16,
|
260 |
+
height // 16 * 16,
|
261 |
+
cfg_weight=guidance,
|
262 |
+
parallel_size=parallel_size,
|
263 |
+
temperature=t2i_temperature)
|
264 |
+
images = unpack(patches,
|
265 |
+
width // 16 * 16,
|
266 |
+
height // 16 * 16,
|
267 |
+
parallel_size=parallel_size)
|
268 |
|
269 |
stime = time.time()
|
270 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|