rockerritesh commited on
Commit
cc844d3
·
verified ·
1 Parent(s): 5c0034a

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +38 -0
  2. main.py +65 -0
  3. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+
8
+ # Install necessary system dependencies
9
+ RUN apt-get update && apt-get install -y gcc && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Set the CC environment variable to ensure TorchInductor uses the correct compiler
12
+ ENV CC=gcc
13
+
14
+ # Copy the requirements file and install dependencies
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Create cache and config directories with appropriate permissions
19
+ RUN mkdir -p /app/cache && chmod 777 /app/cache
20
+ RUN mkdir -p /app/config && chmod 777 /app/config
21
+ RUN mkdir -p /app/triton_cache && chmod 777 /app/triton_cache
22
+ RUN mkdir -p /app/torchinductor_cache && chmod 777 /app/torchinductor_cache
23
+
24
+ # Set environment variables for Hugging Face cache, config, Triton, and TorchInductor directories
25
+ ENV HF_HOME=/app/cache
26
+ ENV XDG_CONFIG_HOME=/app/config
27
+ ENV TRITON_CACHE_DIR=/app/triton_cache
28
+ ENV TORCHINDUCTOR_CACHE_DIR=/app/torchinductor_cache
29
+
30
+
31
+ # Copy the application code
32
+ COPY main.py .
33
+
34
+ # Expose the port FastAPI will run on
35
+ EXPOSE 7860
36
+
37
+ # Command to run the FastAPI app
38
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, File, UploadFile
3
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4
+ from transformers.image_utils import load_image
5
+ import torch
6
+ from io import BytesIO
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from PIL import Image
10
+
11
+ from huggingface_hub import login
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ # Set the cache directory to a writable path
17
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
18
+
19
+ token = os.getenv("huggingface_ankit")
20
+ # Login to the Hugging Face Hub
21
+ login(token)
22
+
23
+ app = FastAPI()
24
+
25
+ model_id = "google/paligemma2-3b-mix-448"
26
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda')
27
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
28
+
29
+ def predict(image):
30
+ prompt = "<image> ocr"
31
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda')
32
+ input_len = model_inputs["input_ids"].shape[-1]
33
+ with torch.inference_mode():
34
+ generation = model.generate(**model_inputs, max_new_tokens=200)
35
+ torch.cuda.empty_cache()
36
+ decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n")
37
+ return decoded
38
+
39
+ @app.post("/extract_text")
40
+ async def extract_text(file: UploadFile = File(...)):
41
+ image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image
42
+ text = predict(image)
43
+ return {"extracted_text": text}
44
+
45
+ @app.post("/batch_extract_text")
46
+ async def batch_extract_text(files: list[UploadFile] = File(...)):
47
+ if len(files) > 20:
48
+ return {"error": "A maximum of 20 images can be processed at a time."}
49
+
50
+ images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files]
51
+ prompts = ["OCR"] * len(images)
52
+
53
+ model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
54
+ input_len = model_inputs["input_ids"].shape[-1]
55
+
56
+ with torch.inference_mode():
57
+ generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
58
+ torch.cuda.empty_cache()
59
+ extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
60
+
61
+ return {"extracted_texts": extracted_texts}
62
+
63
+ if __name__ == "__main__":
64
+ import uvicorn
65
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ numpy
4
+ huggingface_hub
5
+ python-dotenv
6
+ transformers
7
+ torch
8
+ accelerate
9
+ pillow
10
+ python-multipart