Jina_Re_Rank / app.py
Deep8591's picture
Update app.py
622ed28 verified
raw
history blame contribute delete
3.48 kB
import os
import torch
from loguru import logger
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import CrossEncoder
from typing import List, Optional
# Initialize FastAPI app with documentation metadata
app = FastAPI(
title="Document Reranker API",
description="An API for reranking documents using a CrossEncoder model.",
version="1.0",
docs_url="/docs", # Swagger UI
redoc_url="/redoc", # ReDoc UI
)
# Enable CORS (optional but useful for frontend integration)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (modify as needed)
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Device selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.warning(
f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
)
# Ensure a writable cache directory
os.makedirs("models_cache", exist_ok=True)
# Load the model at startup to avoid reloading for each request
try:
model = CrossEncoder(
"jinaai/jina-reranker-v1-turbo-en",
trust_remote_code=True,
device=DEVICE,
cache_dir="models_cache",
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError("Model loading failed. Check logs for details.")
class RerankerRequest(BaseModel):
query: str = Field(..., description="The search query string")
documents: List[str] = Field(..., description="List of documents to rerank")
return_documents: bool = Field(
True, description="Whether to return document content in results"
)
top_k: int = Field(3, description="Number of top results to return")
class RankedResult(BaseModel):
score: float
index: int
document: Optional[str] = None
class RerankerResponse(BaseModel):
results: List[RankedResult]
@app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
async def rerank_documents(request: RerankerRequest):
"""
Reranks the given list of documents based on their relevance to the query.
- **query**: The input query string.
- **documents**: A list of documents to be reranked.
- **return_documents**: Whether to include document content in results.
- **top_k**: Number of top-ranked documents to return.
Returns:
- A list of ranked documents with scores and indexes.
"""
try:
# Prepare model input
results = model.rank(
request.query,
request.documents,
return_documents=request.return_documents,
top_k=request.top_k,
)
# Format the results based on the model's output
formatted_results = [
RankedResult(
score=result["score"],
index=result["corpus_id"],
document=result["text"] if request.return_documents else None,
)
for result in results
]
# Format results
return RerankerResponse(results=formatted_results)
except Exception as e:
logger.error(f"Error in reranking: {e}")
raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
# Run the FastAPI app with Uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)