File size: 2,626 Bytes
496390f
16320e6
e6dbd5f
52ef174
e6dbd5f
496390f
16320e6
 
 
 
e6dbd5f
496390f
 
236b12b
 
 
 
 
 
496390f
236b12b
496390f
e6dbd5f
 
16320e6
 
 
 
 
 
e6dbd5f
1868dc4
52ef174
e6dbd5f
46a8c74
6c92513
8f9207e
07cb9f9
 
 
1868dc4
e6dbd5f
16320e6
1868dc4
16320e6
 
 
 
 
 
 
 
 
 
 
1868dc4
 
16320e6
 
 
 
 
 
 
 
 
 
 
 
 
1868dc4
 
e6dbd5f
 
 
16320e6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List  # 添加必要的类型导入
import numpy as np
import logging, os

# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
    # 去掉 Bearer 和后面的空格
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid Authorization header format")
    
    token = authorization[len("Bearer "):]
    if token != os.environ.get("AUTHORIZATION"):
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return token

app = FastAPI()

try:
    # Load the Sentence Transformer model
    model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise HTTPException(status_code=500, detail="Model loading failed")

class EmbeddingRequest(BaseModel):
    input: Union[str, List[str]]  # 修复类型定义

@app.post("/v1/embeddings")
async def embeddings(request, authorization: str = Depends(check_authorization)):
# async def embeddings(request: EmbeddingRequest):
    # logger.info("Received request for embeddings")
    # return '2222222222'
    # return request.input
    input = request.input

    try:
        if not input:
            return {
                "object": "list",
                "data": [],
                "model": "BAAI/bge-large-zh-v1.5",
                "usage": {
                    "prompt_tokens": 0,
                    "total_tokens": 0
                }
            }

        # Calculate embeddings
        # embeddings = model.encode(input)
        embeddings = model.encode(input, normalize_embeddings=True)

        # Format the embeddings in OpenAI compatible format
        data = {
            "object": "list",
            "data": [
                {
                    "object": "embedding",
                    "embedding": embeddings.tolist(),
                    "index": 0
                }
            ],
            "model": "BAAI/bge-large-zh-v1.5",
            "usage": {
                "prompt_tokens": len(input),
                "total_tokens": len(input)
            }
        }

        return data
    except Exception as e:
        logger.error(f"Error processing embeddings: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")