jina-reranker / app.py
Gopal2002's picture
Update app.py
c369d6c verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from sentence_transformers import CrossEncoder
app = FastAPI()
# Load the CrossEncoder model
model = CrossEncoder(
"jinaai/jina-reranker-v2-base-multilingual",
automodel_args={"torch_dtype": "auto"},
trust_remote_code=True,
)
# Request and response models
class RerankRequest(BaseModel):
query: str
documents: List[str]
class RerankResult(BaseModel):
corpus_id: int
score: float
text: str
class RerankResponse(BaseModel):
query: str
results: List[RerankResult]
# Rerank function using rank()
@app.post("/rerank", response_model=RerankResponse)
def rerank(req: RerankRequest):
query = req.query
documents = req.documents
# Use model.rank() to get rankings based on the query and documents
rankings = model.rank(query, documents, return_documents=True, convert_to_tensor=True)
# Prepare the ranked results
ranked = [
{"corpus_id": ranking['corpus_id'], "score": float(ranking['score']), "text": ranking['text']}
for ranking in rankings
]
return RerankResponse(query=req.query, results=ranked)