Summarization / appImage.py
ikraamkb's picture
Update appImage.py
d5d3aa6 verified
raw
history blame
3 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
from PIL import Image
import torch
import os
import tempfile
from gtts import gTTS
app = FastAPI()
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize models
try:
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
git_model.eval()
USE_GIT = True
except Exception as e:
print(f"[INFO] Falling back to ViT: {e}")
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
USE_GIT = False
def generate_caption(image_path: str) -> str:
try:
if USE_GIT:
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
outputs = git_model.generate(**inputs, max_length=50)
caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
else:
result = captioner(image_path)
caption = result[0]['generated_text']
return caption
except Exception as e:
raise Exception(f"Caption generation failed: {str(e)}")
@app.post("/imagecaption/")
async def caption_image(file: UploadFile = File(...)):
# Validate file type
valid_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']
if file.content_type not in valid_types:
raise HTTPException(
status_code=400,
detail="Please upload a valid image (JPEG, PNG, GIF, or WEBP)"
)
try:
# Save temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
shutil.copyfileobj(file.file, temp)
temp_path = temp.name
# Generate caption
caption = generate_caption(temp_path)
# Generate audio
audio_path = os.path.join(tempfile.gettempdir(), f"caption_{os.path.basename(temp_path)}.mp3")
tts = gTTS(text=caption)
tts.save(audio_path)
return {
"answer": caption,
"audio": f"/files/{os.path.basename(audio_path)}"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
finally:
if 'temp_path' in locals() and os.path.exists(temp_path):
os.unlink(temp_path)
@app.get("/files/{filename}")
async def get_file(filename: str):
file_path = os.path.join(tempfile.gettempdir(), filename)
if os.path.exists(file_path):
return FileResponse(file_path)
raise HTTPException(status_code=404, detail="File not found")