alessandro trinca tornidor commited on
Commit
8959fb9
·
1 Parent(s): d60d246

[refactor] move routes to dedicated routes.py, move app helper functions to dedicated app_helpers.py

Browse files
Files changed (3) hide show
  1. app/main.py +6 -333
  2. app/routes.py +19 -0
  3. utils/app_helpers.py +322 -0
app/main.py CHANGED
@@ -1,360 +1,33 @@
1
- import argparse
2
  import logging
3
  import os
4
- import re
5
  import sys
6
- from typing import Callable
7
-
8
- import cv2
9
  import gradio as gr
10
- import nh3
11
- import numpy as np
12
- import torch
13
- import torch.nn.functional as F
14
  from fastapi import FastAPI
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
- from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
 
19
- from model.LISA import LISAForCausalLM
20
- from model.llava import conversation as conversation_lib
21
- from model.llava.mm_utils import tokenizer_image_token
22
- from model.segment_anything.utils.transforms import ResizeLongestSide
23
- from utils import constants, session_logger, utils
24
 
25
 
26
  session_logger.change_logging(logging.DEBUG)
27
 
28
  CUSTOM_GRADIO_PATH = "/"
29
  app = FastAPI(title="lisa_app", version="1.0")
 
30
 
31
  FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
32
  os.makedirs(FASTAPI_STATIC, exist_ok=True)
33
  app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
34
  templates = Jinja2Templates(directory="templates")
35
- placeholders = utils.create_placeholder_variables()
36
-
37
-
38
- @app.get("/health")
39
- @session_logger.set_uuid_logging
40
- def health() -> str:
41
- import json
42
-
43
- try:
44
- logging.info("health check")
45
- return json.dumps({"msg": "ok"})
46
- except Exception as e:
47
- logging.error(f"exception:{e}.")
48
- return json.dumps({"msg": "request failed"})
49
-
50
-
51
- @session_logger.set_uuid_logging
52
- def parse_args(args_to_parse):
53
- parser = argparse.ArgumentParser(description="LISA chat")
54
- parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory")
55
- parser.add_argument("--vis_save_path", default="./vis_output", type=str)
56
- parser.add_argument(
57
- "--precision",
58
- default="fp16",
59
- type=str,
60
- choices=["fp32", "bf16", "fp16"],
61
- help="precision for inference",
62
- )
63
- parser.add_argument("--image_size", default=1024, type=int, help="image size")
64
- parser.add_argument("--model_max_length", default=512, type=int)
65
- parser.add_argument("--lora_r", default=8, type=int)
66
- parser.add_argument(
67
- "--vision-tower", default="openai/clip-vit-large-patch14", type=str
68
- )
69
- parser.add_argument("--local-rank", default=0, type=int, help="node rank")
70
- parser.add_argument("--load_in_8bit", action="store_true", default=False)
71
- parser.add_argument("--load_in_4bit", action="store_true", default=True)
72
- parser.add_argument("--use_mm_start_end", action="store_true", default=True)
73
- parser.add_argument(
74
- "--conv_type",
75
- default="llava_v1",
76
- type=str,
77
- choices=["llava_v1", "llava_llama_2"],
78
- )
79
- return parser.parse_args(args_to_parse)
80
-
81
-
82
- @session_logger.set_uuid_logging
83
- def get_cleaned_input(input_str):
84
- logging.info(f"start cleaning of input_str: {input_str}.")
85
- input_str = nh3.clean(
86
- input_str,
87
- tags={
88
- "a",
89
- "abbr",
90
- "acronym",
91
- "b",
92
- "blockquote",
93
- "code",
94
- "em",
95
- "i",
96
- "li",
97
- "ol",
98
- "strong",
99
- "ul",
100
- },
101
- attributes={
102
- "a": {"href", "title"},
103
- "abbr": {"title"},
104
- "acronym": {"title"},
105
- },
106
- url_schemes={"http", "https", "mailto"},
107
- link_rel=None,
108
- )
109
- logging.info(f"cleaned input_str: {input_str}.")
110
- return input_str
111
-
112
-
113
- @session_logger.set_uuid_logging
114
- def set_image_precision_by_args(input_image, precision):
115
- if precision == "bf16":
116
- input_image = input_image.bfloat16()
117
- elif precision == "fp16":
118
- input_image = input_image.half()
119
- else:
120
- input_image = input_image.float()
121
- return input_image
122
-
123
-
124
- @session_logger.set_uuid_logging
125
- def preprocess(
126
- x,
127
- pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
128
- pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
129
- img_size=1024,
130
- ) -> torch.Tensor:
131
- """Normalize pixel values and pad to a square input."""
132
- logging.info("preprocess started")
133
- # Normalize colors
134
- x = (x - pixel_mean) / pixel_std
135
- # Pad
136
- h, w = x.shape[-2:]
137
- padh = img_size - h
138
- padw = img_size - w
139
- x = F.pad(x, (0, padw, 0, padh))
140
- logging.info("preprocess ended")
141
- return x
142
-
143
-
144
- @session_logger.set_uuid_logging
145
- def get_model(args_to_parse):
146
- logging.info("starting model preparation...")
147
- os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
148
-
149
- # global tokenizer, tokenizer
150
- # Create model
151
- _tokenizer = AutoTokenizer.from_pretrained(
152
- args_to_parse.version,
153
- cache_dir=None,
154
- model_max_length=args_to_parse.model_max_length,
155
- padding_side="right",
156
- use_fast=False,
157
- )
158
- _tokenizer.pad_token = _tokenizer.unk_token
159
- args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
160
- torch_dtype = torch.float32
161
- if args_to_parse.precision == "bf16":
162
- torch_dtype = torch.bfloat16
163
- elif args_to_parse.precision == "fp16":
164
- torch_dtype = torch.half
165
- kwargs = {"torch_dtype": torch_dtype}
166
- if args_to_parse.load_in_4bit:
167
- kwargs.update(
168
- {
169
- "torch_dtype": torch.half,
170
- "load_in_4bit": True,
171
- "quantization_config": BitsAndBytesConfig(
172
- load_in_4bit=True,
173
- bnb_4bit_compute_dtype=torch.float16,
174
- bnb_4bit_use_double_quant=True,
175
- bnb_4bit_quant_type="nf4",
176
- llm_int8_skip_modules=["visual_model"],
177
- ),
178
- }
179
- )
180
- elif args_to_parse.load_in_8bit:
181
- kwargs.update(
182
- {
183
- "torch_dtype": torch.half,
184
- "quantization_config": BitsAndBytesConfig(
185
- llm_int8_skip_modules=["visual_model"],
186
- load_in_8bit=True,
187
- ),
188
- }
189
- )
190
- _model = LISAForCausalLM.from_pretrained(
191
- args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
192
- )
193
- _model.config.eos_token_id = _tokenizer.eos_token_id
194
- _model.config.bos_token_id = _tokenizer.bos_token_id
195
- _model.config.pad_token_id = _tokenizer.pad_token_id
196
- _model.get_model().initialize_vision_modules(_model.get_model().config)
197
- vision_tower = _model.get_model().get_vision_tower()
198
- vision_tower.to(dtype=torch_dtype)
199
- if args_to_parse.precision == "bf16":
200
- _model = _model.bfloat16().cuda()
201
- elif (
202
- args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
203
- ):
204
- vision_tower = _model.get_model().get_vision_tower()
205
- _model.model.vision_tower = None
206
- import deepspeed
207
-
208
- model_engine = deepspeed.init_inference(
209
- model=_model,
210
- dtype=torch.half,
211
- replace_with_kernel_inject=True,
212
- replace_method="auto",
213
- )
214
- _model = model_engine.module
215
- _model.model.vision_tower = vision_tower.half().cuda()
216
- elif args_to_parse.precision == "fp32":
217
- _model = _model.float().cuda()
218
- vision_tower = _model.get_model().get_vision_tower()
219
- vision_tower.to(device=args_to_parse.local_rank)
220
- _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
221
- _transform = ResizeLongestSide(args_to_parse.image_size)
222
- _model.eval()
223
- logging.info("model preparation ok!")
224
- return _model, _clip_image_processor, _tokenizer, _transform
225
-
226
-
227
- @session_logger.set_uuid_logging
228
- def get_inference_model_by_args(args_to_parse):
229
- logging.info(f"args_to_parse:{args_to_parse}, creating model...")
230
- model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
231
- logging.info("created model, preparing inference function")
232
- no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
233
-
234
- @session_logger.set_uuid_logging
235
- def inference(input_str, input_image):
236
- ## filter out special chars
237
-
238
- input_str = get_cleaned_input(input_str)
239
- logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
240
- logging.info(f"input_str: {input_str}.")
241
-
242
- ## input valid check
243
- if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
244
- output_str = "[Error] Invalid input: ", input_str
245
- return error_happened, output_str
246
-
247
- # Model Inference
248
- conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
249
- conv.messages = []
250
-
251
- prompt = input_str
252
- prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
253
- if args_to_parse.use_mm_start_end:
254
- replace_token = (
255
- utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
256
- )
257
- prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
258
-
259
- conv.append_message(conv.roles[0], prompt)
260
- conv.append_message(conv.roles[1], "")
261
- prompt = conv.get_prompt()
262
-
263
- image_np = cv2.imread(input_image)
264
- image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
265
- original_size_list = [image_np.shape[:2]]
266
-
267
- image_clip = (
268
- clip_image_processor.preprocess(image_np, return_tensors="pt")[
269
- "pixel_values"
270
- ][0]
271
- .unsqueeze(0)
272
- .cuda()
273
- )
274
- logging.info(f"image_clip type: {type(image_clip)}.")
275
- image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
276
-
277
- image = transform.apply_image(image_np)
278
- resize_list = [image.shape[:2]]
279
-
280
- image = (
281
- preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
282
- .unsqueeze(0)
283
- .cuda()
284
- )
285
- logging.info(f"image_clip type: {type(image_clip)}.")
286
- image = set_image_precision_by_args(image, args_to_parse.precision)
287
-
288
- input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
289
- input_ids = input_ids.unsqueeze(0).cuda()
290
-
291
- output_ids, pred_masks = model.evaluate(
292
- image_clip,
293
- image,
294
- input_ids,
295
- resize_list,
296
- original_size_list,
297
- max_new_tokens=512,
298
- tokenizer=tokenizer,
299
- )
300
- output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
301
-
302
- text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
303
- text_output = text_output.replace("\n", "").replace(" ", " ")
304
- text_output = text_output.split("ASSISTANT: ")[-1]
305
-
306
- logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.")
307
- save_img = None
308
- for i, pred_mask in enumerate(pred_masks):
309
- if pred_mask.shape[0] == 0:
310
- continue
311
-
312
- pred_mask = pred_mask.detach().cpu().numpy()[0]
313
- pred_mask = pred_mask > 0
314
-
315
- save_img = image_np.copy()
316
- save_img[pred_mask] = (
317
- image_np * 0.5
318
- + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
319
- )[pred_mask]
320
-
321
- output_str = f"ASSISTANT: {text_output}"
322
- output_image = no_seg_out if save_img is None else save_img
323
- logging.info(f"output_image type: {type(output_image)}.")
324
- return output_image, output_str
325
-
326
- logging.info("prepared inference function!")
327
- return inference
328
-
329
-
330
- @session_logger.set_uuid_logging
331
- def get_gradio_interface(
332
- fn_inference: Callable
333
- ):
334
- return gr.Interface(
335
- fn_inference,
336
- inputs=[
337
- gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
338
- gr.Image(type="filepath", label="Input Image")
339
- ],
340
- outputs=[
341
- gr.Image(type="pil", label="Segmentation Output"),
342
- gr.Textbox(lines=1, placeholder=None, label="Text Output")
343
- ],
344
- title=constants.title,
345
- description=constants.description,
346
- article=constants.article,
347
- examples=constants.examples,
348
- allow_flagging="auto"
349
- )
350
 
351
 
352
  logging.info(f"sys.argv:{sys.argv}.")
353
- args = parse_args([])
354
  logging.info(f"prepared default arguments:{args}.")
355
- inference_fn = get_inference_model_by_args(args)
356
  logging.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
357
- io = get_gradio_interface(inference_fn)
358
  logging.info("created gradio interface")
359
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
360
  logging.info("mounted gradio app within fastapi")
 
 
1
  import logging
2
  import os
 
3
  import sys
 
 
 
4
  import gradio as gr
 
 
 
 
5
  from fastapi import FastAPI
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.templating import Jinja2Templates
 
8
 
9
+ from app import routes
10
+ from utils import app_helpers, session_logger
 
 
 
11
 
12
 
13
  session_logger.change_logging(logging.DEBUG)
14
 
15
  CUSTOM_GRADIO_PATH = "/"
16
  app = FastAPI(title="lisa_app", version="1.0")
17
+ app.include_router(routes.router)
18
 
19
  FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
20
  os.makedirs(FASTAPI_STATIC, exist_ok=True)
21
  app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
22
  templates = Jinja2Templates(directory="templates")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  logging.info(f"sys.argv:{sys.argv}.")
26
+ args = app_helpers.parse_args([])
27
  logging.info(f"prepared default arguments:{args}.")
28
+ inference_fn = app_helpers.get_inference_model_by_args(args)
29
  logging.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
30
+ io = app_helpers.get_gradio_interface(inference_fn)
31
  logging.info("created gradio interface")
32
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
33
  logging.info("mounted gradio app within fastapi")
app/routes.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from fastapi import APIRouter
4
+
5
+ from utils import session_logger
6
+
7
+
8
+ router = APIRouter()
9
+
10
+
11
+ @router.get("/health")
12
+ @session_logger.set_uuid_logging
13
+ def health() -> str:
14
+ try:
15
+ logging.info("health check")
16
+ return json.dumps({"msg": "ok"})
17
+ except Exception as e:
18
+ logging.error(f"exception:{e}.")
19
+ return json.dumps({"msg": "request failed"})
utils/app_helpers.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import re
5
+ from typing import Callable
6
+ import cv2
7
+ import gradio as gr
8
+ import nh3
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
13
+
14
+ from . import constants, session_logger, utils
15
+ from model.LISA import LISAForCausalLM
16
+ from model.llava import conversation as conversation_lib
17
+ from model.llava.mm_utils import tokenizer_image_token
18
+ from model.segment_anything.utils.transforms import ResizeLongestSide
19
+
20
+
21
+ placeholders = utils.create_placeholder_variables()
22
+
23
+
24
+ @session_logger.set_uuid_logging
25
+ def parse_args(args_to_parse):
26
+ parser = argparse.ArgumentParser(description="LISA chat")
27
+ parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory")
28
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
29
+ parser.add_argument(
30
+ "--precision",
31
+ default="fp16",
32
+ type=str,
33
+ choices=["fp32", "bf16", "fp16"],
34
+ help="precision for inference",
35
+ )
36
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
37
+ parser.add_argument("--model_max_length", default=512, type=int)
38
+ parser.add_argument("--lora_r", default=8, type=int)
39
+ parser.add_argument(
40
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
41
+ )
42
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
43
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
44
+ parser.add_argument("--load_in_4bit", action="store_true", default=True)
45
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
46
+ parser.add_argument(
47
+ "--conv_type",
48
+ default="llava_v1",
49
+ type=str,
50
+ choices=["llava_v1", "llava_llama_2"],
51
+ )
52
+ return parser.parse_args(args_to_parse)
53
+
54
+
55
+ @session_logger.set_uuid_logging
56
+ def get_cleaned_input(input_str):
57
+ logging.info(f"start cleaning of input_str: {input_str}.")
58
+ input_str = nh3.clean(
59
+ input_str,
60
+ tags={
61
+ "a",
62
+ "abbr",
63
+ "acronym",
64
+ "b",
65
+ "blockquote",
66
+ "code",
67
+ "em",
68
+ "i",
69
+ "li",
70
+ "ol",
71
+ "strong",
72
+ "ul",
73
+ },
74
+ attributes={
75
+ "a": {"href", "title"},
76
+ "abbr": {"title"},
77
+ "acronym": {"title"},
78
+ },
79
+ url_schemes={"http", "https", "mailto"},
80
+ link_rel=None,
81
+ )
82
+ logging.info(f"cleaned input_str: {input_str}.")
83
+ return input_str
84
+
85
+
86
+ @session_logger.set_uuid_logging
87
+ def set_image_precision_by_args(input_image, precision):
88
+ if precision == "bf16":
89
+ input_image = input_image.bfloat16()
90
+ elif precision == "fp16":
91
+ input_image = input_image.half()
92
+ else:
93
+ input_image = input_image.float()
94
+ return input_image
95
+
96
+
97
+ @session_logger.set_uuid_logging
98
+ def preprocess(
99
+ x,
100
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
101
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
102
+ img_size=1024,
103
+ ) -> torch.Tensor:
104
+ """Normalize pixel values and pad to a square input."""
105
+ logging.info("preprocess started")
106
+ # Normalize colors
107
+ x = (x - pixel_mean) / pixel_std
108
+ # Pad
109
+ h, w = x.shape[-2:]
110
+ padh = img_size - h
111
+ padw = img_size - w
112
+ x = F.pad(x, (0, padw, 0, padh))
113
+ logging.info("preprocess ended")
114
+ return x
115
+
116
+
117
+ @session_logger.set_uuid_logging
118
+ def get_model(args_to_parse):
119
+ logging.info("starting model preparation...")
120
+ os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
121
+
122
+ # global tokenizer, tokenizer
123
+ # Create model
124
+ _tokenizer = AutoTokenizer.from_pretrained(
125
+ args_to_parse.version,
126
+ cache_dir=None,
127
+ model_max_length=args_to_parse.model_max_length,
128
+ padding_side="right",
129
+ use_fast=False,
130
+ )
131
+ _tokenizer.pad_token = _tokenizer.unk_token
132
+ args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
133
+ torch_dtype = torch.float32
134
+ if args_to_parse.precision == "bf16":
135
+ torch_dtype = torch.bfloat16
136
+ elif args_to_parse.precision == "fp16":
137
+ torch_dtype = torch.half
138
+ kwargs = {"torch_dtype": torch_dtype}
139
+ if args_to_parse.load_in_4bit:
140
+ kwargs.update(
141
+ {
142
+ "torch_dtype": torch.half,
143
+ "load_in_4bit": True,
144
+ "quantization_config": BitsAndBytesConfig(
145
+ load_in_4bit=True,
146
+ bnb_4bit_compute_dtype=torch.float16,
147
+ bnb_4bit_use_double_quant=True,
148
+ bnb_4bit_quant_type="nf4",
149
+ llm_int8_skip_modules=["visual_model"],
150
+ ),
151
+ }
152
+ )
153
+ elif args_to_parse.load_in_8bit:
154
+ kwargs.update(
155
+ {
156
+ "torch_dtype": torch.half,
157
+ "quantization_config": BitsAndBytesConfig(
158
+ llm_int8_skip_modules=["visual_model"],
159
+ load_in_8bit=True,
160
+ ),
161
+ }
162
+ )
163
+ _model = LISAForCausalLM.from_pretrained(
164
+ args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
165
+ )
166
+ _model.config.eos_token_id = _tokenizer.eos_token_id
167
+ _model.config.bos_token_id = _tokenizer.bos_token_id
168
+ _model.config.pad_token_id = _tokenizer.pad_token_id
169
+ _model.get_model().initialize_vision_modules(_model.get_model().config)
170
+ vision_tower = _model.get_model().get_vision_tower()
171
+ vision_tower.to(dtype=torch_dtype)
172
+ if args_to_parse.precision == "bf16":
173
+ _model = _model.bfloat16().cuda()
174
+ elif (
175
+ args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
176
+ ):
177
+ vision_tower = _model.get_model().get_vision_tower()
178
+ _model.model.vision_tower = None
179
+ import deepspeed
180
+
181
+ model_engine = deepspeed.init_inference(
182
+ model=_model,
183
+ dtype=torch.half,
184
+ replace_with_kernel_inject=True,
185
+ replace_method="auto",
186
+ )
187
+ _model = model_engine.module
188
+ _model.model.vision_tower = vision_tower.half().cuda()
189
+ elif args_to_parse.precision == "fp32":
190
+ _model = _model.float().cuda()
191
+ vision_tower = _model.get_model().get_vision_tower()
192
+ vision_tower.to(device=args_to_parse.local_rank)
193
+ _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
194
+ _transform = ResizeLongestSide(args_to_parse.image_size)
195
+ _model.eval()
196
+ logging.info("model preparation ok!")
197
+ return _model, _clip_image_processor, _tokenizer, _transform
198
+
199
+
200
+ @session_logger.set_uuid_logging
201
+ def get_inference_model_by_args(args_to_parse):
202
+ logging.info(f"args_to_parse:{args_to_parse}, creating model...")
203
+ model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
204
+ logging.info("created model, preparing inference function")
205
+ no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
206
+
207
+ @session_logger.set_uuid_logging
208
+ def inference(input_str, input_image):
209
+ ## filter out special chars
210
+
211
+ input_str = get_cleaned_input(input_str)
212
+ logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
213
+ logging.info(f"input_str: {input_str}.")
214
+
215
+ ## input valid check
216
+ if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
217
+ output_str = "[Error] Invalid input: ", input_str
218
+ return error_happened, output_str
219
+
220
+ # Model Inference
221
+ conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
222
+ conv.messages = []
223
+
224
+ prompt = input_str
225
+ prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
226
+ if args_to_parse.use_mm_start_end:
227
+ replace_token = (
228
+ utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
229
+ )
230
+ prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
231
+
232
+ conv.append_message(conv.roles[0], prompt)
233
+ conv.append_message(conv.roles[1], "")
234
+ prompt = conv.get_prompt()
235
+
236
+ image_np = cv2.imread(input_image)
237
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
238
+ original_size_list = [image_np.shape[:2]]
239
+
240
+ image_clip = (
241
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
242
+ "pixel_values"
243
+ ][0]
244
+ .unsqueeze(0)
245
+ .cuda()
246
+ )
247
+ logging.info(f"image_clip type: {type(image_clip)}.")
248
+ image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
249
+
250
+ image = transform.apply_image(image_np)
251
+ resize_list = [image.shape[:2]]
252
+
253
+ image = (
254
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
255
+ .unsqueeze(0)
256
+ .cuda()
257
+ )
258
+ logging.info(f"image_clip type: {type(image_clip)}.")
259
+ image = set_image_precision_by_args(image, args_to_parse.precision)
260
+
261
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
262
+ input_ids = input_ids.unsqueeze(0).cuda()
263
+
264
+ output_ids, pred_masks = model.evaluate(
265
+ image_clip,
266
+ image,
267
+ input_ids,
268
+ resize_list,
269
+ original_size_list,
270
+ max_new_tokens=512,
271
+ tokenizer=tokenizer,
272
+ )
273
+ output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
274
+
275
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
276
+ text_output = text_output.replace("\n", "").replace(" ", " ")
277
+ text_output = text_output.split("ASSISTANT: ")[-1]
278
+
279
+ logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.")
280
+ save_img = None
281
+ for i, pred_mask in enumerate(pred_masks):
282
+ if pred_mask.shape[0] == 0:
283
+ continue
284
+
285
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
286
+ pred_mask = pred_mask > 0
287
+
288
+ save_img = image_np.copy()
289
+ save_img[pred_mask] = (
290
+ image_np * 0.5
291
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
292
+ )[pred_mask]
293
+
294
+ output_str = f"ASSISTANT: {text_output}"
295
+ output_image = no_seg_out if save_img is None else save_img
296
+ logging.info(f"output_image type: {type(output_image)}.")
297
+ return output_image, output_str
298
+
299
+ logging.info("prepared inference function!")
300
+ return inference
301
+
302
+
303
+ @session_logger.set_uuid_logging
304
+ def get_gradio_interface(
305
+ fn_inference: Callable
306
+ ):
307
+ return gr.Interface(
308
+ fn_inference,
309
+ inputs=[
310
+ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
311
+ gr.Image(type="filepath", label="Input Image")
312
+ ],
313
+ outputs=[
314
+ gr.Image(type="pil", label="Segmentation Output"),
315
+ gr.Textbox(lines=1, placeholder=None, label="Text Output")
316
+ ],
317
+ title=constants.title,
318
+ description=constants.description,
319
+ article=constants.article,
320
+ examples=constants.examples,
321
+ allow_flagging="auto"
322
+ )