gaur3009 commited on
Commit
12eb9d7
·
verified ·
1 Parent(s): 1df8fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -36
app.py CHANGED
@@ -1,50 +1,69 @@
1
- import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- from torchvision import transforms
6
  from PIL import Image
 
 
 
7
 
8
- # Load MiDaS depth estimation model
9
- midas_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
10
- midas_model.eval()
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- midas_model.to(device)
13
- midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform
14
-
15
- def estimate_depth(image):
16
- """Estimate depth map to identify fabric folds."""
17
- image = image.convert("RGB")
18
- image_tensor = midas_transform(image).to(device)
19
-
20
  with torch.no_grad():
21
- depth = midas_model(image_tensor).squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
22
 
23
- depth = cv2.resize(depth, (image.size[0], image.size[1]))
24
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
25
- return depth.astype(np.uint8)
26
-
27
- def detect_folds(image):
28
- """Apply edge detection and highlight cloth folds."""
29
- depth_map = estimate_depth(image)
30
- edges = cv2.Canny(depth_map, 50, 150)
 
 
 
 
 
 
31
 
32
- # Convert edges to 3-channel image for visualization
33
- edges_colored = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
34
- overlay = cv2.addWeighted(np.array(image), 0.7, edges_colored, 0.3, 0)
35
 
36
- return Image.fromarray(overlay)
 
 
37
 
38
- def main(image):
39
- return detect_folds(image)
40
 
41
- iface = gr.Interface(
 
42
  fn=main,
43
- inputs=gr.Image(type="pil"),
44
- outputs=gr.Image(type="pil"),
45
- title="Cloth Fold Detection",
46
- description="Upload an image of clothing to visualize folds using depth estimation and edge detection."
 
 
 
 
47
  )
48
 
49
- if __name__ == "__main__":
50
- iface.launch(share=True, debug=True)
 
 
1
  import torch
2
  import cv2
3
  import numpy as np
4
+ import gradio as gr
5
  from PIL import Image
6
+ from torchvision import transforms
7
+ from skimage.restoration import denoise_tv_chambolle
8
+ from transformers import SamModel, SamProcessor
9
 
10
+ # Load SAM model
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
13
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
14
+
15
+ def segment_dress(image):
16
+ """Segments the dress from an input image using SAM."""
17
+ input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
18
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
 
 
 
19
  with torch.no_grad():
20
+ outputs = model(**inputs)
21
+ masks = processor.image_processor.post_process_masks(
22
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
23
+ )
24
+ return masks[0][0].numpy() if masks else None
25
+
26
+ def warp_design(design, mask, warp_scale):
27
+ """Warp the design using TPS and scale control."""
28
+ h, w = mask.shape[:2]
29
+ design_resized = cv2.resize(design, (w, h))
30
 
31
+ # Apply scaling
32
+ scaled_mask = (mask * 255 * (warp_scale / 100)).astype(np.uint8)
33
+ return cv2.bitwise_and(design_resized, design_resized, mask=scaled_mask)
34
+
35
+ def blend_images(base, overlay, mask):
36
+ """Blends the design onto the dress using seamless cloning."""
37
+ center = tuple(np.array(base.shape[:2]) // 2)
38
+ return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
39
+
40
+ def apply_design(image_path, design_path, warp_scale):
41
+ """Pipeline to segment, warp, and blend design onto dress."""
42
+ image = Image.open(image_path).convert("RGB")
43
+ design = cv2.imread(design_path)
44
+ mask = segment_dress(image)
45
 
46
+ if mask is None:
47
+ return "Segmentation Failed!"
 
48
 
49
+ warped_design = warp_design(design, mask, warp_scale)
50
+ blended = blend_images(np.array(image), warped_design, mask)
51
+ return Image.fromarray(blended)
52
 
53
+ def main(image, design, warp_scale):
54
+ return apply_design(image, design, warp_scale)
55
 
56
+ # Gradio UI
57
+ demo = gr.Interface(
58
  fn=main,
59
+ inputs=[
60
+ gr.Image(type="filepath", label="Upload Dress Image"),
61
+ gr.Image(type="filepath", label="Upload Design Image"),
62
+ gr.Slider(0, 100, value=50, label="Warp Scale (%)")
63
+ ],
64
+ outputs=gr.Image(label="Warped Design on Dress"),
65
+ title="AI-Powered Dress Designer",
66
+ description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
67
  )
68
 
69
+ demo.launch()