ocr_test_pali / main.py
rockerritesh's picture
Update main.py
d6dda21 verified
raw
history blame
8.54 kB
# # main.py
# from fastapi import FastAPI, File, UploadFile
# from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
# from transformers.image_utils import load_image
# import torch
# from io import BytesIO
# import os
# from dotenv import load_dotenv
# from PIL import Image
# from huggingface_hub import login
# # Load environment variables
# load_dotenv()
# # Set the cache directory to a writable path
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
# token = os.getenv("huggingface_ankit")
# # Login to the Hugging Face Hub
# login(token)
# app = FastAPI()
# model_id = "google/paligemma2-3b-mix-448"
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda')
# processor = PaliGemmaProcessor.from_pretrained(model_id)
# def predict(image):
# prompt = "<image> ocr"
# model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda')
# input_len = model_inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generation = model.generate(**model_inputs, max_new_tokens=200)
# torch.cuda.empty_cache()
# decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n")
# return decoded
# @app.post("/extract_text")
# async def extract_text(file: UploadFile = File(...)):
# image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image
# text = predict(image)
# return {"extracted_text": text}
# @app.post("/batch_extract_text")
# async def batch_extract_text(files: list[UploadFile] = File(...)):
# # if len(files) > 20:
# # return {"error": "A maximum of 20 images can be processed at a time."}
# images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files]
# prompts = ["OCR"] * len(images)
# model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
# input_len = model_inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
# torch.cuda.empty_cache()
# extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
# return {"extracted_texts": extracted_texts}
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=7860)
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import torch
from io import BytesIO
import os
from dotenv import load_dotenv
from PIL import Image
from huggingface_hub import login
import gc
import logging
from typing import List
import time
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Set the cache directory to a writable path
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
token = os.getenv("huggingface_ankit")
# Login to the Hugging Face Hub
login(token)
app = FastAPI()
# Global variables for model and processor
model = None
processor = None
def load_model():
"""Load model and processor when needed"""
global model, processor
if model is None:
model_id = "google/paligemma2-3b-mix-448"
logger.info(f"Loading model {model_id}")
# Load model with memory-efficient settings
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
logger.info("Model loaded successfully")
def clean_memory():
"""Force garbage collection and clear CUDA cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Clear GPU cache
torch.cuda.empty_cache()
logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes")
logger.info("Memory cleaned")
def predict(image):
"""Process a single image"""
load_model() # Ensure model is loaded
# Process input
prompt = "<image> ocr"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move to appropriate device
model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
# Generate with memory optimization
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=200)
# Decode output
decoded = processor.decode(generation[0], skip_special_tokens=True)
# Clean up intermediates
del model_inputs, generation
clean_memory()
return decoded
@app.post("/extract_text")
async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
"""Extract text from a single image"""
try:
start_time = time.time()
image = Image.open(BytesIO(await file.read())).convert("RGB")
text = predict(image)
# Schedule cleanup after response
background_tasks.add_task(clean_memory)
logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds")
return {"extracted_text": text}
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return {"error": str(e)}
@app.post("/batch_extract_text")
async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
"""Extract text from multiple images with batching"""
try:
start_time = time.time()
# Limit batch size for memory management
max_batch_size = batch_size # Adjust based on your GPU memory
# if len(files) > 32:
# return {"error": "A maximum of 20 images can be processed at a time."}
load_model() # Ensure model is loaded
all_results = []
# Process in smaller batches
for i in range(0, len(files), max_batch_size):
batch_files = files[i:i+max_batch_size]
# Load images
images = []
for file in batch_files:
image_data = await file.read()
img = Image.open(BytesIO(image_data)).convert("RGB")
images.append(img)
# Create batch inputs
prompts = ["<image> ocr"] * len(images)
model_inputs = processor(text=prompts, images=images, return_tensors="pt")
# Move to appropriate device
model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
# Generate with memory optimization
with torch.inference_mode():
generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
# Decode outputs
batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
all_results.extend(batch_results)
# Clean up batch resources
del model_inputs, generations, images
clean_memory()
# Schedule cleanup after response
background_tasks.add_task(clean_memory)
logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds")
return {"extracted_texts": all_results}
except Exception as e:
logger.error(f"Error in batch processing: {str(e)}")
return {"error": str(e)}
# Health check endpoint
@app.get("/health")
async def health_check():
# Generate a random image (20x40 pixels) with random RGB values
random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8)
# Create an image from the random data
image = Image.fromarray(random_data)
predict(image)
clean_memory()
return {"status": "healthy"}
# if __name__ == "__main__":
# import uvicorn
# # Start the server with proper worker configuration
# uvicorn.run(
# app,
# host="0.0.0.0",
# port=7860,
# log_level="info",
# workers=1 # Multiple workers can cause GPU memory issues
# )