geqintan commited on
Commit
c302efc
·
1 Parent(s): e0922e7
Files changed (1) hide show
  1. app.py +25 -35
app.py CHANGED
@@ -33,42 +33,32 @@ class EmbeddingRequest(BaseModel):
33
  input: Union[str, List[str]] # 修复类型定义
34
 
35
  @app.post("/v1/embeddings")
36
- async def embeddings(request:EmbeddingRequest, authorization: str = Depends(check_authorization)):
37
- input = request.input
 
 
38
 
39
- try:
40
- if not input:
41
- return {
42
- "object": "list",
43
- "data": [],
44
- "model": "BAAI/bge-large-zh-v1.5",
45
- "usage": {
46
- "prompt_tokens": 0,
47
- "total_tokens": 0
48
- }
49
- }
50
 
51
- # Calculate embeddings
52
- embeddings = model.encode(input, normalize_embeddings=True)
53
 
54
- # Format the embeddings in OpenAI compatible format
55
- data = {
56
- "object": "list",
57
- "data": [
58
- {
59
- "object": "embedding",
60
- "embedding": embeddings.tolist(),
61
- "index": 0
62
- }
63
- ],
64
- "model": "BAAI/bge-large-zh-v1.5",
65
- "usage": {
66
- "prompt_tokens": len(input),
67
- "total_tokens": len(input)
68
- }
69
- }
70
 
71
- return data
72
- except Exception as e:
73
- logger.error(f"Error processing embeddings: {e}")
74
- raise HTTPException(status_code=500, detail="Internal Server Error")
 
 
 
 
 
 
33
  input: Union[str, List[str]] # 修复类型定义
34
 
35
  @app.post("/v1/embeddings")
36
+ async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)):
37
+ input_data = request.input
38
+ # 统一转换为列表处理
39
+ inputs = [input_data] if isinstance(input_data, str) else input_data
40
 
41
+ if not inputs:
42
+ return { ... } # 空输入处理
 
 
 
 
 
 
 
 
 
43
 
44
+ # 计算嵌入向量(二维numpy数组)
45
+ embeddings = model.encode(inputs, normalize_embeddings=True)
46
 
47
+ # 构建符合OpenAI格式的响应
48
+ data_entries = []
49
+ for idx, embed in enumerate(embeddings):
50
+ data_entries.append({
51
+ "object": "embedding",
52
+ "embedding": embed.tolist(), # 每个embed是一维数组
53
+ "index": idx
54
+ })
 
 
 
 
 
 
 
 
55
 
56
+ return {
57
+ "object": "list",
58
+ "data": data_entries, # 包含每个输入的嵌入对象
59
+ "model": "BAAI/bge-large-zh-v1.5",
60
+ "usage": {
61
+ "prompt_tokens": sum(len(text) for text in inputs), # 粗略估计token数
62
+ "total_tokens": sum(len(text) for text in inputs)
63
+ }
64
+ }