Bismay commited on
Commit
25b137d
·
1 Parent(s): eb5e343

fix gallery bug

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -244,7 +244,7 @@ class ClothingInpainter:
244
 
245
  return images_output
246
 
247
- def process_segmentation(image, dilation_iterations=2):
248
  try:
249
  if image is None:
250
  raise gr.Error("Please upload an image")
@@ -260,13 +260,25 @@ def process_segmentation(image, dilation_iterations=2):
260
  if not all_masks:
261
  logger.error("No clothing detected in the image")
262
  raise gr.Error("No clothing detected in the image. Please try a different image.")
 
263
  inpainter.last_mask = all_masks
264
- # Only show main clothing parts for selection
 
 
 
 
 
 
 
 
 
 
 
265
  main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
266
  masks = {k: v for k, v in all_masks.items() if k in main_parts}
267
- vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None)
268
- detected_parts = [k for k in masks.keys()]
269
- return vis_image, gr.update(choices=detected_parts, value=[])
270
  except gr.Error as e:
271
  raise e
272
  except Exception as e:
@@ -360,9 +372,10 @@ def create_interface():
360
  # Input Section
361
  with gr.Group():
362
  gr.Markdown("### Input Image")
 
363
  with gr.Row():
364
- input_image = gr.Image(type="pil", label="Upload Image")
365
- example_btn = gr.Button("Load Example Image", variant="secondary")
366
 
367
  # Clothing Selection
368
  gr.Markdown("### Select Clothing Parts")
 
244
 
245
  return images_output
246
 
247
+ def process_segmentation(image, clothing_parts=None, prompt=None, dilation_iterations=2):
248
  try:
249
  if image is None:
250
  raise gr.Error("Please upload an image")
 
260
  if not all_masks:
261
  logger.error("No clothing detected in the image")
262
  raise gr.Error("No clothing detected in the image. Please try a different image.")
263
+
264
  inpainter.last_mask = all_masks
265
+
266
+ # If clothing parts are selected and prompt is provided, generate variations
267
+ if clothing_parts and prompt:
268
+ # Convert selected_parts to lowercase/dash format
269
+ selected_parts = [p.lower() for p in clothing_parts]
270
+ prompt_dict = {'pos': prompt}
271
+
272
+ # Generate inpainted images
273
+ images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations)
274
+ return images
275
+
276
+ # Otherwise, just show the segmentation visualization
277
  main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
278
  masks = {k: v for k, v in all_masks.items() if k in main_parts}
279
+ vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=clothing_parts)
280
+ return [vis_image]
281
+
282
  except gr.Error as e:
283
  raise e
284
  except Exception as e:
 
372
  # Input Section
373
  with gr.Group():
374
  gr.Markdown("### Input Image")
375
+ input_image = gr.Image(type="pil", label="Upload Image")
376
  with gr.Row():
377
+ example_btn = gr.Button("Load Example Image", variant="secondary", size="sm")
378
+ gr.Markdown("or upload your own image above", size="sm")
379
 
380
  # Clothing Selection
381
  gr.Markdown("### Select Clothing Parts")