amoghrrao commited on
Commit
14674d7
·
verified ·
1 Parent(s): 6277f5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -20
app.py CHANGED
@@ -7,42 +7,64 @@ from torchvision import transforms
7
  from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
8
 
9
  def load_segmentation_model():
10
- model_name = "ZhengPeng7/BiRefNet"
11
- model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
12
- return model
 
 
 
 
 
 
 
13
 
14
  def load_depth_model():
15
- model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
16
- processor = AutoProcessor.from_pretrained(model_name)
17
- model = AutoModelForDepthEstimation.from_pretrained(model_name)
18
- return processor, model
 
 
 
 
 
 
 
19
 
20
  def process_segmentation_image(image):
21
  transform = transforms.Compose([
22
  transforms.Resize((512, 512)),
23
  transforms.ToTensor(),
24
  ])
25
- input_tensor = transform(image).unsqueeze(0)
26
  return image, input_tensor
27
 
28
  def process_depth_image(image, processor):
29
  image = image.resize((512, 512))
30
- inputs = processor(images=image, return_tensors="pt")
31
  return image, inputs
32
 
33
  def segment_image(image, input_tensor, model):
34
- with torch.no_grad():
35
- outputs = model(input_tensor)
36
- output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
37
- mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
38
- mask = (mask > 0.5).astype(np.uint8) * 255
39
- return mask
 
 
 
 
40
 
41
  def estimate_depth(inputs, model):
42
- with torch.no_grad():
43
- outputs = model(**inputs)
44
- depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
45
- return depth_map
 
 
 
 
46
 
47
  def normalize_depth_map(depth_map):
48
  min_val = np.min(depth_map)
@@ -73,6 +95,9 @@ def process_image_pipeline(image):
73
  segmentation_model = load_segmentation_model()
74
  depth_processor, depth_model = load_depth_model()
75
 
 
 
 
76
  _, input_tensor = process_segmentation_image(image)
77
  _, inputs = process_depth_image(image, depth_processor)
78
 
@@ -83,6 +108,9 @@ def process_image_pipeline(image):
83
 
84
  return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image
85
 
 
 
 
86
  iface = gr.Interface(
87
  fn=process_image_pipeline,
88
  inputs=gr.Image(type="pil"),
@@ -96,4 +124,4 @@ iface = gr.Interface(
96
  )
97
 
98
  if __name__ == "__main__":
99
- iface.launch(share=True)
 
7
  from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
8
 
9
  def load_segmentation_model():
10
+ try:
11
+ print("Loading segmentation model...")
12
+ model_name = "ZhengPeng7/BiRefNet"
13
+ model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
14
+ model.to(device)
15
+ print("Segmentation model loaded successfully.")
16
+ return model
17
+ except Exception as e:
18
+ print(f"Error loading segmentation model: {e}")
19
+ return None
20
 
21
  def load_depth_model():
22
+ try:
23
+ print("Loading depth estimation model...")
24
+ model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
25
+ processor = AutoProcessor.from_pretrained(model_name)
26
+ model = AutoModelForDepthEstimation.from_pretrained(model_name)
27
+ model.to(device)
28
+ print("Depth estimation model loaded successfully.")
29
+ return processor, model
30
+ except Exception as e:
31
+ print(f"Error loading depth estimation model: {e}")
32
+ return None, None
33
 
34
  def process_segmentation_image(image):
35
  transform = transforms.Compose([
36
  transforms.Resize((512, 512)),
37
  transforms.ToTensor(),
38
  ])
39
+ input_tensor = transform(image).unsqueeze(0).to(device)
40
  return image, input_tensor
41
 
42
  def process_depth_image(image, processor):
43
  image = image.resize((512, 512))
44
+ inputs = processor(images=image, return_tensors="pt").to(device)
45
  return image, inputs
46
 
47
  def segment_image(image, input_tensor, model):
48
+ try:
49
+ with torch.no_grad():
50
+ outputs = model(input_tensor)
51
+ output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
52
+ mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
53
+ mask = (mask > 0.5).astype(np.uint8) * 255
54
+ return mask
55
+ except Exception as e:
56
+ print(f"Error during segmentation: {e}")
57
+ return np.zeros((512, 512), dtype=np.uint8)
58
 
59
  def estimate_depth(inputs, model):
60
+ try:
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+ depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
64
+ return depth_map
65
+ except Exception as e:
66
+ print(f"Error during depth estimation: {e}")
67
+ return np.zeros((512, 512), dtype=np.float32)
68
 
69
  def normalize_depth_map(depth_map):
70
  min_val = np.min(depth_map)
 
95
  segmentation_model = load_segmentation_model()
96
  depth_processor, depth_model = load_depth_model()
97
 
98
+ if segmentation_model is None or depth_model is None:
99
+ return Image.fromarray(np.zeros((512, 512), dtype=np.uint8)), image, image
100
+
101
  _, input_tensor = process_segmentation_image(image)
102
  _, inputs = process_depth_image(image, depth_processor)
103
 
 
108
 
109
  return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image
110
 
111
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+ print(f"Using device: {device}")
113
+
114
  iface = gr.Interface(
115
  fn=process_image_pipeline,
116
  inputs=gr.Image(type="pil"),
 
124
  )
125
 
126
  if __name__ == "__main__":
127
+ iface.launch(share=False)