File size: 2,064 Bytes
d1fced8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import pipeline, BitsAndBytesConfig
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
from PIL import Image
from io import BytesIO

# Set up device (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Configure quantization if using GPU
if device == "cuda":
    print("GPU found. Using 4-bit quantization.")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )
else:
    print("GPU not found. Using CPU with default settings.")
    quantization_config = None

# Load model pipeline
model_id = "bczhou/tiny-llava-v1-hf"
pipe = pipeline("image-to-text", model=model_id, device=device)

print(f"Using device: {device}")

# Initialize FastAPI application
app = FastAPI()

# Health check endpoint to ensure API is running
@app.get("/")
async def root():
    return {"message": "API is running fine."}

# Define Pydantic model for request input
class ImagePromptInput(BaseModel):
    image_url: str
    prompt: str

# FastAPI route for generating text from an image
@app.post("/generate")
async def generate_text(input_data: ImagePromptInput):
    try:
        # Download and process the image
        response = requests.get(input_data.image_url)
        image = Image.open(BytesIO(response.content)).convert("RGB")
        image = image.resize((750, 500))  # Resize image to fixed dimensions

        # Create a full prompt to pass to the model
        full_prompt = f"USER: <image>\n{input_data.prompt}\nASSISTANT: "

        # Generate response using the model pipeline
        outputs = pipe(image, prompt=full_prompt, generate_kwargs={"max_new_tokens": 200})

        # Return generated text
        generated_text = outputs[0]['generated_text'] #type: ignore
        return {"response": generated_text}

    except Exception as e:
        # Return error if something goes wrong
        raise HTTPException(status_code=500, detail=str(e))