aiWeb / main.py
benkada's picture
Update main.py
0133631 verified
raw
history blame
6.35 kB
import os, io
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
from PyPDF2 import PdfReader
from docx import Document
from PIL import Image
from io import BytesIO
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") # set in HF Space secrets or env
PORT = int(os.getenv("PORT", 7860)) # Spaces auto-set PORT; default 7860 locally
app = FastAPI(
title="AI-Powered Web-App API",
description="Backend for summarisation, captioning & QA",
version="1.2.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -----------------------------------------------------------------------------
# OPTIONAL STATIC FILES (only if ./static exists)
# -----------------------------------------------------------------------------
static_dir = Path("static")
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# -----------------------------------------------------------------------------
# HUGGING FACE INFERENCE CLIENTS
# -----------------------------------------------------------------------------
summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def extract_text_from_pdf(content: bytes) -> str:
reader = PdfReader(io.BytesIO(content))
return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
def extract_text_from_docx(content: bytes) -> str:
doc = Document(io.BytesIO(content))
return "\n".join(p.text for p in doc.paragraphs).strip()
def process_uploaded_file(file: UploadFile) -> str:
content = file.file.read()
ext = file.filename.split(".")[-1].lower()
if ext == "pdf":
return extract_text_from_pdf(content)
if ext == "docx":
return extract_text_from_docx(content)
if ext == "txt":
return content.decode("utf-8").strip()
raise ValueError("Unsupported file type")
# -----------------------------------------------------------------------------
# ROUTES
# -----------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def serve_index():
"""Return the frontend HTML page."""
return FileResponse("index.html")
# -------------------- Summarisation ------------------------------------------
@app.post("/api/summarize")
async def summarize_document(file: UploadFile = File(...)):
try:
text = process_uploaded_file(file)
if len(text) < 20:
return {"result": "Document too short to summarise."}
summary_raw = summary_client.summarization(text[:3000])
if isinstance(summary_raw, list):
summary_txt = summary_raw[0].get("summary_text", str(summary_raw))
elif isinstance(summary_raw, dict):
summary_txt = summary_raw.get("summary_text", str(summary_raw))
else:
summary_txt = str(summary_raw)
return {"result": summary_txt}
except Exception as exc:
return JSONResponse(status_code=500, content={"error": f"Summarisation failure: {exc}"})
# -------------------- Image Caption -----------------------------------------
@app.post("/api/caption")
async def caption_image(file: UploadFile = File(...)):
try:
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
img.thumbnail((1024, 1024))
buf = BytesIO(); img.save(buf, format="JPEG")
result = image_caption_client.image_to_text(buf.getvalue())
if isinstance(result, dict):
caption = result.get("generated_text") or result.get("caption") or "No caption found."
elif isinstance(result, list):
caption = result[0].get("generated_text", "No caption found.")
else:
caption = str(result)
return {"result": caption}
except Exception as exc:
return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
# -------------------- Question Answering ------------------------------------
@app.post("/api/qa")
async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
try:
if file.content_type.startswith("image/"):
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB"); img.thumbnail((1024, 1024))
b = BytesIO(); img.save(b, format="JPEG")
res = image_caption_client.image_to_text(b.getvalue())
context = res.get("generated_text") if isinstance(res, dict) else str(res)
else:
context = process_uploaded_file(file)[:3000]
if not context:
return {"result": "No context – cannot answer."}
answer = qa_client.question_answering(question=question, context=context)
return {"result": answer.get("answer", "No answer found.")}
except Exception as exc:
return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
# -------------------- Health -------------------------------------------------
@app.get("/api/health")
async def health():
return {"status": "healthy", "hf_token_set": bool(HUGGINGFACE_TOKEN), "version": app.version}
# -----------------------------------------------------------------------------
# ENTRYPOINT
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)