gnilets commited on
Commit
0bd4df8
·
verified ·
1 Parent(s): 3d1942f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -113,7 +113,7 @@ async def proxy_openai_api(request: Request):
113
 
114
  request_body = await request.json() if request.method in {'POST', 'PUT'} else None
115
 
116
- async def stream_api_response(api_key: str):
117
  update_authorization_header(api_key)
118
  try:
119
  streaming = client.stream(request.method, COMPLETIONS_URL, headers=headers, params=request.query_params, json=request_body)
@@ -135,13 +135,19 @@ async def proxy_openai_api(request: Request):
135
  for api_key in API_KEYS:
136
  response_generator = stream_api_response(api_key)
137
  try:
138
- first_chunk = await response_generator.__anext__()
139
  if first_chunk == 'auth_error':
140
  print(f'ключ API {api_key} недействителен или превышен лимит отправки запросов')
141
  continue
142
  else:
143
  headers_to_forward = {k: v for k, v in headers.items() if k.lower() not in {'content-length', 'content-encoding', 'alt-svc'}}
144
- return OverrideStreamResponse(chain([first_chunk], response_generator), headers=headers_to_forward)
 
 
 
 
 
 
145
  except StopAsyncIteration:
146
  continue
147
  raise HTTPException(status_code=401, detail='все ключи API использованы, доступ запрещен.')
 
113
 
114
  request_body = await request.json() if request.method in {'POST', 'PUT'} else None
115
 
116
+ async def stream_api_response(api_key: str) -> AsyncIterable[str]:
117
  update_authorization_header(api_key)
118
  try:
119
  streaming = client.stream(request.method, COMPLETIONS_URL, headers=headers, params=request.query_params, json=request_body)
 
135
  for api_key in API_KEYS:
136
  response_generator = stream_api_response(api_key)
137
  try:
138
+ first_chunk = await anext(response_generator)
139
  if first_chunk == 'auth_error':
140
  print(f'ключ API {api_key} недействителен или превышен лимит отправки запросов')
141
  continue
142
  else:
143
  headers_to_forward = {k: v for k, v in headers.items() if k.lower() not in {'content-length', 'content-encoding', 'alt-svc'}}
144
+
145
+ async def combined_generator():
146
+ yield first_chunk
147
+ async for chunk in response_generator:
148
+ yield chunk
149
+
150
+ return OverrideStreamResponse(combined_generator(), headers=headers_to_forward)
151
  except StopAsyncIteration:
152
  continue
153
  raise HTTPException(status_code=401, detail='все ключи API использованы, доступ запрещен.')