ParthSadaria commited on
Commit
1d32d66
·
verified ·
1 Parent(s): b8bbdba

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -46
main.py CHANGED
@@ -7,6 +7,7 @@ import httpx
7
  from pathlib import Path # Import Path from pathlib
8
  import requests
9
  import re
 
10
  import json
11
  from typing import Optional
12
 
@@ -119,62 +120,46 @@ async def get_models():
119
  async def fetch_models():
120
  return await get_models()
121
 
 
 
122
  @app.post("/chat/completions")
123
  @app.post("/v1/chat/completions")
124
  async def get_completion(payload: Payload, request: Request):
125
- model_to_use = payload.model
126
  payload_dict = payload.dict()
127
- payload_dict["model"] = model_to_use
128
-
129
  # Select the appropriate endpoint
130
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
131
-
132
  print(payload_dict)
133
-
134
  async def stream_generator(payload_dict):
135
- async with httpx.AsyncClient() as client:
136
- try:
137
- async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
138
- if response.status_code == 422:
139
- # Handle unprocessable entity errors
140
- raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
141
- elif response.status_code == 400:
142
- # Handle bad request errors
143
- raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
144
- elif response.status_code == 403:
145
- # Handle forbidden access
146
- raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
147
- elif response.status_code == 404:
148
- # Handle not found errors
149
- raise HTTPException(status_code=404, detail="The requested resource was not found.")
150
- elif response.status_code >= 500:
151
- # Handle server errors
152
- raise HTTPException(status_code=500, detail="Server error. Try again later.")
153
-
154
- response.raise_for_status() # Raise HTTPStatusError for non-200 responses not explicitly handled
155
-
156
- # Stream response to the client
157
- async for line in response.aiter_lines():
158
- if line:
159
- yield f"{line}\n"
160
- except httpx.HTTPStatusError as status_err:
161
- # Catch specific HTTP errors
162
- raise HTTPException(
163
- status_code=status_err.response.status_code,
164
- detail=f"HTTP error: {status_err.response.text}"
165
- )
166
- except httpx.TimeoutException:
167
- # Handle timeout errors
168
- raise HTTPException(status_code=504, detail="Request timed out. Please try again later.")
169
- except httpx.RequestError as req_err:
170
- # Handle generic request errors
171
- raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
172
- except Exception as e:
173
- # Catch any unexpected exceptions
174
- raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
175
 
176
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
177
-
178
  @app.on_event("startup")
179
  async def startup_event():
180
  print("API endpoints:")
 
7
  from pathlib import Path # Import Path from pathlib
8
  import requests
9
  import re
10
+ import cloudscraper
11
  import json
12
  from typing import Optional
13
 
 
120
  async def fetch_models():
121
  return await get_models()
122
 
123
+ import cloudscraper
124
+
125
  @app.post("/chat/completions")
126
  @app.post("/v1/chat/completions")
127
  async def get_completion(payload: Payload, request: Request):
 
128
  payload_dict = payload.dict()
 
 
129
  # Select the appropriate endpoint
130
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
 
131
  print(payload_dict)
 
132
  async def stream_generator(payload_dict):
133
+ scraper = cloudscraper.create_scraper() # Create a CloudScraper session
134
+ try:
135
+ # Send POST request using CloudScraper
136
+ response = scraper.post(f"{endpoint}/v1/chat/completions", json=payload_dict, stream=True)
137
+
138
+ # Check response status
139
+ if response.status_code == 422:
140
+ raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
141
+ elif response.status_code == 400:
142
+ raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
143
+ elif response.status_code == 403:
144
+ raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
145
+ elif response.status_code == 404:
146
+ raise HTTPException(status_code=404, detail="The requested resource was not found.")
147
+ elif response.status_code >= 500:
148
+ raise HTTPException(status_code=500, detail="Server error. Try again later.")
149
+
150
+ # Stream response lines to the client
151
+ for line in response.iter_lines():
152
+ if line:
153
+ yield line.decode('utf-8') + "\n"
154
+
155
+ except requests.exceptions.RequestException as req_err:
156
+ # Handle request-specific errors
157
+ raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
158
+ except Exception as e:
159
+ # Handle unexpected errors
160
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
 
163
  @app.on_event("startup")
164
  async def startup_event():
165
  print("API endpoints:")