wjm55 commited on
Commit
1955b0a
·
1 Parent(s): 066ecb2

refactor init_model function to accept model_id parameter and update predict endpoint to use dynamic model initialization; added supervision library to requirements

Browse files
Files changed (2) hide show
  1. app.py +30 -7
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from huggingface_hub import hf_hub_download
8
  from ultralytics import YOLO
9
  import requests
10
-
11
 
12
  ###
13
 
@@ -19,9 +19,8 @@ import requests
19
  #wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
20
 
21
 
22
- def init_model():
23
  is_pf=True
24
- model_id = "yoloe-11s"
25
  # Create a models directory if it doesn't exist
26
  os.makedirs("models", exist_ok=True)
27
  filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
@@ -34,11 +33,16 @@ def init_model():
34
 
35
  app = FastAPI()
36
 
37
- # Initialize model at startup
38
- model = init_model()
39
 
40
  @app.post("/predict")
41
- async def predict(image_url: str, texts: str = "hat"):
 
 
 
 
 
 
 
42
  # Set classes to filter
43
  class_list = [text.strip() for text in texts.split(',')]
44
 
@@ -51,7 +55,7 @@ async def predict(image_url: str, texts: str = "hat"):
51
  model.set_classes(class_list, text_embeddings)
52
 
53
  # Run inference with the PIL Image
54
- results = model.predict(source=image, conf=0.25, iou=0.7)
55
 
56
  # Extract detection results
57
  result = results[0]
@@ -66,6 +70,25 @@ async def predict(image_url: str, texts: str = "hat"):
66
  }
67
  detections.append(detection)
68
  print(detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return {"detections": detections}
70
 
71
  if __name__ == "__main__":
 
7
  from huggingface_hub import hf_hub_download
8
  from ultralytics import YOLO
9
  import requests
10
+ import supervision as sv
11
 
12
  ###
13
 
 
19
  #wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
20
 
21
 
22
+ def init_model(model_id: str):
23
  is_pf=True
 
24
  # Create a models directory if it doesn't exist
25
  os.makedirs("models", exist_ok=True)
26
  filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
 
33
 
34
  app = FastAPI()
35
 
 
 
36
 
37
  @app.post("/predict")
38
+ async def predict(image_url: str,
39
+ texts: str = "hat",
40
+ model_id: str = "yoloe-11l",
41
+ conf: float = 0.25,
42
+ iou: float = 0.7
43
+ ):
44
+ # Initialize model at startup
45
+ model = init_model(model_id)
46
  # Set classes to filter
47
  class_list = [text.strip() for text in texts.split(',')]
48
 
 
55
  model.set_classes(class_list, text_embeddings)
56
 
57
  # Run inference with the PIL Image
58
+ results = model.predict(source=image, conf=conf, iou=iou)
59
 
60
  # Extract detection results
61
  result = results[0]
 
70
  }
71
  detections.append(detection)
72
  print(detections)
73
+ # detections = sv.Detections.from_ultralytics(results[0])
74
+
75
+ # resolution_wh = image.size
76
+ # thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
77
+ # text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
78
+
79
+ # labels = [
80
+ # f"{class_name} {confidence:.2f}"
81
+ # for class_name, confidence
82
+ # in zip(detections['class_name'], detections.confidence)
83
+ # ]
84
+
85
+ # annotated_image = image.copy()
86
+ # annotated_image = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX, opacity=0.4).annotate(
87
+ # scene=annotated_image, detections=detections)
88
+ # annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
89
+ # scene=annotated_image, detections=detections)
90
+ # annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
91
+ # scene=annotated_image, detections=detections, labels=labels)
92
  return {"detections": detections}
93
 
94
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
 
3
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP
4
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip
5
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api
 
1
  fastapi
2
  uvicorn[standard]
3
+ supervision
4
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP
5
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip
6
  git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api