|
from fastapi import FastAPI, Request, HTTPException, Depends, Header |
|
from pydantic import BaseModel, Field |
|
from sentence_transformers import SentenceTransformer |
|
from typing import Union, List |
|
import numpy as np |
|
import logging, os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def check_authorization(authorization: str = Header(..., alias="Authorization")): |
|
|
|
if not authorization.startswith("Bearer "): |
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
|
|
|
token = authorization[len("Bearer "):] |
|
if token != os.environ.get("AUTHORIZATION"): |
|
raise HTTPException(status_code=401, detail="Unauthorized access") |
|
return token |
|
|
|
app = FastAPI() |
|
|
|
try: |
|
|
|
model = SentenceTransformer("BAAI/bge-large-zh-v1.5") |
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise HTTPException(status_code=500, detail="Model loading failed") |
|
|
|
class EmbeddingRequest(BaseModel): |
|
input: Union[str, List[str]] |
|
|
|
@app.post("/v1/embeddings") |
|
async def embeddings(request, authorization: str = Depends(check_authorization)): |
|
|
|
|
|
|
|
|
|
input = request.input |
|
|
|
try: |
|
if not input: |
|
return { |
|
"object": "list", |
|
"data": [], |
|
"model": "BAAI/bge-large-zh-v1.5", |
|
"usage": { |
|
"prompt_tokens": 0, |
|
"total_tokens": 0 |
|
} |
|
} |
|
|
|
|
|
|
|
embeddings = model.encode(input, normalize_embeddings=True) |
|
|
|
|
|
data = { |
|
"object": "list", |
|
"data": [ |
|
{ |
|
"object": "embedding", |
|
"embedding": embeddings.tolist(), |
|
"index": 0 |
|
} |
|
], |
|
"model": "BAAI/bge-large-zh-v1.5", |
|
"usage": { |
|
"prompt_tokens": len(input), |
|
"total_tokens": len(input) |
|
} |
|
} |
|
|
|
return data |
|
except Exception as e: |
|
logger.error(f"Error processing embeddings: {e}") |
|
raise HTTPException(status_code=500, detail="Internal Server Error") |