alessandro trinca tornidor commited on
Commit
c41e6ce
·
1 Parent(s): 72ceb76

[refactor] lisa app has back all its functionalities

Browse files
Files changed (1) hide show
  1. main.py +233 -7
main.py CHANGED
@@ -1,17 +1,28 @@
1
  import argparse
2
- import json
3
  import logging
4
  import os
 
5
  import sys
6
  from typing import Callable
7
 
 
8
  import gradio as gr
9
  import nh3
 
 
 
10
  from fastapi import FastAPI
11
  from fastapi.staticfiles import StaticFiles
12
  from fastapi.templating import Jinja2Templates
 
13
 
 
 
 
 
14
  from utils import constants, session_logger
 
 
15
 
16
 
17
  session_logger.change_logging(logging.DEBUG)
@@ -28,6 +39,8 @@ templates = Jinja2Templates(directory="templates")
28
  @app.get("/health")
29
  @session_logger.set_uuid_logging
30
  def health() -> str:
 
 
31
  try:
32
  logging.info("health check")
33
  return json.dumps({"msg": "ok"})
@@ -98,24 +111,233 @@ def get_cleaned_input(input_str):
98
  return input_str
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  @session_logger.set_uuid_logging
102
  def get_inference_model_by_args(args_to_parse):
103
- logging.info(f"args_to_parse:{args_to_parse}.")
 
 
104
 
105
  @session_logger.set_uuid_logging
106
  def inference(input_str, input_image):
107
- logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
108
- output_str = get_cleaned_input(input_str)
109
- logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
110
- output_image = input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  logging.info(f"output_image type: {type(output_image)}.")
112
  return output_image, output_str
113
 
 
114
  return inference
115
 
116
 
117
  @session_logger.set_uuid_logging
118
- def get_gradio_interface(fn_inference: Callable):
 
 
119
  return gr.Interface(
120
  fn_inference,
121
  inputs=[
@@ -136,6 +358,10 @@ def get_gradio_interface(fn_inference: Callable):
136
 
137
  logging.info(f"sys.argv:{sys.argv}.")
138
  args = parse_args([])
 
139
  inference_fn = get_inference_model_by_args(args)
 
140
  io = get_gradio_interface(inference_fn)
 
141
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
 
 
1
  import argparse
 
2
  import logging
3
  import os
4
+ import re
5
  import sys
6
  from typing import Callable
7
 
8
+ import cv2
9
  import gradio as gr
10
  import nh3
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
  from fastapi import FastAPI
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
 
19
+ from model.LISA import LISAForCausalLM
20
+ from model.llava import conversation as conversation_lib
21
+ from model.llava.mm_utils import tokenizer_image_token
22
+ from model.segment_anything.utils.transforms import ResizeLongestSide
23
  from utils import constants, session_logger
24
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
25
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
26
 
27
 
28
  session_logger.change_logging(logging.DEBUG)
 
39
  @app.get("/health")
40
  @session_logger.set_uuid_logging
41
  def health() -> str:
42
+ import json
43
+
44
  try:
45
  logging.info("health check")
46
  return json.dumps({"msg": "ok"})
 
111
  return input_str
112
 
113
 
114
+ @session_logger.set_uuid_logging
115
+ def set_image_precision_by_args(input_image, precision):
116
+ if precision == "bf16":
117
+ input_image = input_image.bfloat16()
118
+ elif precision == "fp16":
119
+ input_image = input_image.half()
120
+ else:
121
+ input_image = input_image.float()
122
+ return input_image
123
+
124
+
125
+ @session_logger.set_uuid_logging
126
+ def preprocess(
127
+ x,
128
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
129
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
130
+ img_size=1024,
131
+ ) -> torch.Tensor:
132
+ """Normalize pixel values and pad to a square input."""
133
+ logging.info("preprocess started")
134
+ # Normalize colors
135
+ x = (x - pixel_mean) / pixel_std
136
+ # Pad
137
+ h, w = x.shape[-2:]
138
+ padh = img_size - h
139
+ padw = img_size - w
140
+ x = F.pad(x, (0, padw, 0, padh))
141
+ logging.info("preprocess ended")
142
+ return x
143
+
144
+
145
+ @session_logger.set_uuid_logging
146
+ def get_model(args_to_parse):
147
+ logging.info("starting model preparation...")
148
+ os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
149
+
150
+ # global tokenizer, tokenizer
151
+ # Create model
152
+ _tokenizer = AutoTokenizer.from_pretrained(
153
+ args_to_parse.version,
154
+ cache_dir=None,
155
+ model_max_length=args_to_parse.model_max_length,
156
+ padding_side="right",
157
+ use_fast=False,
158
+ )
159
+ _tokenizer.pad_token = _tokenizer.unk_token
160
+ args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
161
+ torch_dtype = torch.float32
162
+ if args_to_parse.precision == "bf16":
163
+ torch_dtype = torch.bfloat16
164
+ elif args_to_parse.precision == "fp16":
165
+ torch_dtype = torch.half
166
+ kwargs = {"torch_dtype": torch_dtype}
167
+ if args_to_parse.load_in_4bit:
168
+ kwargs.update(
169
+ {
170
+ "torch_dtype": torch.half,
171
+ "load_in_4bit": True,
172
+ "quantization_config": BitsAndBytesConfig(
173
+ load_in_4bit=True,
174
+ bnb_4bit_compute_dtype=torch.float16,
175
+ bnb_4bit_use_double_quant=True,
176
+ bnb_4bit_quant_type="nf4",
177
+ llm_int8_skip_modules=["visual_model"],
178
+ ),
179
+ }
180
+ )
181
+ elif args_to_parse.load_in_8bit:
182
+ kwargs.update(
183
+ {
184
+ "torch_dtype": torch.half,
185
+ "quantization_config": BitsAndBytesConfig(
186
+ llm_int8_skip_modules=["visual_model"],
187
+ load_in_8bit=True,
188
+ ),
189
+ }
190
+ )
191
+ _model = LISAForCausalLM.from_pretrained(
192
+ args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
193
+ )
194
+ _model.config.eos_token_id = _tokenizer.eos_token_id
195
+ _model.config.bos_token_id = _tokenizer.bos_token_id
196
+ _model.config.pad_token_id = _tokenizer.pad_token_id
197
+ _model.get_model().initialize_vision_modules(_model.get_model().config)
198
+ vision_tower = _model.get_model().get_vision_tower()
199
+ vision_tower.to(dtype=torch_dtype)
200
+ if args_to_parse.precision == "bf16":
201
+ _model = _model.bfloat16().cuda()
202
+ elif (
203
+ args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
204
+ ):
205
+ vision_tower = _model.get_model().get_vision_tower()
206
+ _model.model.vision_tower = None
207
+ import deepspeed
208
+
209
+ model_engine = deepspeed.init_inference(
210
+ model=_model,
211
+ dtype=torch.half,
212
+ replace_with_kernel_inject=True,
213
+ replace_method="auto",
214
+ )
215
+ _model = model_engine.module
216
+ _model.model.vision_tower = vision_tower.half().cuda()
217
+ elif args_to_parse.precision == "fp32":
218
+ _model = _model.float().cuda()
219
+ vision_tower = _model.get_model().get_vision_tower()
220
+ vision_tower.to(device=args_to_parse.local_rank)
221
+ _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
222
+ _transform = ResizeLongestSide(args_to_parse.image_size)
223
+ _model.eval()
224
+ logging.info("model preparation ok!")
225
+ return _model, _clip_image_processor, _tokenizer, _transform
226
+
227
+
228
  @session_logger.set_uuid_logging
229
  def get_inference_model_by_args(args_to_parse):
230
+ logging.info(f"args_to_parse:{args_to_parse}, creating model...")
231
+ model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
232
+ logging.info("created model, preparing inference function")
233
 
234
  @session_logger.set_uuid_logging
235
  def inference(input_str, input_image):
236
+ ## filter out special chars
237
+
238
+ input_str = get_cleaned_input(input_str)
239
+ logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
240
+ logging.info(f"input_str: {input_str}.")
241
+
242
+ ## input valid check
243
+ if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
244
+ output_str = "[Error] Invalid input: ", input_str
245
+ # output_image = np.zeros((128, 128, 3))
246
+ ## error happened
247
+ output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
248
+ return output_image, output_str
249
+
250
+ # Model Inference
251
+ conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
252
+ conv.messages = []
253
+
254
+ prompt = input_str
255
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
256
+ if args_to_parse.use_mm_start_end:
257
+ replace_token = (
258
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
259
+ )
260
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
261
+
262
+ conv.append_message(conv.roles[0], prompt)
263
+ conv.append_message(conv.roles[1], "")
264
+ prompt = conv.get_prompt()
265
+
266
+ image_np = cv2.imread(input_image)
267
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
268
+ original_size_list = [image_np.shape[:2]]
269
+
270
+ image_clip = (
271
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
272
+ "pixel_values"
273
+ ][0]
274
+ .unsqueeze(0)
275
+ .cuda()
276
+ )
277
+ logging.info(f"image_clip type: {type(image_clip)}.")
278
+ image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)
279
+
280
+ image = transform.apply_image(image_np)
281
+ resize_list = [image.shape[:2]]
282
+
283
+ image = (
284
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
285
+ .unsqueeze(0)
286
+ .cuda()
287
+ )
288
+ logging.info(f"image_clip type: {type(image_clip)}.")
289
+ image = set_image_precision_by_args(image, args_to_parse.precision)
290
+
291
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
292
+ input_ids = input_ids.unsqueeze(0).cuda()
293
+
294
+ output_ids, pred_masks = model.evaluate(
295
+ image_clip,
296
+ image,
297
+ input_ids,
298
+ resize_list,
299
+ original_size_list,
300
+ max_new_tokens=512,
301
+ tokenizer=tokenizer,
302
+ )
303
+ output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
304
+
305
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
306
+ text_output = text_output.replace("\n", "").replace(" ", " ")
307
+ text_output = text_output.split("ASSISTANT: ")[-1]
308
+
309
+ logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.")
310
+ save_img = None
311
+ for i, pred_mask in enumerate(pred_masks):
312
+ if pred_mask.shape[0] == 0:
313
+ continue
314
+
315
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
316
+ pred_mask = pred_mask > 0
317
+
318
+ save_img = image_np.copy()
319
+ save_img[pred_mask] = (
320
+ image_np * 0.5
321
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
322
+ )[pred_mask]
323
+
324
+ output_str = f"ASSITANT: {text_output}"
325
+ if save_img is not None:
326
+ output_image = save_img # input_image
327
+ else:
328
+ ## no seg output
329
+ output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
330
  logging.info(f"output_image type: {type(output_image)}.")
331
  return output_image, output_str
332
 
333
+ logging.info("prepared inference function!")
334
  return inference
335
 
336
 
337
  @session_logger.set_uuid_logging
338
+ def get_gradio_interface(
339
+ fn_inference: Callable
340
+ ):
341
  return gr.Interface(
342
  fn_inference,
343
  inputs=[
 
358
 
359
  logging.info(f"sys.argv:{sys.argv}.")
360
  args = parse_args([])
361
+ logging.info(f"prepared default arguments:{args}.")
362
  inference_fn = get_inference_model_by_args(args)
363
+ logging.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
364
  io = get_gradio_interface(inference_fn)
365
+ logging.info("created gradio interface")
366
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
367
+ logging.info("mounted gradio app within fastapi")