embedding / app.py
geqintan's picture
update
c302efc
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List # 添加必要的类型导入
import numpy as np
import logging, os
# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
# 去掉 Bearer 和后面的空格
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
token = authorization[len("Bearer "):]
if token != os.environ.get("AUTHORIZATION"):
raise HTTPException(status_code=401, detail="Unauthorized access")
return token
app = FastAPI()
try:
# Load the Sentence Transformer model
model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail="Model loading failed")
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 修复类型定义
@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)):
input_data = request.input
# 统一转换为列表处理
inputs = [input_data] if isinstance(input_data, str) else input_data
if not inputs:
return { ... } # 空输入处理
# 计算嵌入向量(二维numpy数组)
embeddings = model.encode(inputs, normalize_embeddings=True)
# 构建符合OpenAI格式的响应
data_entries = []
for idx, embed in enumerate(embeddings):
data_entries.append({
"object": "embedding",
"embedding": embed.tolist(), # 每个embed是一维数组
"index": idx
})
return {
"object": "list",
"data": data_entries, # 包含每个输入的嵌入对象
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": sum(len(text) for text in inputs), # 粗略估计token数
"total_tokens": sum(len(text) for text in inputs)
}
}