obichimav commited on
Commit
e7d4da2
·
verified ·
1 Parent(s): 9b71c84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -71,20 +71,23 @@ def encode_image_to_base64(image_array):
71
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
72
 
73
 
74
- def format_query_for_model(text_input, model_type="owlv2"):
75
  """Format query based on model requirements"""
76
- # Extract objects (e.g., "count frogs and horses" -> ["frog", "horse"])
77
  text = text_input.lower()
78
  words = [w.strip('.,?!') for w in text.split()
79
  if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
80
 
81
  if model_type == "owlv2":
82
- print([["a photo of " + obj for obj in words]])
83
- return [["a photo of " + obj for obj in words]]
 
 
84
  else: # DINO
85
- # DINO only works with single object queries with format "a object."
86
- print(f"a {words[0]}.")
87
- return f"a {words[0]}."
 
88
 
89
  def detect_objects(query_text):
90
  if state.current_image is None:
@@ -94,6 +97,7 @@ def detect_objects(query_text):
94
  draw = ImageDraw.Draw(image)
95
 
96
  if state.current_model == "owlv2":
 
97
  inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
98
  with torch.no_grad():
99
  outputs = owlv2_model(**inputs)
@@ -101,6 +105,7 @@ def detect_objects(query_text):
101
  outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
102
  )
103
  else: # DINO
 
104
  inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
105
  with torch.no_grad():
106
  outputs = dino_model(**inputs)
@@ -125,7 +130,6 @@ def detect_objects(query_text):
125
  "message": f"Detected {len(boxes)} objects"
126
  }
127
 
128
-
129
  def identify_plant():
130
  if state.current_image is None:
131
  return {"error": "No image provided"}
 
71
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
72
 
73
 
74
+ # def format_query_for_model(text_input, model_type="owlv2"):
75
  """Format query based on model requirements"""
76
+ # Extract objects (e.g., "detect a lion" -> "lion")
77
  text = text_input.lower()
78
  words = [w.strip('.,?!') for w in text.split()
79
  if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
80
 
81
  if model_type == "owlv2":
82
+ # Return just the list of queries for Owlv2, not nested list
83
+ queries = ["a photo of " + obj for obj in words]
84
+ print("Owlv2 queries:", queries)
85
+ return queries
86
  else: # DINO
87
+ # DINO query format
88
+ query = f"a {words[0]}."
89
+ print("DINO query:", query)
90
+ return query
91
 
92
  def detect_objects(query_text):
93
  if state.current_image is None:
 
97
  draw = ImageDraw.Draw(image)
98
 
99
  if state.current_model == "owlv2":
100
+ # For Owlv2, pass the text queries directly
101
  inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
102
  with torch.no_grad():
103
  outputs = owlv2_model(**inputs)
 
105
  outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
106
  )
107
  else: # DINO
108
+ # For DINO, pass the single text query
109
  inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
110
  with torch.no_grad():
111
  outputs = dino_model(**inputs)
 
130
  "message": f"Detected {len(boxes)} objects"
131
  }
132
 
 
133
  def identify_plant():
134
  if state.current_image is None:
135
  return {"error": "No image provided"}