|
from fastapi import FastAPI, Request, HTTPException, Depends, Header |
|
from pydantic import BaseModel, Field |
|
from sentence_transformers import SentenceTransformer |
|
from typing import Union, List |
|
import numpy as np |
|
import logging, os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def check_authorization(authorization: str = Header(..., alias="Authorization")): |
|
|
|
if not authorization.startswith("Bearer "): |
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
|
|
|
token = authorization[len("Bearer "):] |
|
if token != os.environ.get("AUTHORIZATION"): |
|
raise HTTPException(status_code=401, detail="Unauthorized access") |
|
return token |
|
|
|
app = FastAPI() |
|
|
|
try: |
|
|
|
model = SentenceTransformer("BAAI/bge-large-zh-v1.5") |
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise HTTPException(status_code=500, detail="Model loading failed") |
|
|
|
class EmbeddingRequest(BaseModel): |
|
input: Union[str, List[str]] |
|
|
|
@app.post("/v1/embeddings") |
|
async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)): |
|
input_data = request.input |
|
|
|
inputs = [input_data] if isinstance(input_data, str) else input_data |
|
|
|
if not inputs: |
|
return { ... } |
|
|
|
|
|
embeddings = model.encode(inputs, normalize_embeddings=True) |
|
|
|
|
|
data_entries = [] |
|
for idx, embed in enumerate(embeddings): |
|
data_entries.append({ |
|
"object": "embedding", |
|
"embedding": embed.tolist(), |
|
"index": idx |
|
}) |
|
|
|
return { |
|
"object": "list", |
|
"data": data_entries, |
|
"model": "BAAI/bge-large-zh-v1.5", |
|
"usage": { |
|
"prompt_tokens": sum(len(text) for text in inputs), |
|
"total_tokens": sum(len(text) for text in inputs) |
|
} |
|
} |