from fastapi import FastAPI, Request, HTTPException, Depends, Header from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer from typing import Union, List # 添加必要的类型导入 import numpy as np import logging, os # 设置日志记录 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 定义依赖项来校验 Authorization async def check_authorization(authorization: str = Header(..., alias="Authorization")): # 去掉 Bearer 和后面的空格 if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid Authorization header format") token = authorization[len("Bearer "):] if token != os.environ.get("AUTHORIZATION"): raise HTTPException(status_code=401, detail="Unauthorized access") return token app = FastAPI() try: # Load the Sentence Transformer model model = SentenceTransformer("BAAI/bge-large-zh-v1.5") except Exception as e: logger.error(f"Failed to load model: {e}") raise HTTPException(status_code=500, detail="Model loading failed") class EmbeddingRequest(BaseModel): input: Union[str, List[str]] # 修复类型定义 @app.post("/v1/embeddings") async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)): input_data = request.input # 统一转换为列表处理 inputs = [input_data] if isinstance(input_data, str) else input_data if not inputs: return { ... } # 空输入处理 # 计算嵌入向量(二维numpy数组) embeddings = model.encode(inputs, normalize_embeddings=True) # 构建符合OpenAI格式的响应 data_entries = [] for idx, embed in enumerate(embeddings): data_entries.append({ "object": "embedding", "embedding": embed.tolist(), # 每个embed是一维数组 "index": idx }) return { "object": "list", "data": data_entries, # 包含每个输入的嵌入对象 "model": "BAAI/bge-large-zh-v1.5", "usage": { "prompt_tokens": sum(len(text) for text in inputs), # 粗略估计token数 "total_tokens": sum(len(text) for text in inputs) } }