File size: 2,432 Bytes
496390f
16320e6
e6dbd5f
 
496390f
16320e6
 
 
 
e6dbd5f
496390f
 
236b12b
 
 
 
 
 
496390f
236b12b
496390f
e6dbd5f
 
16320e6
 
 
 
 
 
e6dbd5f
 
16320e6
e6dbd5f
 
ad9b013
 
e6dbd5f
 
16320e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6dbd5f
 
 
16320e6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import numpy as np
import logging, os

# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
    # 去掉 Bearer 和后面的空格
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid Authorization header format")
    
    token = authorization[len("Bearer "):]
    if token != os.environ.get("AUTHORIZATION"):
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return token

app = FastAPI()

try:
    # Load the Sentence Transformer model
    model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise HTTPException(status_code=500, detail="Model loading failed")

class EmbeddingRequest(BaseModel):
    input: str = Field(..., min_length=1, max_length=1000)

@app.post("/embeddings")
# async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)):
async def embeddings(request: EmbeddingRequest):
    input_text = request.input

    try:
        if not input_text:
            return {
                "object": "list",
                "data": [],
                "model": "BAAI/bge-large-zh-v1.5",
                "usage": {
                    "prompt_tokens": 0,
                    "total_tokens": 0
                }
            }

        # Calculate embeddings
        embeddings = model.encode(input_text)

        # Format the embeddings in OpenAI compatible format
        data = {
            "object": "list",
            "data": [
                {
                    "object": "embedding",
                    "embedding": embeddings.tolist(),
                    "index": 0
                }
            ],
            "model": "BAAI/bge-large-zh-v1.5",
            "usage": {
                "prompt_tokens": len(input_text),
                "total_tokens": len(input_text)
            }
        }

        return data
    except Exception as e:
        logger.error(f"Error processing embeddings: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")