xzerus commited on
Commit
11bbd27
·
verified ·
1 Parent(s): f3d47d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -57
app.py CHANGED
@@ -5,6 +5,12 @@ from decord import VideoReader, cpu
5
  from PIL import Image
6
  from torchvision.transforms.functional import InterpolationMode
7
  from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
8
 
9
  # Device Configuration
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -22,21 +28,6 @@ def build_transform(input_size):
22
  ])
23
  return transform
24
 
25
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
26
- best_ratio_diff = float('inf')
27
- best_ratio = (1, 1)
28
- area = width * height
29
- for ratio in target_ratios:
30
- target_aspect_ratio = ratio[0] / ratio[1]
31
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
32
- if ratio_diff < best_ratio_diff:
33
- best_ratio_diff = ratio_diff
34
- best_ratio = ratio
35
- elif ratio_diff == best_ratio_diff:
36
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
37
- best_ratio = ratio
38
- return best_ratio
39
-
40
  def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
41
  orig_width, orig_height = image.size
42
  aspect_ratio = orig_width / orig_height
@@ -46,16 +37,11 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbna
46
  i * j <= max_num and i * j >= min_num)
47
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
48
 
49
- target_aspect_ratio = find_closest_aspect_ratio(
50
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
51
-
52
- target_width = image_size * target_aspect_ratio[0]
53
- target_height = image_size * target_aspect_ratio[1]
54
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
55
-
56
  resized_img = image.resize((target_width, target_height))
57
  processed_images = []
58
- for i in range(blocks):
59
  box = (
60
  (i % (target_width // image_size)) * image_size,
61
  (i // (target_width // image_size)) * image_size,
@@ -64,13 +50,12 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbna
64
  )
65
  split_img = resized_img.crop(box)
66
  processed_images.append(split_img)
67
- assert len(processed_images) == blocks
68
  if use_thumbnail and len(processed_images) != 1:
69
  thumbnail_img = image.resize((image_size, image_size))
70
  processed_images.append(thumbnail_img)
71
  return processed_images
72
 
73
- def load_image(image_file, input_size=448, max_num=12):
74
  image = Image.open(image_file).convert('RGB')
75
  transform = build_transform(input_size=input_size)
76
  images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
@@ -78,38 +63,6 @@ def load_image(image_file, input_size=448, max_num=12):
78
  pixel_values = torch.stack(pixel_values).to(device)
79
  return pixel_values
80
 
81
- def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
82
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
83
- max_frame = len(vr) - 1
84
- fps = float(vr.get_avg_fps())
85
-
86
- pixel_values_list, num_patches_list = [], []
87
- transform = build_transform(input_size=input_size)
88
- frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
89
- for frame_index in frame_indices:
90
- img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
91
- img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
92
- pixel_values = [transform(tile) for tile in img]
93
- pixel_values = torch.stack(pixel_values)
94
- num_patches_list.append(pixel_values.shape[0])
95
- pixel_values_list.append(pixel_values)
96
- pixel_values = torch.cat(pixel_values_list)
97
- return pixel_values, num_patches_list
98
-
99
- def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
100
- if bound:
101
- start, end = bound[0], bound[1]
102
- else:
103
- start, end = -100000, 100000
104
- start_idx = max(first_idx, round(start * fps))
105
- end_idx = min(round(end * fps), max_frame)
106
- seg_size = float(end_idx - start_idx) / num_segments
107
- frame_indices = np.array([
108
- int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
109
- for idx in range(num_segments)
110
- ])
111
- return frame_indices
112
-
113
  # Load Model
114
  path = 'OpenGVLab/InternVL2_5-1B'
115
  model = AutoModel.from_pretrained(
@@ -119,3 +72,14 @@ model = AutoModel.from_pretrained(
119
  trust_remote_code=True
120
  ).eval().to(device)
121
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  from torchvision.transforms.functional import InterpolationMode
7
  from transformers import AutoModel, AutoTokenizer
8
+ from fastapi import FastAPI, UploadFile, File
9
+ from typing import List
10
+ from io import BytesIO
11
+
12
+ # FastAPI app initialization
13
+ app = FastAPI()
14
 
15
  # Device Configuration
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  ])
29
  return transform
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
32
  orig_width, orig_height = image.size
33
  aspect_ratio = orig_width / orig_height
 
37
  i * j <= max_num and i * j >= min_num)
38
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
39
 
40
+ target_width = image_size * target_ratios[0][0]
41
+ target_height = image_size * target_ratios[0][1]
 
 
 
 
 
42
  resized_img = image.resize((target_width, target_height))
43
  processed_images = []
44
+ for i in range(target_ratios[0][0] * target_ratios[0][1]):
45
  box = (
46
  (i % (target_width // image_size)) * image_size,
47
  (i // (target_width // image_size)) * image_size,
 
50
  )
51
  split_img = resized_img.crop(box)
52
  processed_images.append(split_img)
 
53
  if use_thumbnail and len(processed_images) != 1:
54
  thumbnail_img = image.resize((image_size, image_size))
55
  processed_images.append(thumbnail_img)
56
  return processed_images
57
 
58
+ def load_image(image_file: BytesIO, input_size=448, max_num=12):
59
  image = Image.open(image_file).convert('RGB')
60
  transform = build_transform(input_size=input_size)
61
  images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
 
63
  pixel_values = torch.stack(pixel_values).to(device)
64
  return pixel_values
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Load Model
67
  path = 'OpenGVLab/InternVL2_5-1B'
68
  model = AutoModel.from_pretrained(
 
72
  trust_remote_code=True
73
  ).eval().to(device)
74
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
75
+
76
+ @app.post("/predict")
77
+ async def predict(file: UploadFile = File(...), question: str = "Describe the image"):
78
+ # Load and preprocess the image
79
+ file_bytes = BytesIO(await file.read())
80
+ pixel_values = load_image(file_bytes)
81
+
82
+ # Generate a response
83
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
84
+ response, _ = model.chat(tokenizer, pixel_values, question, generation_config)
85
+ return {"question": question, "response": response}