NikhilJoson commited on
Commit
1745cd1
·
verified ·
1 Parent(s): b3020f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -25
app.py CHANGED
@@ -85,22 +85,12 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
85
  is_gen_request, extracted_prompt = is_generation_request(message)
86
 
87
  if is_gen_request:
88
- # Prepare a more detailed prompt by considering context from the conversation
89
  context_prompt = extracted_prompt
90
 
91
- # Optionally, enhance the prompt with context from previous messages
92
- if chat_history and len(chat_history) > 0:
93
- # Get the last few turns of conversation for context (limit to last 3 turns)
94
- recent_context = chat_history[-3:] if len(chat_history) > 3 else chat_history
95
- context_text = " ".join([f"User: {user_msg}" for user_msg, _ in recent_context])
96
- #context_text = " ".join([f"{user}: {user_msg}" for user_msg, _ in recent_context])
97
-
98
- # Only use context if it's not too long
99
- if len(context_text) < 200: # Arbitrary length limit
100
- context_prompt = f"{context_text}. {extracted_prompt}"
101
-
102
- # Generate images
103
- generated_images = generate_image(prompt=context_prompt, seed=seed, guidance=cfg_weight, t2i_temperature=t2i_temperature)
104
 
105
  # Create a response that includes the generated images
106
  response = f"I've generated the following images based on: '{extracted_prompt}'"
@@ -111,7 +101,7 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
111
  # Return the message, updated history, maintained image context, and generated images
112
  return "", chat_history, image, generated_images
113
 
114
- # Regular chat flow (no image generation)
115
  # set seed
116
  torch.manual_seed(seed)
117
  np.random.seed(seed)
@@ -155,9 +145,8 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
155
  if image is not None:
156
  pil_images = [Image.fromarray(image)]
157
 
158
- prepare_inputs = vl_chat_processor(
159
- conversations=conversation, images=pil_images, force_batchify=True
160
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
161
 
162
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
163
 
@@ -224,7 +213,8 @@ def unpack(dec, width, height, parallel_size=5):
224
 
225
  @torch.inference_mode()
226
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
227
- def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
 
228
  # Clear CUDA cache and avoid tracking gradients
229
  torch.cuda.empty_cache()
230
  # Set the seed for reproducible results
@@ -236,18 +226,45 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=
236
  height = 384
237
  parallel_size = 1
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  with torch.no_grad():
240
- messages = [{'role': '<|User|>', 'content': prompt},
241
  {'role': '<|Assistant|>', 'content': ''}]
242
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
243
- sft_format=vl_chat_processor.sft_format,
244
- system_prompt='')
245
  text = text + vl_chat_processor.image_start_tag
246
 
247
  input_ids = torch.LongTensor(tokenizer.encode(text))
248
- output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
249
- parallel_size=parallel_size, temperature=t2i_temperature)
250
- images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
 
 
 
 
 
 
 
251
 
252
  stime = time.time()
253
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
 
85
  is_gen_request, extracted_prompt = is_generation_request(message)
86
 
87
  if is_gen_request:
88
+ # Extract the prompt directly
89
  context_prompt = extracted_prompt
90
 
91
+ # Generate images with full conversation history
92
+ generated_images = generate_image(prompt=context_prompt, conversation_history=chat_history, # Pass the full chat history
93
+ seed=seed, guidance=cfg_weight, t2i_temperature=t2i_temperature)
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Create a response that includes the generated images
96
  response = f"I've generated the following images based on: '{extracted_prompt}'"
 
101
  # Return the message, updated history, maintained image context, and generated images
102
  return "", chat_history, image, generated_images
103
 
104
+ # Rest of the function remains the same...
105
  # set seed
106
  torch.manual_seed(seed)
107
  np.random.seed(seed)
 
145
  if image is not None:
146
  pil_images = [Image.fromarray(image)]
147
 
148
+ prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True
149
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
 
150
 
151
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
152
 
 
213
 
214
  @torch.inference_mode()
215
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
216
+ def generate_image(prompt, conversation_history=None, # Add conversation history parameter
217
+ seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
218
  # Clear CUDA cache and avoid tracking gradients
219
  torch.cuda.empty_cache()
220
  # Set the seed for reproducible results
 
226
  height = 384
227
  parallel_size = 1
228
 
229
+ # Prepare a richer context-aware prompt
230
+ full_prompt = prompt
231
+
232
+ # Add conversation history context if available
233
+ if conversation_history and len(conversation_history) > 0:
234
+ # Build a context string from the last few conversation turns
235
+ # Limit to last 3-5 turns to keep prompt manageable
236
+ recent_turns = conversation_history[-5:] if len(conversation_history) > 5 else conversation_history
237
+
238
+ context_parts = []
239
+ for user_msg, assistant_msg in recent_turns:
240
+ if user_msg and user_msg.strip():
241
+ context_parts.append(f"User: {user_msg}")
242
+ if assistant_msg and assistant_msg.strip():
243
+ context_parts.append(f"Assistant: {assistant_msg}")
244
+
245
+ conversation_context = "\n".join(context_parts)
246
+
247
+ # Combine conversation context with the prompt
248
+ full_prompt = f"Based on this conversation:\n{conversation_context}\n\nGenerate: {prompt}"
249
+
250
  with torch.no_grad():
251
+ messages = [{'role': '<|User|>', 'content': full_prompt},
252
  {'role': '<|Assistant|>', 'content': ''}]
253
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
254
+ sft_format=vl_chat_processor.sft_format, system_prompt='')
 
255
  text = text + vl_chat_processor.image_start_tag
256
 
257
  input_ids = torch.LongTensor(tokenizer.encode(text))
258
+ output, patches = generate(input_ids,
259
+ width // 16 * 16,
260
+ height // 16 * 16,
261
+ cfg_weight=guidance,
262
+ parallel_size=parallel_size,
263
+ temperature=t2i_temperature)
264
+ images = unpack(patches,
265
+ width // 16 * 16,
266
+ height // 16 * 16,
267
+ parallel_size=parallel_size)
268
 
269
  stime = time.time()
270
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]