# # main.py # from fastapi import FastAPI, File, UploadFile # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration # from transformers.image_utils import load_image # import torch # from io import BytesIO # import os # from dotenv import load_dotenv # from PIL import Image # from huggingface_hub import login # # Load environment variables # load_dotenv() # # Set the cache directory to a writable path # os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" # token = os.getenv("huggingface_ankit") # # Login to the Hugging Face Hub # login(token) # app = FastAPI() # model_id = "google/paligemma2-3b-mix-448" # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') # processor = PaliGemmaProcessor.from_pretrained(model_id) # def predict(image): # prompt = " ocr" # model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') # input_len = model_inputs["input_ids"].shape[-1] # with torch.inference_mode(): # generation = model.generate(**model_inputs, max_new_tokens=200) # torch.cuda.empty_cache() # decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") # return decoded # @app.post("/extract_text") # async def extract_text(file: UploadFile = File(...)): # image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image # text = predict(image) # return {"extracted_text": text} # @app.post("/batch_extract_text") # async def batch_extract_text(files: list[UploadFile] = File(...)): # # if len(files) > 20: # # return {"error": "A maximum of 20 images can be processed at a time."} # images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] # prompts = ["OCR"] * len(images) # model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) # input_len = model_inputs["input_ids"].shape[-1] # with torch.inference_mode(): # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) # torch.cuda.empty_cache() # extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] # return {"extracted_texts": extracted_texts} # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=7860) from fastapi import FastAPI, File, UploadFile, BackgroundTasks from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration import torch from io import BytesIO import os from dotenv import load_dotenv from PIL import Image from huggingface_hub import login import gc import logging from typing import List import time import numpy as np # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Load environment variables load_dotenv() # Set the cache directory to a writable path os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" token = os.getenv("huggingface_ankit") # Login to the Hugging Face Hub login(token) app = FastAPI() # Global variables for model and processor model = None processor = None def load_model(): """Load model and processor when needed""" global model, processor if model is None: model_id = "google/paligemma2-3b-mix-448" logger.info(f"Loading model {model_id}") # Load model with memory-efficient settings model = PaliGemmaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency ) processor = PaliGemmaProcessor.from_pretrained(model_id) logger.info("Model loaded successfully") def clean_memory(): """Force garbage collection and clear CUDA cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear GPU cache torch.cuda.empty_cache() logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes") logger.info("Memory cleaned") def predict(image): """Process a single image""" load_model() # Ensure model is loaded # Process input prompt = " ocr" model_inputs = processor(text=prompt, images=image, return_tensors="pt") # Move to appropriate device model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} # Generate with memory optimization with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=200) # Decode output decoded = processor.decode(generation[0], skip_special_tokens=True) # Clean up intermediates del model_inputs, generation clean_memory() return decoded @app.post("/extract_text") async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)): """Extract text from a single image""" try: start_time = time.time() image = Image.open(BytesIO(await file.read())).convert("RGB") text = predict(image) # Schedule cleanup after response background_tasks.add_task(clean_memory) logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds") return {"extracted_text": text} except Exception as e: logger.error(f"Error processing image: {str(e)}") return {"error": str(e)} @app.post("/batch_extract_text") async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): """Extract text from multiple images with batching""" try: start_time = time.time() # Limit batch size for memory management max_batch_size = batch_size # Adjust based on your GPU memory # if len(files) > 32: # return {"error": "A maximum of 20 images can be processed at a time."} load_model() # Ensure model is loaded all_results = [] # Process in smaller batches for i in range(0, len(files), max_batch_size): batch_files = files[i:i+max_batch_size] # Load images images = [] for file in batch_files: image_data = await file.read() img = Image.open(BytesIO(image_data)).convert("RGB") images.append(img) # Create batch inputs prompts = [" ocr"] * len(images) model_inputs = processor(text=prompts, images=images, return_tensors="pt") # Move to appropriate device model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} # Generate with memory optimization with torch.inference_mode(): generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) # Decode outputs batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] all_results.extend(batch_results) # Clean up batch resources del model_inputs, generations, images clean_memory() # Schedule cleanup after response background_tasks.add_task(clean_memory) logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds") return {"extracted_texts": all_results} except Exception as e: logger.error(f"Error in batch processing: {str(e)}") return {"error": str(e)} # Health check endpoint @app.get("/health") async def health_check(): # Generate a random image (20x40 pixels) with random RGB values random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8) # Create an image from the random data image = Image.fromarray(random_data) predict(image) clean_memory() return {"status": "healthy"} # if __name__ == "__main__": # import uvicorn # # Start the server with proper worker configuration # uvicorn.run( # app, # host="0.0.0.0", # port=7860, # log_level="info", # workers=1 # Multiple workers can cause GPU memory issues # )