Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline | |
from PIL import Image | |
import torch | |
import os | |
import tempfile | |
from gtts import gTTS | |
app = FastAPI() | |
# CORS Configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize models | |
try: | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
git_model.eval() | |
USE_GIT = True | |
except Exception as e: | |
print(f"[INFO] Falling back to ViT: {e}") | |
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
USE_GIT = False | |
def generate_caption(image_path: str) -> str: | |
try: | |
if USE_GIT: | |
image = Image.open(image_path).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = git_model.generate(**inputs, max_length=50) | |
caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
else: | |
result = captioner(image_path) | |
caption = result[0]['generated_text'] | |
return caption | |
except Exception as e: | |
raise Exception(f"Caption generation failed: {str(e)}") | |
async def caption_image(file: UploadFile = File(...)): | |
# Validate file type | |
valid_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp'] | |
if file.content_type not in valid_types: | |
raise HTTPException( | |
status_code=400, | |
detail="Please upload a valid image (JPEG, PNG, GIF, or WEBP)" | |
) | |
try: | |
# Save temp file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp: | |
shutil.copyfileobj(file.file, temp) | |
temp_path = temp.name | |
# Generate caption | |
caption = generate_caption(temp_path) | |
# Generate audio | |
audio_path = os.path.join(tempfile.gettempdir(), f"caption_{os.path.basename(temp_path)}.mp3") | |
tts = gTTS(text=caption) | |
tts.save(audio_path) | |
return { | |
"answer": caption, | |
"audio": f"/files/{os.path.basename(audio_path)}" | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=str(e) | |
) | |
finally: | |
if 'temp_path' in locals() and os.path.exists(temp_path): | |
os.unlink(temp_path) | |
async def get_file(filename: str): | |
file_path = os.path.join(tempfile.gettempdir(), filename) | |
if os.path.exists(file_path): | |
return FileResponse(file_path) | |
raise HTTPException(status_code=404, detail="File not found") |