ParthSadaria commited on
Commit
619453c
·
verified ·
1 Parent(s): 2ab4453

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -11
main.py CHANGED
@@ -144,11 +144,76 @@ async def get_models():
144
  raise HTTPException(status_code=404, detail="models.json not found")
145
  except json.JSONDecodeError:
146
  raise HTTPException(status_code=500, detail="Error decoding models.json")
147
-
148
  @app.get("/models")
149
  async def fetch_models():
150
  return await get_models()
151
  server_status = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  @app.post("/chat/completions")
153
  @app.post("api/v1/chat/completions")
154
  async def get_completion(payload: Payload, request: Request):
@@ -167,12 +232,20 @@ async def get_completion(payload: Payload, request: Request):
167
  status_code=400,
168
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
169
  )
170
-
171
- usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
172
-
173
- # Prepare payload
174
  payload_dict = payload.dict()
175
  payload_dict["model"] = model_to_use
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Select the appropriate endpoint
178
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
@@ -181,18 +254,18 @@ async def get_completion(payload: Payload, request: Request):
181
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
182
  aaip = request.client.host
183
  print(f"Time: {current_time}, {aaip}")
184
- # print(payload_dict)
185
  scraper = cloudscraper.create_scraper()
186
  async def stream_generator(payload_dict):
187
-
188
  # Prepare custom headers
189
  custom_headers = {
190
  'DNT': '1',
191
- # 'Origin': ENDPOINT_ORIGIN,
192
  'Priority': 'u=1, i',
193
- # 'Referer': ENDPOINT_ORIGIN
194
  }
195
 
 
 
 
196
  try:
197
  # Send POST request using CloudScraper with custom headers
198
  response = scraper.post(
@@ -214,10 +287,16 @@ async def get_completion(payload: Payload, request: Request):
214
  elif response.status_code >= 500:
215
  raise HTTPException(status_code=500, detail="Server error. Try again later.")
216
 
217
- # Stream response lines to the client
218
  for line in response.iter_lines():
219
  if line:
220
- yield line.decode('utf-8') + "\n"
 
 
 
 
 
 
221
  except requests.exceptions.RequestException as req_err:
222
  # Handle request-specific errors
223
  print(response.text)
 
144
  raise HTTPException(status_code=404, detail="models.json not found")
145
  except json.JSONDecodeError:
146
  raise HTTPException(status_code=500, detail="Error decoding models.json")
147
+ @app.get("api/v1/models")
148
  @app.get("/models")
149
  async def fetch_models():
150
  return await get_models()
151
  server_status = True
152
+ import asyncio
153
+ import datetime
154
+ import hashlib
155
+ import json
156
+ from functools import lru_cache
157
+ from typing import Dict, Any
158
+
159
+ # Create a simple in-memory cache with LRU (Least Recently Used) mechanism
160
+ class RequestCache:
161
+ def __init__(self, max_size=100, expires_after=300): # 100 entries, expires after 5 minutes
162
+ self.cache: Dict[str, Dict[str, Any]] = {}
163
+ self.max_size = max_size
164
+ self.expires_after = expires_after
165
+
166
+ def _generate_cache_key(self, payload_dict: Dict[str, Any]) -> str:
167
+ """
168
+ Generate a unique cache key based on the payload.
169
+ Exclude any time-sensitive or dynamic fields.
170
+ """
171
+ # Create a deep copy to avoid modifying the original payload
172
+ payload_for_hash = payload_dict.copy()
173
+
174
+ # Remove fields that might change between identical requests
175
+ payload_for_hash.pop('request_id', None)
176
+ payload_for_hash.pop('timestamp', None)
177
+
178
+ # Convert to a sorted, stable JSON string for hashing
179
+ payload_json = json.dumps(payload_for_hash, sort_keys=True)
180
+
181
+ # Create a hash of the payload
182
+ return hashlib.md5(payload_json.encode()).hexdigest()
183
+
184
+ def get(self, payload_dict: Dict[str, Any]) -> Any:
185
+ """
186
+ Retrieve cached response if it exists and is not expired.
187
+ """
188
+ cache_key = self._generate_cache_key(payload_dict)
189
+
190
+ # Check if cache entry exists and is not expired
191
+ if (cache_key in self.cache and
192
+ (datetime.datetime.now() - self.cache[cache_key]['timestamp']).total_seconds() < self.expires_after):
193
+ return self.cache[cache_key]['response']
194
+
195
+ return None
196
+
197
+ def set(self, payload_dict: Dict[str, Any], response: Any):
198
+ """
199
+ Store response in cache, managing cache size.
200
+ """
201
+ cache_key = self._generate_cache_key(payload_dict)
202
+
203
+ # Remove oldest entry if cache is full
204
+ if len(self.cache) >= self.max_size:
205
+ oldest_key = min(self.cache, key=lambda k: self.cache[k]['timestamp'])
206
+ del self.cache[oldest_key]
207
+
208
+ # Store the response with a timestamp
209
+ self.cache[cache_key] = {
210
+ 'response': response,
211
+ 'timestamp': datetime.datetime.now()
212
+ }
213
+
214
+ # Global cache instance
215
+ request_cache = RequestCache()
216
+
217
  @app.post("/chat/completions")
218
  @app.post("api/v1/chat/completions")
219
  async def get_completion(payload: Payload, request: Request):
 
232
  status_code=400,
233
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
234
  )
235
+
236
+ # Convert payload to dictionary for caching
 
 
237
  payload_dict = payload.dict()
238
  payload_dict["model"] = model_to_use
239
+
240
+ # Check cache first
241
+ cached_response = request_cache.get(payload_dict)
242
+ if cached_response:
243
+ return StreamingResponse(
244
+ (line for line in cached_response),
245
+ media_type="application/json"
246
+ )
247
+
248
+ usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
249
 
250
  # Select the appropriate endpoint
251
  endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
 
254
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
255
  aaip = request.client.host
256
  print(f"Time: {current_time}, {aaip}")
257
+
258
  scraper = cloudscraper.create_scraper()
259
  async def stream_generator(payload_dict):
 
260
  # Prepare custom headers
261
  custom_headers = {
262
  'DNT': '1',
 
263
  'Priority': 'u=1, i',
 
264
  }
265
 
266
+ # Collect response lines to cache
267
+ response_lines = []
268
+
269
  try:
270
  # Send POST request using CloudScraper with custom headers
271
  response = scraper.post(
 
287
  elif response.status_code >= 500:
288
  raise HTTPException(status_code=500, detail="Server error. Try again later.")
289
 
290
+ # Stream response lines to the client and collect for caching
291
  for line in response.iter_lines():
292
  if line:
293
+ decoded_line = line.decode('utf-8') + "\n"
294
+ response_lines.append(decoded_line)
295
+ yield decoded_line
296
+
297
+ # Cache the entire response after successful streaming
298
+ request_cache.set(payload_dict, response_lines)
299
+
300
  except requests.exceptions.RequestException as req_err:
301
  # Handle request-specific errors
302
  print(response.text)