mknolan commited on
Commit
38c6c45
·
verified ·
1 Parent(s): 68f0b83

Upload app.py

Browse files

Basic app with single image and PDF slide analysis functionality

Files changed (1) hide show
  1. app.py +391 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import tempfile
6
+ import torch
7
+ import torchvision.transforms as T
8
+ from torchvision.transforms.functional import InterpolationMode
9
+ from PIL import Image
10
+ import gradio as gr
11
+ from transformers import AutoModel, AutoTokenizer
12
+ import pdf2image
13
+
14
+ # Constants
15
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
16
+ IMAGENET_STD = (0.229, 0.224, 0.225)
17
+
18
+ # Configuration
19
+ MODEL_NAME = "OpenGVLab/InternVL2_5-8B"
20
+ IMAGE_SIZE = 448
21
+
22
+ # Set up environment variables
23
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
24
+
25
+ # Utility functions for image processing
26
+ def build_transform(input_size):
27
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
28
+ transform = T.Compose([
29
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
30
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
31
+ T.ToTensor(),
32
+ T.Normalize(mean=MEAN, std=STD)
33
+ ])
34
+ return transform
35
+
36
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
37
+ best_ratio_diff = float('inf')
38
+ best_ratio = (1, 1)
39
+ area = width * height
40
+ for ratio in target_ratios:
41
+ target_aspect_ratio = ratio[0] / ratio[1]
42
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
43
+ if ratio_diff < best_ratio_diff:
44
+ best_ratio_diff = ratio_diff
45
+ best_ratio = ratio
46
+ elif ratio_diff == best_ratio_diff:
47
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
48
+ best_ratio = ratio
49
+ return best_ratio
50
+
51
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
52
+ orig_width, orig_height = image.size
53
+ aspect_ratio = orig_width / orig_height
54
+
55
+ # calculate the existing image aspect ratio
56
+ target_ratios = set(
57
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
58
+ i * j <= max_num and i * j >= min_num)
59
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
60
+
61
+ # find the closest aspect ratio to the target
62
+ target_aspect_ratio = find_closest_aspect_ratio(
63
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
64
+
65
+ # calculate the target width and height
66
+ target_width = image_size * target_aspect_ratio[0]
67
+ target_height = image_size * target_aspect_ratio[1]
68
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
69
+
70
+ # resize the image
71
+ resized_img = image.resize((target_width, target_height))
72
+ processed_images = []
73
+ for i in range(blocks):
74
+ box = (
75
+ (i % (target_width // image_size)) * image_size,
76
+ (i // (target_width // image_size)) * image_size,
77
+ ((i % (target_width // image_size)) + 1) * image_size,
78
+ ((i // (target_width // image_size)) + 1) * image_size
79
+ )
80
+ # split the image
81
+ split_img = resized_img.crop(box)
82
+ processed_images.append(split_img)
83
+ assert len(processed_images) == blocks
84
+ if use_thumbnail and len(processed_images) != 1:
85
+ thumbnail_img = image.resize((image_size, image_size))
86
+ processed_images.append(thumbnail_img)
87
+ return processed_images
88
+
89
+ # Function to split model across GPUs
90
+ def split_model(model_name):
91
+ device_map = {}
92
+ world_size = torch.cuda.device_count()
93
+ if world_size <= 1:
94
+ return "auto"
95
+
96
+ num_layers = {
97
+ 'InternVL2_5-1B': 24,
98
+ 'InternVL2_5-2B': 24,
99
+ 'InternVL2_5-4B': 36,
100
+ 'InternVL2_5-8B': 32,
101
+ 'InternVL2_5-26B': 48,
102
+ 'InternVL2_5-38B': 64,
103
+ 'InternVL2_5-78B': 80
104
+ }[model_name]
105
+
106
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
107
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
108
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
109
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
110
+ layer_cnt = 0
111
+ for i, num_layer in enumerate(num_layers_per_gpu):
112
+ for j in range(num_layer):
113
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
114
+ layer_cnt += 1
115
+ device_map['vision_model'] = 0
116
+ device_map['mlp1'] = 0
117
+ device_map['language_model.model.tok_embeddings'] = 0
118
+ device_map['language_model.model.embed_tokens'] = 0
119
+ device_map['language_model.model.rotary_emb'] = 0
120
+ device_map['language_model.output'] = 0
121
+ device_map['language_model.model.norm'] = 0
122
+ device_map['language_model.lm_head'] = 0
123
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
124
+
125
+ return device_map
126
+
127
+ # Model loading function
128
+ def load_model():
129
+ print(f"\n=== Loading {MODEL_NAME} ===")
130
+ print(f"CUDA available: {torch.cuda.is_available()}")
131
+
132
+ if torch.cuda.is_available():
133
+ print(f"GPU count: {torch.cuda.device_count()}")
134
+ for i in range(torch.cuda.device_count()):
135
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
136
+
137
+ # Memory info
138
+ print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
139
+ print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
140
+ print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
141
+
142
+ # Determine device map
143
+ device_map = "auto"
144
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1:
145
+ model_short_name = MODEL_NAME.split('/')[-1]
146
+ device_map = split_model(model_short_name)
147
+
148
+ # Load model and tokenizer
149
+ try:
150
+ model = AutoModel.from_pretrained(
151
+ MODEL_NAME,
152
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
153
+ low_cpu_mem_usage=True,
154
+ trust_remote_code=True,
155
+ device_map=device_map
156
+ )
157
+
158
+ tokenizer = AutoTokenizer.from_pretrained(
159
+ MODEL_NAME,
160
+ use_fast=False,
161
+ trust_remote_code=True
162
+ )
163
+
164
+ print(f"✓ Model and tokenizer loaded successfully!")
165
+ return model, tokenizer
166
+ except Exception as e:
167
+ print(f"❌ Error loading model: {e}")
168
+ import traceback
169
+ traceback.print_exc()
170
+ return None, None
171
+
172
+ # Extract slides from uploaded PDF file
173
+ def extract_slides_from_pdf(file_obj):
174
+ try:
175
+ file_bytes = file_obj.read()
176
+ file_extension = os.path.splitext(file_obj.name)[1].lower()
177
+
178
+ # Check if it's a PDF
179
+ if file_extension != '.pdf':
180
+ return []
181
+
182
+ # Create temporary file
183
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
184
+ temp_file.write(file_bytes)
185
+ temp_path = temp_file.name
186
+
187
+ # Extract images from PDF using pdf2image
188
+ slides = []
189
+ try:
190
+ images = pdf2image.convert_from_path(temp_path, dpi=300)
191
+ slides = [(f"Slide {i+1}", img) for i, img in enumerate(images)]
192
+ except Exception as e:
193
+ print(f"Error converting PDF: {e}")
194
+
195
+ # Clean up temporary file
196
+ os.unlink(temp_path)
197
+
198
+ return slides
199
+
200
+ except Exception as e:
201
+ import traceback
202
+ error_msg = f"Error extracting slides: {str(e)}\n{traceback.format_exc()}"
203
+ print(error_msg)
204
+ return []
205
+
206
+ # Image analysis function
207
+ def analyze_image(model, tokenizer, image, prompt):
208
+ try:
209
+ # Check if image is valid
210
+ if image is None:
211
+ return "Please upload an image first."
212
+
213
+ # Process the image
214
+ processed_images = dynamic_preprocess(image, image_size=IMAGE_SIZE)
215
+
216
+ # Prepare the prompt
217
+ text_prompt = f"USER: <image>\n{prompt}\nASSISTANT:"
218
+
219
+ # Convert inputs for the model
220
+ inputs = tokenizer([text_prompt], return_tensors="pt")
221
+
222
+ # Move inputs to the right device
223
+ if torch.cuda.is_available():
224
+ inputs = {k: v.cuda() for k, v in inputs.items()}
225
+
226
+ # Add image to the inputs
227
+ inputs["images"] = processed_images
228
+
229
+ # Generate a response
230
+ with torch.no_grad():
231
+ outputs = model.generate(
232
+ **inputs,
233
+ max_new_tokens=512,
234
+ )
235
+
236
+ # Decode the outputs
237
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
238
+
239
+ # Extract only the assistant's response
240
+ assistant_response = generated_text.split("ASSISTANT:")[-1].strip()
241
+
242
+ return assistant_response
243
+ except Exception as e:
244
+ import traceback
245
+ error_msg = f"Error analyzing image: {str(e)}\n{traceback.format_exc()}"
246
+ return error_msg
247
+
248
+ # Analyze multiple slides from a PDF
249
+ def analyze_pdf_slides(model, tokenizer, file_obj, prompt, num_slides=2):
250
+ try:
251
+ if file_obj is None:
252
+ return "Please upload a PDF file."
253
+
254
+ # Extract slides from PDF
255
+ slides = extract_slides_from_pdf(file_obj)
256
+
257
+ if not slides:
258
+ return "No slides were extracted from the file. Please check that it's a valid PDF."
259
+
260
+ # Limit to the requested number of slides
261
+ slides = slides[:num_slides]
262
+
263
+ # Analyze each slide
264
+ analyses = []
265
+ for slide_title, slide_image in slides:
266
+ analysis = analyze_image(model, tokenizer, slide_image, prompt)
267
+ analyses.append((slide_title, analysis))
268
+
269
+ # Format the results
270
+ result = ""
271
+ for slide_title, analysis in analyses:
272
+ result += f"## {slide_title}\n\n{analysis}\n\n---\n\n"
273
+
274
+ return result
275
+
276
+ except Exception as e:
277
+ import traceback
278
+ error_msg = f"Error analyzing slides: {str(e)}\n{traceback.format_exc()}"
279
+ return error_msg
280
+
281
+ # Main function
282
+ def main():
283
+ # Load the model
284
+ model, tokenizer = load_model()
285
+
286
+ if model is None:
287
+ # Create an error interface if model loading failed
288
+ demo = gr.Interface(
289
+ fn=lambda x: "Model loading failed. Please check the logs for details.",
290
+ inputs=gr.Textbox(),
291
+ outputs=gr.Textbox(),
292
+ title="InternVL2.5 Analyzer - Error",
293
+ description="The model failed to load. Please check the logs for more information."
294
+ )
295
+ return demo
296
+
297
+ # Create an interface with tabs
298
+ with gr.Blocks(title="InternVL2.5 Analyzer") as demo:
299
+ gr.Markdown("# InternVL2.5 Image and Slide Analyzer")
300
+
301
+ with gr.Tabs():
302
+ # Single Image Analysis Tab
303
+ with gr.TabItem("Single Image Analysis"):
304
+ # Predefined prompts for analysis
305
+ image_prompts = [
306
+ "Describe this image in detail.",
307
+ "What can you tell me about this image?",
308
+ "Is there any text in this image? If so, can you read it?",
309
+ "What is the main subject of this image?",
310
+ "What emotions or feelings does this image convey?",
311
+ "Describe the composition and visual elements of this image.",
312
+ "Summarize what you see in this image in one paragraph."
313
+ ]
314
+
315
+ with gr.Row():
316
+ image_input = gr.Image(type="pil", label="Upload Image")
317
+ image_prompt = gr.Dropdown(
318
+ choices=image_prompts,
319
+ value=image_prompts[0],
320
+ label="Select a prompt",
321
+ allow_custom_value=True
322
+ )
323
+
324
+ image_analyze_btn = gr.Button("Analyze Image")
325
+ image_output = gr.Textbox(label="Analysis Results", lines=15)
326
+
327
+ # Handle the image analysis action
328
+ image_analyze_btn.click(
329
+ fn=lambda img, prompt: analyze_image(model, tokenizer, img, prompt),
330
+ inputs=[image_input, image_prompt],
331
+ outputs=image_output
332
+ )
333
+
334
+ # PDF Slides Analysis Tab
335
+ with gr.TabItem("PDF Slides Analysis"):
336
+ slide_prompts = [
337
+ "Analyze this slide and describe its contents.",
338
+ "What is the main message of this slide?",
339
+ "Extract all the text visible in this slide.",
340
+ "What are the key points presented in this slide?",
341
+ "Describe the visual elements and layout of this slide."
342
+ ]
343
+
344
+ with gr.Row():
345
+ file_input = gr.File(label="Upload PDF")
346
+ slide_prompt = gr.Dropdown(
347
+ choices=slide_prompts,
348
+ value=slide_prompts[0],
349
+ label="Select a prompt",
350
+ allow_custom_value=True
351
+ )
352
+
353
+ num_slides = gr.Slider(
354
+ minimum=1,
355
+ maximum=5,
356
+ value=2,
357
+ step=1,
358
+ label="Number of Slides to Analyze"
359
+ )
360
+
361
+ slides_analyze_btn = gr.Button("Analyze Slides")
362
+ slides_output = gr.Markdown(label="Analysis Results")
363
+
364
+ # Handle the slides analysis action
365
+ slides_analyze_btn.click(
366
+ fn=lambda file, prompt, num: analyze_pdf_slides(model, tokenizer, file, prompt, num),
367
+ inputs=[file_input, slide_prompt, num_slides],
368
+ outputs=slides_output
369
+ )
370
+
371
+ # Add example if available
372
+ if os.path.exists("example_slides/test_slides.pdf"):
373
+ gr.Examples(
374
+ examples=[
375
+ ["example_slides/test_slides.pdf", "Extract all the text visible in this slide.", 2]
376
+ ],
377
+ inputs=[file_input, slide_prompt, num_slides]
378
+ )
379
+
380
+ return demo
381
+
382
+ # Run the application
383
+ if __name__ == "__main__":
384
+ try:
385
+ # Create and launch the interface
386
+ demo = main()
387
+ demo.launch(server_name="0.0.0.0")
388
+ except Exception as e:
389
+ print(f"Error starting the application: {e}")
390
+ import traceback
391
+ traceback.print_exc()