markany-yhkwon commited on
Commit
4a72a2c
·
1 Parent(s): 5bccb37
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -33,13 +33,13 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
33
  ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
34
 
35
 
36
- def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
37
  args = SLConfig.fromfile(model_config_path)
38
  model = build_model(args)
39
  args.device = device
40
 
41
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
42
- checkpoint = torch.load(cache_file, map_location='cpu')
43
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
44
  print("Model loaded from {} \n => {}".format(cache_file, log))
45
  _ = model.eval()
@@ -72,7 +72,7 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
  # run grounidng
75
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
76
  annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
@@ -96,7 +96,7 @@ if __name__ == "__main__":
96
  with gr.Blocks(css=css) as demo:
97
  gr.Markdown("<h1><center>Grounding DINO<h1><center>")
98
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
99
- gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
100
 
101
  with gr.Row():
102
  with gr.Column():
 
33
  ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
34
 
35
 
36
+ def load_model_hf(model_config_path, repo_id, filename, device='cuda'):
37
  args = SLConfig.fromfile(model_config_path)
38
  model = build_model(args)
39
  args.device = device
40
 
41
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
42
+ checkpoint = torch.load(cache_file, map_location=device)
43
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
44
  print("Model loaded from {} \n => {}".format(cache_file, log))
45
  _ = model.eval()
 
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
  # run grounidng
75
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cuda')
76
  annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
 
96
  with gr.Blocks(css=css) as demo:
97
  gr.Markdown("<h1><center>Grounding DINO<h1><center>")
98
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
99
+ gr.Markdown("<h3><center>Running on GPU for faster inference<h3><center>")
100
 
101
  with gr.Row():
102
  with gr.Column():