embedding / app.py
geqintan's picture
update
52ef174
raw
history blame
2.63 kB
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List # 添加必要的类型导入
import numpy as np
import logging, os
# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
# 去掉 Bearer 和后面的空格
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
token = authorization[len("Bearer "):]
if token != os.environ.get("AUTHORIZATION"):
raise HTTPException(status_code=401, detail="Unauthorized access")
return token
app = FastAPI()
try:
# Load the Sentence Transformer model
model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail="Model loading failed")
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 修复类型定义
@app.post("/v1/embeddings")
async def embeddings(request, authorization: str = Depends(check_authorization)):
# async def embeddings(request: EmbeddingRequest):
# logger.info("Received request for embeddings")
# return '2222222222'
# return request.input
input = request.input
try:
if not input:
return {
"object": "list",
"data": [],
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
# Calculate embeddings
# embeddings = model.encode(input)
embeddings = model.encode(input, normalize_embeddings=True)
# Format the embeddings in OpenAI compatible format
data = {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": embeddings.tolist(),
"index": 0
}
],
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": len(input),
"total_tokens": len(input)
}
}
return data
except Exception as e:
logger.error(f"Error processing embeddings: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")