File size: 2,626 Bytes
496390f 16320e6 e6dbd5f 52ef174 e6dbd5f 496390f 16320e6 e6dbd5f 496390f 236b12b 496390f 236b12b 496390f e6dbd5f 16320e6 e6dbd5f 1868dc4 52ef174 e6dbd5f 46a8c74 6c92513 8f9207e 07cb9f9 1868dc4 e6dbd5f 16320e6 1868dc4 16320e6 1868dc4 16320e6 1868dc4 e6dbd5f 16320e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List # 添加必要的类型导入
import numpy as np
import logging, os
# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
# 去掉 Bearer 和后面的空格
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
token = authorization[len("Bearer "):]
if token != os.environ.get("AUTHORIZATION"):
raise HTTPException(status_code=401, detail="Unauthorized access")
return token
app = FastAPI()
try:
# Load the Sentence Transformer model
model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail="Model loading failed")
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 修复类型定义
@app.post("/v1/embeddings")
async def embeddings(request, authorization: str = Depends(check_authorization)):
# async def embeddings(request: EmbeddingRequest):
# logger.info("Received request for embeddings")
# return '2222222222'
# return request.input
input = request.input
try:
if not input:
return {
"object": "list",
"data": [],
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
# Calculate embeddings
# embeddings = model.encode(input)
embeddings = model.encode(input, normalize_embeddings=True)
# Format the embeddings in OpenAI compatible format
data = {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": embeddings.tolist(),
"index": 0
}
],
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": len(input),
"total_tokens": len(input)
}
}
return data
except Exception as e:
logger.error(f"Error processing embeddings: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error") |