Nadav Eden commited on
Commit
b7a2f31
·
1 Parent(s): e425a6f

adding vlm support

Browse files
Files changed (3) hide show
  1. app.py +108 -10
  2. requirements.txt +4 -0
  3. utils.py +42 -0
app.py CHANGED
@@ -1,7 +1,11 @@
1
  #!/usr/bin/env python3
2
 
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
5
 
6
  llms = {
7
  "Qwen2-1.5B": {"model": "Qwen/Qwen2-1.5B-Instruct", "prefix": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
@@ -12,25 +16,31 @@ llms = {
12
  "DeepSeek-Coder": {"model": "DeepSeek/DeepSeek-Coder", "prefix": "You are a helpful assistant."},
13
  }
14
 
15
- vlms = dict()
 
 
 
 
 
 
 
 
16
 
17
- def run_example(text_input, model_id="Qwen2-1.5B"):
18
  global messages
19
  tokenizer = AutoTokenizer.from_pretrained(llms[model_id]["model"], trust_remote_code=True)
20
  model = AutoModelForCausalLM.from_pretrained(llms[model_id]["model"], trust_remote_code=True)
21
-
22
- system_prompt = llms[model_id]["prefix"]
23
 
24
  if messages is None:
25
  messages = [
26
- {"role": "system", "content": system_prompt},
27
  {"role": "user", "content": text_input},
28
  ]
29
  else:
30
  messages.append({"role": "user", "content": text_input})
31
 
32
 
33
- text = tokenizer.apply_chat_template(
34
  messages,
35
  tokenize=False,
36
  add_generation_prompt=True,
@@ -51,12 +61,86 @@ def run_example(text_input, model_id="Qwen2-1.5B"):
51
 
52
  return response
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  messages = list()
56
  def reset_conversation():
57
  global messages
58
  messages = list()
59
 
 
 
 
 
 
60
  with gr.Blocks() as demo:
61
  gr.Markdown(
62
  """
@@ -74,20 +158,34 @@ with gr.Blocks() as demo:
74
  model_output_text = gr.Textbox(label="Model Output Text")
75
 
76
 
77
- submit_btn.click(run_example,
78
  [text_input, model_selector],
79
  [model_output_text])
80
 
81
  reset_btn.click(reset_conversation)
82
 
83
  with gr.Tab(label="VLM (WIP)"):
 
84
  with gr.Row():
85
  with gr.Column():
86
  input_img = gr.Image(label="Input Image", type="pil")
87
- model_selector = gr.Dropdown(choices=list(vlms.keys()), label="Model", value="Qwen2-1.5B")
 
88
  text_input = gr.Textbox(label="User Prompt")
89
  submit_btn = gr.Button(value="Submit")
90
- reset_btn = gr.Button(value="Reset conversation")
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  if __name__ == "__main__":
93
  demo.launch()
 
1
  #!/usr/bin/env python3
2
 
3
  import gradio as gr
4
+ from PIL import Image
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, Qwen2VLForConditionalGeneration
6
+ from utils import image_to_base64, rescale_bounding_boxes, draw_bounding_boxes, florence_draw_bboxes
7
+ from qwen_vl_utils import process_vision_info
8
+ import re
9
 
10
  llms = {
11
  "Qwen2-1.5B": {"model": "Qwen/Qwen2-1.5B-Instruct", "prefix": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
 
16
  "DeepSeek-Coder": {"model": "DeepSeek/DeepSeek-Coder", "prefix": "You are a helpful assistant."},
17
  }
18
 
19
+ vlms = {
20
+ "Florence-2-base": {"model": "microsoft/Florence-2-base", "prefix": "help me"},
21
+ "Florence-2-large": {"model": "microsoft/Florence-2-large", "prefix": "help me"},
22
+ "Qwen2-vl-2B": {"model": "Qwen/Qwen2-VL-2B-Instruct", "prefix": "You are a helpfull assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] whith the values beeing scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."},
23
+ "Qwen2-vl-7B": {"model": "Qwen/Qwen2-VL-7B-Instruct", "prefix": "You are a helpfull assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] whith the values beeing scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."},
24
+ "Qwen2.5-vl-3B": {"model": "Qwen/Qwen2.5-VL-3B-Instruct", "prefix": "You are a helpfull assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] whith the values beeing scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."}
25
+ }
26
+
27
+ tasks = ["<OD>", "<OCR>", "<CAPTION>", "<OCR_WITH_REGION>"]
28
 
29
+ def run_llm(text_input, model_id="Qwen2-1.5B"):
30
  global messages
31
  tokenizer = AutoTokenizer.from_pretrained(llms[model_id]["model"], trust_remote_code=True)
32
  model = AutoModelForCausalLM.from_pretrained(llms[model_id]["model"], trust_remote_code=True)
 
 
33
 
34
  if messages is None:
35
  messages = [
36
+ {"role": "system", "content": llms[model_id]["prefix"]},
37
  {"role": "user", "content": text_input},
38
  ]
39
  else:
40
  messages.append({"role": "user", "content": text_input})
41
 
42
 
43
+ text = tokenizer.apply_chat_template (
44
  messages,
45
  tokenize=False,
46
  add_generation_prompt=True,
 
61
 
62
  return response
63
 
64
+ def run_vlm(image, text_input, model_id="Qwen2-vl-2B", prompt = "<OD>"):
65
+ if "Qwen" in model_id:
66
+ model = Qwen2VLForConditionalGeneration.from_pretrained(vlms[model_id]["model"], torch_dtype="auto", device_map="auto")
67
+ else:
68
+ model = AutoModelForCausalLM.from_pretrained(vlms[model_id]["model"], trust_remote_code=True)
69
+ processor = AutoProcessor.from_pretrained(vlms[model_id]["model"], trust_remote_code=True)
70
+
71
+ if "Qwen" in model_id:
72
+ messages = [
73
+ {
74
+ "role": "user",
75
+ "content": [
76
+ {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
77
+ {"type": "text", "text": vlms[model_id]["prefix"]},
78
+ {"type": "text", "text": text_input},
79
+ ],
80
+ }
81
+ ]
82
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
83
+ image_inputs, video_inputs = process_vision_info(messages)
84
+ inputs = processor(
85
+ text=[text],
86
+ images=image_inputs,
87
+ videos=video_inputs,
88
+ padding=True,
89
+ return_tensors="pt",
90
+ ).to(model.device)
91
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
92
+ generated_ids_trimmed = [
93
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
94
+ ]
95
+ output_text = processor.batch_decode(
96
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
97
+ )
98
+ print(output_text)
99
+ pattern = r'\[\s*([.\d]+)\s*,\s*([.\d]+)\s*,\s*([.\d]+)\s*,\s*([.\d]+)\s*\]'
100
+ matches = re.findall(pattern, str(output_text))
101
+ parsed_boxes = [[float(num) for num in match] for match in matches]
102
+ scaled_boxes = rescale_bounding_boxes(parsed_boxes, image.width, image.height)
103
+ print(scaled_boxes)
104
+ draw = draw_bounding_boxes(image, scaled_boxes)
105
+ else:
106
+ messages = prompt + text_input
107
+ inputs = processor(text=messages, images=image, return_tensors="pt").to(model.device)
108
+ generated_ids = model.generate(
109
+ input_ids=inputs["input_ids"],
110
+ pixel_values=inputs["pixel_values"],
111
+ max_new_tokens=1024,
112
+ early_stopping=False,
113
+ do_sample=False,
114
+ num_beams=3,
115
+ )
116
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
117
+ parsed_answer = processor.post_process_generation(
118
+ generated_text,
119
+ task=prompt,
120
+ image_size=(image.width, image.height)
121
+ )
122
+ print(parsed_answer)
123
+ if prompt == '<OD>':
124
+ parsed_boxes = parsed_answer['<OD>']['bboxes']
125
+ draw = florence_draw_bboxes(image, parsed_answer)
126
+ output_text = "None"
127
+ elif prompt == '<OCR>':
128
+ output_text = parsed_answer['<OCR>']
129
+ draw = image
130
+ parsed_boxes = None
131
+
132
+ return output_text, parsed_boxes, draw
133
 
134
  messages = list()
135
  def reset_conversation():
136
  global messages
137
  messages = list()
138
 
139
+ def update_task_dropdown(model):
140
+ if "Florence" in model:
141
+ return gr.Dropdown(visible=True)
142
+ return gr.Dropdown(visible=False)
143
+
144
  with gr.Blocks() as demo:
145
  gr.Markdown(
146
  """
 
158
  model_output_text = gr.Textbox(label="Model Output Text")
159
 
160
 
161
+ submit_btn.click(run_llm,
162
  [text_input, model_selector],
163
  [model_output_text])
164
 
165
  reset_btn.click(reset_conversation)
166
 
167
  with gr.Tab(label="VLM (WIP)"):
168
+ # taken from https://huggingface.co/spaces/maxiw/Qwen2-VL-Detection/blob/main/app.py
169
  with gr.Row():
170
  with gr.Column():
171
  input_img = gr.Image(label="Input Image", type="pil")
172
+ model_selector = gr.Dropdown(choices=list(vlms.keys()), label="Model", value="Florence-2-base")
173
+ task_select = gr.Dropdown(choices=tasks, label="task", value= "<OD>")
174
  text_input = gr.Textbox(label="User Prompt")
175
  submit_btn = gr.Button(value="Submit")
176
+ with gr.Column():
177
+ model_output_text = gr.Textbox(label="Model Output Text")
178
+ parsed_boxes = gr.Textbox(label="Parsed Boxes")
179
+ annotated_image = gr.Image(label="Annotated Image")
180
+
181
+ model_selector.change(update_task_dropdown, inputs=model_selector, outputs=task_select)
182
+
183
+
184
+ submit_btn.click(run_vlm,
185
+ [input_img, text_input, model_selector, task_select],
186
+ [model_output_text, parsed_boxes, annotated_image])
187
+
188
+
189
 
190
  if __name__ == "__main__":
191
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  huggingface_hub==0.25.2
2
  torch
 
3
  transformers
4
  gradio
 
 
 
 
1
  huggingface_hub==0.25.2
2
  torch
3
+ torchvision
4
  transformers
5
  gradio
6
+ Pillow
7
+ qwen_vl_utils
8
+ accelerate>=0.26.0
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from PIL import ImageDraw, ImageFont
3
+ from io import BytesIO
4
+
5
+ def image_to_base64(image):
6
+ buffered = BytesIO()
7
+ image.save(buffered, format="PNG")
8
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
9
+ return img_str
10
+
11
+
12
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
13
+ draw = ImageDraw.Draw(image)
14
+ for box in bounding_boxes:
15
+ xmin, ymin, xmax, ymax = box
16
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
17
+ return image
18
+
19
+ def florence_draw_bboxes(image, bounding_boxes, outline_color="red", line_width=2):
20
+ draw = ImageDraw.Draw(image)
21
+ #font = ImageFont.truetype("sans-serif.ttf", 16)
22
+ for bbox, label in zip(bounding_boxes['<OD>']['bboxes'], bounding_boxes['<OD>']['labels']):
23
+ x1, y1, x2, y2 = bbox
24
+ draw.rectangle([x1, y1, x2, y2], outline=outline_color, width=line_width)
25
+ draw.text((x1, x2), label, (255,255,255))
26
+ return image
27
+
28
+
29
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
30
+ x_scale = original_width / scaled_width
31
+ y_scale = original_height / scaled_height
32
+ rescaled_boxes = []
33
+ for box in bounding_boxes:
34
+ xmin, ymin, xmax, ymax = box
35
+ rescaled_box = [
36
+ xmin * x_scale,
37
+ ymin * y_scale,
38
+ xmax * x_scale,
39
+ ymax * y_scale
40
+ ]
41
+ rescaled_boxes.append(rescaled_box)
42
+ return rescaled_boxes