xzerus commited on
Commit
895c285
·
verified ·
1 Parent(s): 11bbd27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -70
app.py CHANGED
@@ -1,85 +1,104 @@
1
- import numpy as np
2
  import torch
3
  import torchvision.transforms as T
4
- from decord import VideoReader, cpu
5
  from PIL import Image
6
- from torchvision.transforms.functional import InterpolationMode
7
- from transformers import AutoModel, AutoTokenizer
8
- from fastapi import FastAPI, UploadFile, File
9
- from typing import List
10
- from io import BytesIO
11
 
12
- # FastAPI app initialization
13
- app = FastAPI()
14
-
15
- # Device Configuration
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
18
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
  IMAGENET_STD = (0.229, 0.224, 0.225)
20
 
21
  def build_transform(input_size):
22
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
 
 
23
  transform = T.Compose([
24
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
25
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
26
  T.ToTensor(),
27
- T.Normalize(mean=MEAN, std=STD)
28
  ])
29
  return transform
30
 
31
- def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
32
- orig_width, orig_height = image.size
33
- aspect_ratio = orig_width / orig_height
34
-
35
- target_ratios = set(
36
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
37
- i * j <= max_num and i * j >= min_num)
38
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
39
-
40
- target_width = image_size * target_ratios[0][0]
41
- target_height = image_size * target_ratios[0][1]
42
- resized_img = image.resize((target_width, target_height))
43
- processed_images = []
44
- for i in range(target_ratios[0][0] * target_ratios[0][1]):
45
- box = (
46
- (i % (target_width // image_size)) * image_size,
47
- (i // (target_width // image_size)) * image_size,
48
- ((i % (target_width // image_size)) + 1) * image_size,
49
- ((i // (target_width // image_size)) + 1) * image_size
50
- )
51
- split_img = resized_img.crop(box)
52
- processed_images.append(split_img)
53
- if use_thumbnail and len(processed_images) != 1:
54
- thumbnail_img = image.resize((image_size, image_size))
55
- processed_images.append(thumbnail_img)
56
- return processed_images
57
-
58
- def load_image(image_file: BytesIO, input_size=448, max_num=12):
59
- image = Image.open(image_file).convert('RGB')
60
- transform = build_transform(input_size=input_size)
61
- images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
62
- pixel_values = [transform(image) for image in images]
63
- pixel_values = torch.stack(pixel_values).to(device)
64
- return pixel_values
65
-
66
- # Load Model
67
- path = 'OpenGVLab/InternVL2_5-1B'
68
  model = AutoModel.from_pretrained(
69
- path,
70
- low_cpu_mem_usage=True,
71
- use_flash_attn=False,
72
- trust_remote_code=True
73
- ).eval().to(device)
74
- tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
75
-
76
- @app.post("/predict")
77
- async def predict(file: UploadFile = File(...), question: str = "Describe the image"):
78
- # Load and preprocess the image
79
- file_bytes = BytesIO(await file.read())
80
- pixel_values = load_image(file_bytes)
81
-
82
- # Generate a response
83
- generation_config = dict(max_new_tokens=1024, do_sample=True)
84
- response, _ = model.chat(tokenizer, pixel_values, question, generation_config)
85
- return {"question": question, "response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torchvision.transforms as T
 
3
  from PIL import Image
4
+ from threading import Thread
5
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
6
+ import gradio as gr
7
+ import logging
 
8
 
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
11
 
12
+ # ImageNet normalization values
13
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
14
  IMAGENET_STD = (0.229, 0.224, 0.225)
15
 
16
  def build_transform(input_size):
17
+ """
18
+ Build preprocessing pipeline for images.
19
+ """
20
  transform = T.Compose([
21
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
22
+ T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
23
  T.ToTensor(),
24
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
25
  ])
26
  return transform
27
 
28
+ def preprocess_image(image, input_size=448):
29
+ """
30
+ Preprocess the image to the required format.
31
+ """
32
+ logging.info("Starting image preprocessing...")
33
+ transform = build_transform(input_size)
34
+ tensor_image = transform(image).unsqueeze(0) # Add batch dimension
35
+ logging.info(f"Image preprocessed. Shape: {tensor_image.shape}")
36
+ return tensor_image
37
+
38
+ # Load the model and tokenizer
39
+ logging.info("Loading model from Hugging Face Hub...")
40
+ model_path = "OpenGVLab/InternVL2_5-1B" # Use Hugging Face model path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model = AutoModel.from_pretrained(
42
+ model_path,
43
+ torch_dtype=torch.bfloat16,
44
+ trust_remote_code=True,
45
+ ).eval()
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
48
+
49
+ # Add the `<image>` token if missing
50
+ if "<image>" not in tokenizer.get_vocab():
51
+ tokenizer.add_tokens(["<image>"])
52
+ logging.info("Added `<image>` token to tokenizer vocabulary.")
53
+ model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
54
+
55
+ assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
56
+
57
+ def describe_image(image):
58
+ """
59
+ Generate a description for the uploaded image with streamed output.
60
+ """
61
+ try:
62
+ logging.info("Processing uploaded image...")
63
+ pixel_values = preprocess_image(image, input_size=448).to(torch.bfloat16)
64
+
65
+ prompt = "<image>\nExtract text from the image, respond with only the extracted text."
66
+ logging.info(f"Prompt: {prompt}")
67
+
68
+ # Streamer for live text output
69
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
70
+ generation_config = dict(max_new_tokens=512, do_sample=True, streamer=streamer)
71
+
72
+ logging.info("Starting model inference...")
73
+ thread = Thread(target=model.chat, kwargs=dict(
74
+ tokenizer=tokenizer, pixel_values=pixel_values, question=prompt,
75
+ history=None, return_history=False, generation_config=generation_config,
76
+ ))
77
+ thread.start()
78
+
79
+ generated_text = ''
80
+ for new_text in streamer:
81
+ if new_text == model.conv_template.sep:
82
+ break
83
+ generated_text += new_text
84
+ yield new_text # Stream each chunk
85
+
86
+ logging.info("Inference complete.")
87
+ except Exception as e:
88
+ logging.error(f"Error during processing: {e}")
89
+ yield f"Error: {e}"
90
+
91
+ # Gradio Interface
92
+ logging.info("Setting up Gradio interface...")
93
+ interface = gr.Interface(
94
+ fn=describe_image,
95
+ inputs=gr.Image(type="pil"),
96
+ outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
97
+ title="Image to Text",
98
+ description="Upload an image to extract text using the pretrained model.",
99
+ live=True, # Enables live streaming output
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ logging.info("Launching Gradio interface...")
104
+ interface.launch(server_name="0.0.0.0", server_port=7860)