reab5555 commited on
Commit
a63d0d6
·
verified ·
1 Parent(s): f05ca8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -6,8 +6,7 @@ from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
  import numpy as np
7
  import os
8
 
9
- # Check if CUDA is available, otherwise use CPU
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
 
12
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
13
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
@@ -33,7 +32,7 @@ def detect_objects_in_frame(image, target):
33
  boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
34
 
35
  for box, score, label in zip(boxes, scores, labels):
36
- if score.item() >= 0.25:
37
  box = [round(i, 2) for i in box.tolist()]
38
  object_label = text[label]
39
  confidence = round(score.item(), 3)
@@ -64,7 +63,7 @@ def process_video(video_path, target, progress=gr.Progress()):
64
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
65
  out = cv2.VideoWriter(output_path, fourcc, original_fps, (int(cap.get(3)), int(cap.get(4))))
66
 
67
- batch_size = 16
68
  frames = []
69
 
70
  for frame in progress.tqdm(range(frame_count)):
 
6
  import numpy as np
7
  import os
8
 
9
+ device = 'cuda'
 
10
 
11
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
12
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
 
32
  boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
33
 
34
  for box, score, label in zip(boxes, scores, labels):
35
+ if score.item() >= 0.5:
36
  box = [round(i, 2) for i in box.tolist()]
37
  object_label = text[label]
38
  confidence = round(score.item(), 3)
 
63
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
64
  out = cv2.VideoWriter(output_path, fourcc, original_fps, (int(cap.get(3)), int(cap.get(4))))
65
 
66
+ batch_size = 64
67
  frames = []
68
 
69
  for frame in progress.tqdm(range(frame_count)):