|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
import os
|
|
import tempfile
|
|
from transformers import pipeline
|
|
import torch
|
|
from PIL import Image
|
|
import pytesseract
|
|
from langchain.chains import LLMChain
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain_community.llms import HuggingFaceHub
|
|
|
|
|
|
app = FastAPI(
|
|
title="AI-Powered Web Application API",
|
|
description="API for document analysis, image captioning, and question answering",
|
|
version="1.0.0"
|
|
)
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
summarizer = None
|
|
image_captioner = None
|
|
qa_chain = None
|
|
|
|
class SummaryRequest(BaseModel):
|
|
file: UploadFile = File(...)
|
|
|
|
class CaptionRequest(BaseModel):
|
|
file: UploadFile = File(...)
|
|
|
|
class QARequest(BaseModel):
|
|
file: UploadFile = File(...)
|
|
question: str = Form(...)
|
|
|
|
def initialize_models():
|
|
"""Initialize AI models with optimized prompts"""
|
|
global summarizer, image_captioner, qa_chain
|
|
|
|
|
|
if summarizer is None:
|
|
summarizer = pipeline(
|
|
"summarization",
|
|
model="facebook/bart-large-cnn",
|
|
device=0 if torch.cuda.is_available() else -1
|
|
)
|
|
|
|
|
|
if image_captioner is None:
|
|
image_captioner = pipeline(
|
|
"image-to-text",
|
|
model="nlpconnect/vit-gpt2-image-captioning",
|
|
device=0 if torch.cuda.is_available() else -1
|
|
)
|
|
|
|
|
|
if qa_chain is None:
|
|
llm = HuggingFaceHub(
|
|
repo_id="google/flan-t5-large",
|
|
model_kwargs={"temperature": 0.1, "max_length": 512}
|
|
)
|
|
|
|
qa_prompt = PromptTemplate(
|
|
input_variables=["document", "question"],
|
|
template="""
|
|
Using the provided document, answer the following question precisely.
|
|
If the answer cannot be determined from the document, respond with
|
|
'The answer cannot be determined from the provided document.'
|
|
|
|
Question: {question}
|
|
|
|
Rules:
|
|
1. Provide a concise answer (1-3 sentences maximum)
|
|
2. When possible, reference the specific section of the document that supports your answer
|
|
3. Maintain numerical precision when answering quantitative questions
|
|
4. For comparison questions, highlight both items being compared
|
|
|
|
Document: {document}
|
|
"""
|
|
)
|
|
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
|
|
def extract_text_from_file(file: UploadFile) -> str:
|
|
"""Extract text from various file formats"""
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
|
temp_file.write(file.file.read())
|
|
temp_path = temp_file.name
|
|
|
|
try:
|
|
|
|
|
|
if file.filename.endswith('.txt'):
|
|
with open(temp_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
else:
|
|
|
|
raise HTTPException(
|
|
status_code=415,
|
|
detail="File type not supported in this example implementation"
|
|
)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
@app.post("/api/summarize")
|
|
async def summarize_document(file: UploadFile = File(...)):
|
|
"""Summarize a document"""
|
|
initialize_models()
|
|
|
|
try:
|
|
|
|
document_text = extract_text_from_file(file)
|
|
|
|
|
|
summary = summarizer(
|
|
document_text,
|
|
max_length=150,
|
|
min_length=30,
|
|
do_sample=False,
|
|
truncation=True
|
|
)
|
|
|
|
return JSONResponse(
|
|
content={"status": "success", "result": summary[0]['summary_text']},
|
|
status_code=200
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error processing document: {str(e)}"
|
|
)
|
|
|
|
@app.post("/api/caption")
|
|
async def generate_image_caption(file: UploadFile = File(...)):
|
|
"""Generate caption for an image"""
|
|
initialize_models()
|
|
|
|
try:
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
|
temp_file.write(file.file.read())
|
|
temp_path = temp_file.name
|
|
|
|
|
|
image = Image.open(temp_path)
|
|
|
|
|
|
caption = image_captioner(
|
|
image,
|
|
generate_kwargs={
|
|
"max_length": 50,
|
|
"num_beams": 4,
|
|
"early_stopping": True
|
|
}
|
|
)
|
|
|
|
return JSONResponse(
|
|
content={"status": "success", "result": caption[0]['generated_text']},
|
|
status_code=200
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error processing image: {str(e)}"
|
|
)
|
|
finally:
|
|
if 'temp_path' in locals() and os.path.exists(temp_path):
|
|
os.unlink(temp_path)
|
|
|
|
@app.post("/api/qa")
|
|
async def answer_question(
|
|
file: UploadFile = File(...),
|
|
question: str = Form(...)
|
|
):
|
|
"""Answer questions based on document content"""
|
|
initialize_models()
|
|
|
|
try:
|
|
|
|
document_text = extract_text_from_file(file)
|
|
|
|
|
|
answer = qa_chain.run(document=document_text, question=question)
|
|
|
|
return JSONResponse(
|
|
content={"status": "success", "result": answer},
|
|
status_code=200
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error processing question: {str(e)}"
|
|
)
|
|
|
|
@app.get("/")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy", "version": "1.0.0"}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |