Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
from PIL import Image | |
from datetime import datetime | |
import os | |
import torch | |
import gc | |
# Set PyTorch memory allocation configuration | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128" | |
DESCRIPTION = "[Sparrow Qwen2-VL-2B Backend](https://github.com/katanaml/sparrow)" | |
def process_image(image_filepath, max_width=800, max_height=1000): | |
if image_filepath is None: | |
raise ValueError("No image provided. Please upload an image before submitting.") | |
img = Image.open(image_filepath) | |
width, height = img.size | |
# Calculate new dimensions while maintaining aspect ratio | |
if width > max_width or height > max_height: | |
aspect_ratio = width / height | |
if width > max_width: | |
new_width = max_width | |
new_height = int(new_width / aspect_ratio) | |
if new_height > max_height: | |
new_height = max_height | |
new_width = int(new_height * aspect_ratio) | |
else: | |
new_width, new_height = width, height | |
# Resize the image if needed | |
if new_width != width or new_height != height: | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
# Generate temporary filename - use /tmp folder for better space management | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"/tmp/image_{timestamp}.jpg" # Use jpg for smaller file size | |
# Save with optimized compression | |
img.save(filename, format='JPEG', quality=85, optimize=True) | |
return os.path.abspath(filename), new_width, new_height | |
# Initialize model with memory optimizations but without 4-bit quantization | |
model = None | |
processor = None | |
def load_model(): | |
# Load model with memory optimizations | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-2B-Instruct", | |
torch_dtype=torch.float16, # Use fp16 for memory efficiency | |
device_map="auto", | |
attn_implementation="flash_attention_2" # Use FlashAttention if available | |
) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
return model, processor | |
def run_inference(input_imgs, text_input): | |
global model, processor | |
# Lazy load model | |
if model is None or processor is None: | |
model, processor = load_model() | |
results = [] | |
# Process images one at a time to avoid OOM issues | |
for image in input_imgs: | |
# Clear cache before processing each image | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Process image with reduced dimensions | |
image_path, width, height = process_image(image) | |
try: | |
# Create messages with optimized image | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image_path, | |
"resized_height": height, | |
"resized_width": width | |
}, | |
{ | |
"type": "text", | |
"text": text_input | |
} | |
] | |
} | |
] | |
# Prepare inputs with memory optimization | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
# Clear unused memory | |
del messages | |
torch.cuda.empty_cache() | |
# Process inputs with truncation to control memory usage | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
truncation=True, # Add truncation | |
max_length=768, # Limit context length | |
return_tensors="pt", | |
) | |
# Move to GPU efficiently | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# Clean up variables to free memory | |
del text, image_inputs, video_inputs | |
torch.cuda.empty_cache() | |
# Generate with optimized parameters | |
with torch.inference_mode(): # More efficient than no_grad | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1024, # Reduced from 4096 | |
do_sample=False, # Deterministic generation uses less memory | |
use_cache=True, # Use KV cache | |
num_beams=1 # Disable beam search to save memory | |
) | |
# Process output efficiently | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
] | |
raw_output = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True | |
) | |
results.append(raw_output[0]) | |
print(f"Processed: {image_path}") | |
# Clear tensors from GPU memory | |
del inputs, generated_ids, generated_ids_trimmed | |
torch.cuda.empty_cache() | |
gc.collect() | |
finally: | |
# Clean up temporary files | |
if os.path.exists(image_path): | |
os.remove(image_path) | |
return results | |
# Gradio interface | |
css = """ | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Tab(label="Qwen2-VL-2B Input"): | |
with gr.Row(): | |
with gr.Column(): | |
input_imgs = gr.Files(file_types=["image"], label="Upload Document Images") | |
text_input = gr.Textbox(label="Query") | |
submit_btn = gr.Button(value="Submit", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Response") | |
submit_btn.click(run_inference, [input_imgs, text_input], [output_text]) | |
# Use smaller queue size to manage memory | |
demo.queue(api_open=True, max_size=3) | |
demo.launch(debug=True) | |