huntrezz commited on
Commit
fd26002
·
verified ·
1 Parent(s): 893be2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -18,9 +18,12 @@ parameters_to_prune = [
18
  prune.global_unstructured(
19
  parameters_to_prune,
20
  pruning_method=prune.L1Unstructured,
21
- amount=0.4, # Prune 40% of weights
22
  )
23
 
 
 
 
24
  # Apply quantization after pruning
25
  model = torch.quantization.quantize_dynamic(
26
  model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
@@ -33,20 +36,20 @@ color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFER
33
  input_tensor = torch.zeros((1, 3, 128, 128), dtype=torch.float32, device=device)
34
 
35
  def preprocess_image(image):
36
- return cv2.resize(image, (128, 128), interpolation=cv2.INTER_AREA).transpose(2, 0, 1).astype(np.float32) / 255.0
37
 
38
  @torch.inference_mode()
39
  def process_frame(image):
40
  if image is None:
41
  return None
42
  preprocessed = preprocess_image(image)
43
- input_tensor[0] = torch.from_numpy(preprocessed).to(device)
44
 
45
  predicted_depth = model(input_tensor).predicted_depth
46
  depth_map = predicted_depth.squeeze().cpu().numpy()
47
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
48
  depth_map = (depth_map * 255).astype(np.uint8)
49
- depth_map_colored = cv2.applyColorMap(depth_map, color_map)
50
 
51
  return cv2.cvtColor(depth_map_colored, cv2.COLOR_BGR2RGB)
52
 
 
18
  prune.global_unstructured(
19
  parameters_to_prune,
20
  pruning_method=prune.L1Unstructured,
21
+ amount=0.3, # Prune 30% of weights
22
  )
23
 
24
+ for module, _ in parameters_to_prune:
25
+ prune.remove(module, "weight")
26
+
27
  # Apply quantization after pruning
28
  model = torch.quantization.quantize_dynamic(
29
  model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
 
36
  input_tensor = torch.zeros((1, 3, 128, 128), dtype=torch.float32, device=device)
37
 
38
  def preprocess_image(image):
39
+ return cv2.resize(image, (128, 72), interpolation=cv2.INTER_AREA).transpose(2, 0, 1).astype(np.float32) / 255.0
40
 
41
  @torch.inference_mode()
42
  def process_frame(image):
43
  if image is None:
44
  return None
45
  preprocessed = preprocess_image(image)
46
+ input_tensor = torch.from_numpy(preprocessed).unsqueeze(0).to(device)
47
 
48
  predicted_depth = model(input_tensor).predicted_depth
49
  depth_map = predicted_depth.squeeze().cpu().numpy()
50
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
51
  depth_map = (depth_map * 255).astype(np.uint8)
52
+ depth_map_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_INFERNO)
53
 
54
  return cv2.cvtColor(depth_map_colored, cv2.COLOR_BGR2RGB)
55