ParthSadaria commited on
Commit
eefae44
·
verified ·
1 Parent(s): 619453c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -90
main.py CHANGED
@@ -4,6 +4,8 @@ from fastapi import FastAPI, HTTPException, Request
4
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
5
  from pydantic import BaseModel
6
  import httpx
 
 
7
  from pathlib import Path # Import Path from pathlib
8
  import requests
9
  import re
@@ -149,71 +151,6 @@ async def get_models():
149
  async def fetch_models():
150
  return await get_models()
151
  server_status = True
152
- import asyncio
153
- import datetime
154
- import hashlib
155
- import json
156
- from functools import lru_cache
157
- from typing import Dict, Any
158
-
159
- # Create a simple in-memory cache with LRU (Least Recently Used) mechanism
160
- class RequestCache:
161
- def __init__(self, max_size=100, expires_after=300): # 100 entries, expires after 5 minutes
162
- self.cache: Dict[str, Dict[str, Any]] = {}
163
- self.max_size = max_size
164
- self.expires_after = expires_after
165
-
166
- def _generate_cache_key(self, payload_dict: Dict[str, Any]) -> str:
167
- """
168
- Generate a unique cache key based on the payload.
169
- Exclude any time-sensitive or dynamic fields.
170
- """
171
- # Create a deep copy to avoid modifying the original payload
172
- payload_for_hash = payload_dict.copy()
173
-
174
- # Remove fields that might change between identical requests
175
- payload_for_hash.pop('request_id', None)
176
- payload_for_hash.pop('timestamp', None)
177
-
178
- # Convert to a sorted, stable JSON string for hashing
179
- payload_json = json.dumps(payload_for_hash, sort_keys=True)
180
-
181
- # Create a hash of the payload
182
- return hashlib.md5(payload_json.encode()).hexdigest()
183
-
184
- def get(self, payload_dict: Dict[str, Any]) -> Any:
185
- """
186
- Retrieve cached response if it exists and is not expired.
187
- """
188
- cache_key = self._generate_cache_key(payload_dict)
189
-
190
- # Check if cache entry exists and is not expired
191
- if (cache_key in self.cache and
192
- (datetime.datetime.now() - self.cache[cache_key]['timestamp']).total_seconds() < self.expires_after):
193
- return self.cache[cache_key]['response']
194
-
195
- return None
196
-
197
- def set(self, payload_dict: Dict[str, Any], response: Any):
198
- """
199
- Store response in cache, managing cache size.
200
- """
201
- cache_key = self._generate_cache_key(payload_dict)
202
-
203
- # Remove oldest entry if cache is full
204
- if len(self.cache) >= self.max_size:
205
- oldest_key = min(self.cache, key=lambda k: self.cache[k]['timestamp'])
206
- del self.cache[oldest_key]
207
-
208
- # Store the response with a timestamp
209
- self.cache[cache_key] = {
210
- 'response': response,
211
- 'timestamp': datetime.datetime.now()
212
- }
213
-
214
- # Global cache instance
215
- request_cache = RequestCache()
216
-
217
  @app.post("/chat/completions")
218
  @app.post("api/v1/chat/completions")
219
  async def get_completion(payload: Payload, request: Request):
@@ -232,20 +169,12 @@ async def get_completion(payload: Payload, request: Request):
232
  status_code=400,
233
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
234
  )
235
-
236
- # Convert payload to dictionary for caching
 
 
237
  payload_dict = payload.dict()
238
  payload_dict["model"] = model_to_use
239
-
240
- # Check cache first
241
- cached_response = request_cache.get(payload_dict)
242
- if cached_response:
243
- return StreamingResponse(
244
- (line for line in cached_response),
245
- media_type="application/json"
246
- )
247
-
248
- usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
249
 
250
  # Select the appropriate endpoint
251
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
@@ -253,19 +182,19 @@ async def get_completion(payload: Payload, request: Request):
253
  # Current time and IP logging
254
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
255
  aaip = request.client.host
256
- print(f"Time: {current_time}, {aaip}")
257
-
258
  scraper = cloudscraper.create_scraper()
259
  async def stream_generator(payload_dict):
 
260
  # Prepare custom headers
261
  custom_headers = {
262
  'DNT': '1',
 
263
  'Priority': 'u=1, i',
 
264
  }
265
 
266
- # Collect response lines to cache
267
- response_lines = []
268
-
269
  try:
270
  # Send POST request using CloudScraper with custom headers
271
  response = scraper.post(
@@ -287,16 +216,10 @@ async def get_completion(payload: Payload, request: Request):
287
  elif response.status_code >= 500:
288
  raise HTTPException(status_code=500, detail="Server error. Try again later.")
289
 
290
- # Stream response lines to the client and collect for caching
291
  for line in response.iter_lines():
292
  if line:
293
- decoded_line = line.decode('utf-8') + "\n"
294
- response_lines.append(decoded_line)
295
- yield decoded_line
296
-
297
- # Cache the entire response after successful streaming
298
- request_cache.set(payload_dict, response_lines)
299
-
300
  except requests.exceptions.RequestException as req_err:
301
  # Handle request-specific errors
302
  print(response.text)
 
4
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
5
  from pydantic import BaseModel
6
  import httpx
7
+ import hashlib
8
+ from functools import lru_cache
9
  from pathlib import Path # Import Path from pathlib
10
  import requests
11
  import re
 
151
  async def fetch_models():
152
  return await get_models()
153
  server_status = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  @app.post("/chat/completions")
155
  @app.post("api/v1/chat/completions")
156
  async def get_completion(payload: Payload, request: Request):
 
169
  status_code=400,
170
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
171
  )
172
+
173
+ usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
174
+
175
+ # Prepare payload
176
  payload_dict = payload.dict()
177
  payload_dict["model"] = model_to_use
 
 
 
 
 
 
 
 
 
 
178
 
179
  # Select the appropriate endpoint
180
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
 
182
  # Current time and IP logging
183
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
184
  aaip = request.client.host
185
+ print(f"Time: {current_time}, {aaip} , {model_to_use}")
186
+ # print(payload_dict)
187
  scraper = cloudscraper.create_scraper()
188
  async def stream_generator(payload_dict):
189
+
190
  # Prepare custom headers
191
  custom_headers = {
192
  'DNT': '1',
193
+ # 'Origin': ENDPOINT_ORIGIN,
194
  'Priority': 'u=1, i',
195
+ # 'Referer': ENDPOINT_ORIGIN
196
  }
197
 
 
 
 
198
  try:
199
  # Send POST request using CloudScraper with custom headers
200
  response = scraper.post(
 
216
  elif response.status_code >= 500:
217
  raise HTTPException(status_code=500, detail="Server error. Try again later.")
218
 
219
+ # Stream response lines to the client
220
  for line in response.iter_lines():
221
  if line:
222
+ yield line.decode('utf-8') + "\n"
 
 
 
 
 
 
223
  except requests.exceptions.RequestException as req_err:
224
  # Handle request-specific errors
225
  print(response.text)