Update app.py
Browse files
app.py
CHANGED
@@ -71,20 +71,23 @@ def encode_image_to_base64(image_array):
|
|
71 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
72 |
|
73 |
|
74 |
-
def format_query_for_model(text_input, model_type="owlv2"):
|
75 |
"""Format query based on model requirements"""
|
76 |
-
# Extract objects (e.g., "
|
77 |
text = text_input.lower()
|
78 |
words = [w.strip('.,?!') for w in text.split()
|
79 |
if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
|
80 |
|
81 |
if model_type == "owlv2":
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
else: # DINO
|
85 |
-
# DINO
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
def detect_objects(query_text):
|
90 |
if state.current_image is None:
|
@@ -94,6 +97,7 @@ def detect_objects(query_text):
|
|
94 |
draw = ImageDraw.Draw(image)
|
95 |
|
96 |
if state.current_model == "owlv2":
|
|
|
97 |
inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
|
98 |
with torch.no_grad():
|
99 |
outputs = owlv2_model(**inputs)
|
@@ -101,6 +105,7 @@ def detect_objects(query_text):
|
|
101 |
outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
|
102 |
)
|
103 |
else: # DINO
|
|
|
104 |
inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
|
105 |
with torch.no_grad():
|
106 |
outputs = dino_model(**inputs)
|
@@ -125,7 +130,6 @@ def detect_objects(query_text):
|
|
125 |
"message": f"Detected {len(boxes)} objects"
|
126 |
}
|
127 |
|
128 |
-
|
129 |
def identify_plant():
|
130 |
if state.current_image is None:
|
131 |
return {"error": "No image provided"}
|
|
|
71 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
72 |
|
73 |
|
74 |
+
# def format_query_for_model(text_input, model_type="owlv2"):
|
75 |
"""Format query based on model requirements"""
|
76 |
+
# Extract objects (e.g., "detect a lion" -> "lion")
|
77 |
text = text_input.lower()
|
78 |
words = [w.strip('.,?!') for w in text.split()
|
79 |
if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
|
80 |
|
81 |
if model_type == "owlv2":
|
82 |
+
# Return just the list of queries for Owlv2, not nested list
|
83 |
+
queries = ["a photo of " + obj for obj in words]
|
84 |
+
print("Owlv2 queries:", queries)
|
85 |
+
return queries
|
86 |
else: # DINO
|
87 |
+
# DINO query format
|
88 |
+
query = f"a {words[0]}."
|
89 |
+
print("DINO query:", query)
|
90 |
+
return query
|
91 |
|
92 |
def detect_objects(query_text):
|
93 |
if state.current_image is None:
|
|
|
97 |
draw = ImageDraw.Draw(image)
|
98 |
|
99 |
if state.current_model == "owlv2":
|
100 |
+
# For Owlv2, pass the text queries directly
|
101 |
inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
|
102 |
with torch.no_grad():
|
103 |
outputs = owlv2_model(**inputs)
|
|
|
105 |
outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
|
106 |
)
|
107 |
else: # DINO
|
108 |
+
# For DINO, pass the single text query
|
109 |
inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
|
110 |
with torch.no_grad():
|
111 |
outputs = dino_model(**inputs)
|
|
|
130 |
"message": f"Detected {len(boxes)} objects"
|
131 |
}
|
132 |
|
|
|
133 |
def identify_plant():
|
134 |
if state.current_image is None:
|
135 |
return {"error": "No image provided"}
|