bibibi12345 commited on
Commit
3c9b1bd
·
verified ·
1 Parent(s): ef5e32f

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +78 -65
app/main.py CHANGED
@@ -184,9 +184,9 @@ class OpenAIRequest(BaseModel):
184
  # Allow extra fields to pass through without causing validation errors
185
  model_config = ConfigDict(extra='allow')
186
 
187
- # Configure authentication
188
  def init_vertex_ai():
189
- global client # Ensure we modify the global client variable
190
  try:
191
  # Priority 1: Check for credentials JSON content in environment variable (Hugging Face)
192
  credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
@@ -224,57 +224,52 @@ def init_vertex_ai():
224
 
225
  # Initialize the client with the credentials
226
  try:
227
- client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
228
- # print(f"Initialized Vertex AI using GOOGLE_CREDENTIALS_JSON env var for project: {project_id}") # Reduced verbosity
 
 
 
 
 
 
229
  except Exception as client_err:
230
- print(f"ERROR: Failed to initialize genai.Client from GOOGLE_CREDENTIALS_JSON: {client_err}") # Added context
231
  raise
232
- return True
233
  except Exception as e:
234
- # print(f"Error loading credentials from GOOGLE_CREDENTIALS_JSON: {e}") # Reduced verbosity, error logged above
235
- pass # Add pass to avoid empty block error
236
  # Fall through to other methods if this fails
237
 
238
  # Priority 2: Try to use the credential manager to get credentials from files
239
  # print(f"Trying credential manager (directory: {credential_manager.credentials_dir})") # Reduced verbosity
240
- credentials, project_id = credential_manager.get_next_credentials()
241
-
242
- if credentials and project_id:
 
 
 
 
243
  try:
244
- client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
245
- # print(f"Initialized Vertex AI using Credential Manager for project: {project_id}") # Reduced verbosity
246
- return True
 
 
 
 
 
247
  except Exception as e:
248
- print(f"ERROR: Failed to initialize client with credentials from Credential Manager file ({credential_manager.credentials_dir}): {e}") # Added context
249
-
250
- # Priority 3: Fall back to GOOGLE_APPLICATION_CREDENTIALS environment variable (file path)
251
- file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
252
- if file_path:
253
- # print(f"Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}") # Reduced verbosity
254
- if os.path.exists(file_path):
255
- try:
256
- # print(f"File exists, attempting to load credentials") # Reduced verbosity
257
- credentials = service_account.Credentials.from_service_account_file(
258
- file_path,
259
- scopes=['https://www.googleapis.com/auth/cloud-platform']
260
- )
261
- project_id = credentials.project_id
262
- print(f"Successfully loaded credentials from file for project: {project_id}")
263
-
264
- try:
265
- client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
266
- # print(f"Initialized Vertex AI using GOOGLE_APPLICATION_CREDENTIALS file path for project: {project_id}") # Reduced verbosity
267
- return True
268
- except Exception as client_err:
269
- print(f"ERROR: Failed to initialize client with credentials from GOOGLE_APPLICATION_CREDENTIALS file ({file_path}): {client_err}") # Added context
270
- except Exception as e:
271
- print(f"ERROR: Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}") # Added context
272
- else:
273
- print(f"ERROR: GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
274
 
275
  # If none of the methods worked, this error is still useful
276
- # print(f"ERROR: No valid credentials found. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager ({credential_manager.credentials_dir}), and GOOGLE_APPLICATION_CREDENTIALS.")
277
- return False
 
 
 
 
 
278
  except Exception as e:
279
  print(f"Error initializing authentication: {e}")
280
  return False
@@ -283,9 +278,9 @@ def init_vertex_ai():
283
  @app.on_event("startup")
284
  async def startup_event():
285
  if init_vertex_ai():
286
- print("INFO: Vertex AI client successfully initialized.")
287
  else:
288
- print("ERROR: Failed to initialize Vertex AI client. Please check credential configuration (GOOGLE_CREDENTIALS_JSON, /app/credentials/*.json, or GOOGLE_APPLICATION_CREDENTIALS) and logs for details.")
289
 
290
  # Conversion functions
291
  # Define supported roles for Gemini API
@@ -651,11 +646,11 @@ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
651
  config["stop_sequences"] = request.stop
652
 
653
  # Additional parameters with direct mappings
654
- if request.presence_penalty is not None:
655
- config["presence_penalty"] = request.presence_penalty
656
 
657
- if request.frequency_penalty is not None:
658
- config["frequency_penalty"] = request.frequency_penalty
659
 
660
  if request.seed is not None:
661
  config["seed"] = request.seed
@@ -988,7 +983,7 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
988
  }
989
 
990
  @app.post("/v1/chat/completions")
991
- async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)):
992
  try:
993
  # Validate model availability
994
  models_response = await list_models()
@@ -1016,14 +1011,32 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1016
  # Create generation config
1017
  generation_config = create_generation_config(request)
1018
 
1019
- # Use the globally initialized client (from startup)
1020
- global client
1021
- if client is None:
1022
- error_response = create_openai_error_response(
1023
- 500, "Vertex AI client not initialized", "server_error"
1024
- )
1025
- return JSONResponse(status_code=500, content=error_response)
1026
- print(f"Using globally initialized client.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1027
 
1028
  # Common safety settings
1029
  safety_settings = [
@@ -1036,7 +1049,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1036
 
1037
 
1038
  # --- Helper function to make the API call (handles stream/non-stream) ---
1039
- async def make_gemini_call(model_name, prompt_func, current_gen_config):
1040
  prompt = prompt_func(request.messages)
1041
 
1042
  # Log prompt structure
@@ -1058,7 +1071,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1058
  # Check if fake streaming is enabled (directly from environment variable)
1059
  fake_streaming = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
1060
  if fake_streaming:
1061
- return await fake_stream_generator(model_name, prompt, current_gen_config, request)
1062
 
1063
  # Regular streaming call
1064
  response_id = f"chatcmpl-{int(time.time())}"
@@ -1070,7 +1083,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1070
  try:
1071
  for candidate_index in range(candidate_count):
1072
  print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1073
- responses = await client.aio.models.generate_content_stream(
1074
  model=model_name,
1075
  contents=prompt,
1076
  config=current_gen_config,
@@ -1109,7 +1122,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1109
  # Non-streaming call
1110
  try:
1111
  print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1112
- response = await client.aio.models.generate_content(
1113
  model=model_name,
1114
  contents=prompt,
1115
  config=current_gen_config,
@@ -1152,7 +1165,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1152
  current_config = attempt["config_modifier"](generation_config.copy())
1153
 
1154
  try:
1155
- result = await make_gemini_call(attempt["model"], attempt["prompt_func"], current_config)
1156
 
1157
  # For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
1158
  # For non-streaming, if make_gemini_call doesn't raise, it's successful
@@ -1230,7 +1243,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1230
  current_config["system_instruction"] = encryption_instructions
1231
 
1232
  try:
1233
- result = await make_gemini_call(current_model_name, current_prompt_func, current_config)
1234
  return result
1235
  except Exception as e:
1236
  # Handle potential errors for non-auto models
@@ -1326,7 +1339,7 @@ def is_response_valid(response):
1326
  return False
1327
 
1328
  # --- Fake streaming implementation ---
1329
- async def fake_stream_generator(model_name, prompt, current_gen_config, request):
1330
  """
1331
  Simulates streaming by making a non-streaming API call and chunking the response.
1332
  While waiting for the response, sends keep-alive messages to the client.
@@ -1337,7 +1350,7 @@ async def fake_stream_generator(model_name, prompt, current_gen_config, request)
1337
  # Create a task for the non-streaming API call
1338
  print(f"FAKE STREAMING: Making non-streaming request to Gemini API (Model: {model_name})")
1339
  api_call_task = asyncio.create_task(
1340
- client.aio.models.generate_content(
1341
  model=model_name,
1342
  contents=prompt,
1343
  config=current_gen_config,
 
184
  # Allow extra fields to pass through without causing validation errors
185
  model_config = ConfigDict(extra='allow')
186
 
187
+ # Configure authentication - Initializes a fallback client and validates credential sources
188
  def init_vertex_ai():
189
+ global client # This will hold the fallback client if initialized
190
  try:
191
  # Priority 1: Check for credentials JSON content in environment variable (Hugging Face)
192
  credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
 
224
 
225
  # Initialize the client with the credentials
226
  try:
227
+ # Initialize the global client ONLY if it hasn't been set yet
228
+ if client is None:
229
+ client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
230
+ print(f"INFO: Initialized fallback Vertex AI client using GOOGLE_CREDENTIALS_JSON env var for project: {project_id}")
231
+ else:
232
+ print(f"INFO: Fallback client already initialized. GOOGLE_CREDENTIALS_JSON credentials validated for project: {project_id}")
233
+ # Even if client was already set, we return True because this method worked
234
+ return True
235
  except Exception as client_err:
236
+ print(f"ERROR: Failed to initialize genai.Client from GOOGLE_CREDENTIALS_JSON: {client_err}")
237
  raise
 
238
  except Exception as e:
239
+ print(f"WARNING: Error processing GOOGLE_CREDENTIALS_JSON: {e}. Will try other methods.")
 
240
  # Fall through to other methods if this fails
241
 
242
  # Priority 2: Try to use the credential manager to get credentials from files
243
  # print(f"Trying credential manager (directory: {credential_manager.credentials_dir})") # Reduced verbosity
244
+ # Priority 2: Try to use the credential manager to get credentials from files
245
+ # We call get_next_credentials here mainly to validate it works and log the first file found
246
+ # The actual rotation happens per-request
247
+ print(f"INFO: Checking Credential Manager (directory: {credential_manager.credentials_dir})")
248
+ cm_credentials, cm_project_id = credential_manager.get_next_credentials() # Use temp vars
249
+
250
+ if cm_credentials and cm_project_id:
251
  try:
252
+ # Initialize the global client ONLY if it hasn't been set yet
253
+ if client is None:
254
+ client = genai.Client(vertexai=True, credentials=cm_credentials, project=cm_project_id, location="us-central1")
255
+ print(f"INFO: Initialized fallback Vertex AI client using Credential Manager for project: {cm_project_id}")
256
+ return True # Successfully initialized global client
257
+ else:
258
+ print(f"INFO: Fallback client already initialized. Credential Manager validated for project: {cm_project_id}")
259
+ # Don't return True here if client was already set, let it fall through to check GAC
260
  except Exception as e:
261
+ print(f"ERROR: Failed to initialize client with credentials from Credential Manager file ({credential_manager.credentials_dir}): {e}")
262
+ else:
263
+ print(f"INFO: No credentials loaded via Credential Manager.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  # If none of the methods worked, this error is still useful
266
+ # If we reach here, either no method worked, or a prior method already initialized the client
267
+ if client is not None:
268
+ print("INFO: Fallback client initialization check complete.")
269
+ return True # A fallback client exists
270
+ else:
271
+ print(f"ERROR: No valid credentials found or failed to initialize client. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager ({credential_manager.credentials_dir}), and GOOGLE_APPLICATION_CREDENTIALS.")
272
+ return False
273
  except Exception as e:
274
  print(f"Error initializing authentication: {e}")
275
  return False
 
278
  @app.on_event("startup")
279
  async def startup_event():
280
  if init_vertex_ai():
281
+ print("INFO: Fallback Vertex AI client initialization check completed successfully.")
282
  else:
283
+ print("ERROR: Failed to initialize a fallback Vertex AI client. API will likely fail. Please check credential configuration (GOOGLE_CREDENTIALS_JSON, /app/credentials/*.json, or GOOGLE_APPLICATION_CREDENTIALS) and logs for details.")
284
 
285
  # Conversion functions
286
  # Define supported roles for Gemini API
 
646
  config["stop_sequences"] = request.stop
647
 
648
  # Additional parameters with direct mappings
649
+ # if request.presence_penalty is not None:
650
+ # config["presence_penalty"] = request.presence_penalty
651
 
652
+ # if request.frequency_penalty is not None:
653
+ # config["frequency_penalty"] = request.frequency_penalty
654
 
655
  if request.seed is not None:
656
  config["seed"] = request.seed
 
983
  }
984
 
985
  @app.post("/v1/chat/completions")
986
+ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
987
  try:
988
  # Validate model availability
989
  models_response = await list_models()
 
1011
  # Create generation config
1012
  generation_config = create_generation_config(request)
1013
 
1014
+ # --- Determine which client to use (Rotation or Fallback) ---
1015
+ client_to_use = None
1016
+ rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
1017
+
1018
+ if rotated_credentials and rotated_project_id:
1019
+ try:
1020
+ # Create a request-specific client using the rotated credentials
1021
+ client_to_use = genai.Client(vertexai=True, credentials=rotated_credentials, project=rotated_project_id, location="us-central1")
1022
+ print(f"INFO: Using rotated credential for project: {rotated_project_id} (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})") # Log which credential was used
1023
+ except Exception as e:
1024
+ print(f"ERROR: Failed to create client from rotated credential: {e}. Will attempt fallback.")
1025
+ client_to_use = None # Ensure it's None if creation failed
1026
+
1027
+ # If rotation failed or wasn't possible, try the fallback client
1028
+ if client_to_use is None:
1029
+ global client # Access the fallback client initialized at startup
1030
+ if client is not None:
1031
+ client_to_use = client
1032
+ print("INFO: Using fallback Vertex AI client.")
1033
+ else:
1034
+ # Critical error: No rotated client AND no fallback client
1035
+ error_response = create_openai_error_response(
1036
+ 500, "Vertex AI client not available (Rotation failed and no fallback)", "server_error"
1037
+ )
1038
+ return JSONResponse(status_code=500, content=error_response)
1039
+ # --- Client determined ---
1040
 
1041
  # Common safety settings
1042
  safety_settings = [
 
1049
 
1050
 
1051
  # --- Helper function to make the API call (handles stream/non-stream) ---
1052
+ async def make_gemini_call(client_instance, model_name, prompt_func, current_gen_config): # Add client_instance parameter
1053
  prompt = prompt_func(request.messages)
1054
 
1055
  # Log prompt structure
 
1071
  # Check if fake streaming is enabled (directly from environment variable)
1072
  fake_streaming = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
1073
  if fake_streaming:
1074
+ return await fake_stream_generator(client_instance, model_name, prompt, current_gen_config, request) # Pass client_instance
1075
 
1076
  # Regular streaming call
1077
  response_id = f"chatcmpl-{int(time.time())}"
 
1083
  try:
1084
  for candidate_index in range(candidate_count):
1085
  print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1086
+ responses = await client_instance.aio.models.generate_content_stream( # Use client_instance
1087
  model=model_name,
1088
  contents=prompt,
1089
  config=current_gen_config,
 
1122
  # Non-streaming call
1123
  try:
1124
  print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1125
+ response = await client_instance.aio.models.generate_content( # Use client_instance
1126
  model=model_name,
1127
  contents=prompt,
1128
  config=current_gen_config,
 
1165
  current_config = attempt["config_modifier"](generation_config.copy())
1166
 
1167
  try:
1168
+ result = await make_gemini_call(client_to_use, attempt["model"], attempt["prompt_func"], current_config) # Pass client_to_use
1169
 
1170
  # For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
1171
  # For non-streaming, if make_gemini_call doesn't raise, it's successful
 
1243
  current_config["system_instruction"] = encryption_instructions
1244
 
1245
  try:
1246
+ result = await make_gemini_call(client_to_use, current_model_name, current_prompt_func, current_config) # Pass client_to_use
1247
  return result
1248
  except Exception as e:
1249
  # Handle potential errors for non-auto models
 
1339
  return False
1340
 
1341
  # --- Fake streaming implementation ---
1342
+ async def fake_stream_generator(client_instance, model_name, prompt, current_gen_config, request): # Add client_instance parameter
1343
  """
1344
  Simulates streaming by making a non-streaming API call and chunking the response.
1345
  While waiting for the response, sends keep-alive messages to the client.
 
1350
  # Create a task for the non-streaming API call
1351
  print(f"FAKE STREAMING: Making non-streaming request to Gemini API (Model: {model_name})")
1352
  api_call_task = asyncio.create_task(
1353
+ client_instance.aio.models.generate_content( # Use client_instance
1354
  model=model_name,
1355
  contents=prompt,
1356
  config=current_gen_config,