ethiotech4848 commited on
Commit
2c2b2b5
·
verified ·
1 Parent(s): 099df15

Create fastapi_app.py

Browse files
Files changed (1) hide show
  1. fastapi_app.py +132 -0
fastapi_app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
+ import json
6
+
7
+ from typegpt_api import generate, model_mapping, simplified_models
8
+ from api_info import developer_info, model_providers
9
+
10
+ app = FastAPI()
11
+
12
+ # Set up CORS middleware if needed
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ @app.get("/health_check")
22
+ async def health_check():
23
+ return {"status": "OK"}
24
+
25
+ @app.get("/models")
26
+ async def get_models():
27
+ try:
28
+ response = {
29
+ "object": "list",
30
+ "data": []
31
+ }
32
+ for provider, info in model_providers.items():
33
+ for model in info["models"]:
34
+ response["data"].append({
35
+ "id": model,
36
+ "object": "model",
37
+ "provider": provider,
38
+ "description": info["description"]
39
+ })
40
+
41
+ return JSONResponse(content=response)
42
+ except Exception as e:
43
+ return JSONResponse(content={"error": str(e)}, status_code=500)
44
+
45
+ @app.post("/chat/completions")
46
+ async def chat_completions(request: Request):
47
+ # Receive the JSON payload
48
+ try:
49
+ body = await request.json()
50
+ except Exception as e:
51
+ return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400)
52
+
53
+ # Extract parameters
54
+ model = body.get("model")
55
+ messages = body.get("messages")
56
+ temperature = body.get("temperature", 0.7)
57
+ top_p = body.get("top_p", 1.0)
58
+ n = body.get("n", 1)
59
+ stream = body.get("stream", False)
60
+ stop = body.get("stop")
61
+ max_tokens = body.get("max_tokens")
62
+ presence_penalty = body.get("presence_penalty", 0.0)
63
+ frequency_penalty = body.get("frequency_penalty", 0.0)
64
+ logit_bias = body.get("logit_bias")
65
+ user = body.get("user")
66
+ timeout = 30 # or set based on your preference
67
+
68
+ # Validate required parameters
69
+ if not model:
70
+ return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400)
71
+ if not messages:
72
+ return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400)
73
+
74
+ # Call the generate function
75
+ try:
76
+ if stream:
77
+ async def generate_stream():
78
+ response = generate(
79
+ model=model,
80
+ messages=messages,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ n=n,
84
+ stream=True,
85
+ stop=stop,
86
+ max_tokens=max_tokens,
87
+ presence_penalty=presence_penalty,
88
+ frequency_penalty=frequency_penalty,
89
+ logit_bias=logit_bias,
90
+ user=user,
91
+ timeout=timeout,
92
+ )
93
+
94
+ for chunk in response:
95
+ yield f"data: {json.dumps(chunk)}\n\n"
96
+ yield "data: [DONE]\n\n"
97
+
98
+ return StreamingResponse(
99
+ generate_stream(),
100
+ media_type="text/event-stream",
101
+ headers={
102
+ "Cache-Control": "no-cache",
103
+ "Connection": "keep-alive",
104
+ "Transfer-Encoding": "chunked"
105
+ }
106
+ )
107
+ else:
108
+ response = generate(
109
+ model=model,
110
+ messages=messages,
111
+ temperature=temperature,
112
+ top_p=top_p,
113
+ n=n,
114
+ stream=False,
115
+ stop=stop,
116
+ max_tokens=max_tokens,
117
+ presence_penalty=presence_penalty,
118
+ frequency_penalty=frequency_penalty,
119
+ logit_bias=logit_bias,
120
+ user=user,
121
+ timeout=timeout,
122
+ )
123
+ return JSONResponse(content=response)
124
+ except Exception as e:
125
+ return JSONResponse(content={"error": str(e)}, status_code=500)
126
+
127
+ @app.get("/developer_info")
128
+ async def get_developer_info():
129
+ return JSONResponse(content=developer_info)
130
+
131
+ if __name__ == "__main__":
132
+ uvicorn.run(app, host="0.0.0.0", port=8000)