Update app/main.py
Browse files- app/main.py +78 -65
app/main.py
CHANGED
@@ -184,9 +184,9 @@ class OpenAIRequest(BaseModel):
|
|
184 |
# Allow extra fields to pass through without causing validation errors
|
185 |
model_config = ConfigDict(extra='allow')
|
186 |
|
187 |
-
# Configure authentication
|
188 |
def init_vertex_ai():
|
189 |
-
global client #
|
190 |
try:
|
191 |
# Priority 1: Check for credentials JSON content in environment variable (Hugging Face)
|
192 |
credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
@@ -224,57 +224,52 @@ def init_vertex_ai():
|
|
224 |
|
225 |
# Initialize the client with the credentials
|
226 |
try:
|
227 |
-
client
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
except Exception as client_err:
|
230 |
-
print(f"ERROR: Failed to initialize genai.Client from GOOGLE_CREDENTIALS_JSON: {client_err}")
|
231 |
raise
|
232 |
-
return True
|
233 |
except Exception as e:
|
234 |
-
|
235 |
-
pass # Add pass to avoid empty block error
|
236 |
# Fall through to other methods if this fails
|
237 |
|
238 |
# Priority 2: Try to use the credential manager to get credentials from files
|
239 |
# print(f"Trying credential manager (directory: {credential_manager.credentials_dir})") # Reduced verbosity
|
240 |
-
credentials
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
243 |
try:
|
244 |
-
client
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
247 |
except Exception as e:
|
248 |
-
print(f"ERROR: Failed to initialize client with credentials from Credential Manager file ({credential_manager.credentials_dir}): {e}")
|
249 |
-
|
250 |
-
|
251 |
-
file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
252 |
-
if file_path:
|
253 |
-
# print(f"Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}") # Reduced verbosity
|
254 |
-
if os.path.exists(file_path):
|
255 |
-
try:
|
256 |
-
# print(f"File exists, attempting to load credentials") # Reduced verbosity
|
257 |
-
credentials = service_account.Credentials.from_service_account_file(
|
258 |
-
file_path,
|
259 |
-
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
260 |
-
)
|
261 |
-
project_id = credentials.project_id
|
262 |
-
print(f"Successfully loaded credentials from file for project: {project_id}")
|
263 |
-
|
264 |
-
try:
|
265 |
-
client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
|
266 |
-
# print(f"Initialized Vertex AI using GOOGLE_APPLICATION_CREDENTIALS file path for project: {project_id}") # Reduced verbosity
|
267 |
-
return True
|
268 |
-
except Exception as client_err:
|
269 |
-
print(f"ERROR: Failed to initialize client with credentials from GOOGLE_APPLICATION_CREDENTIALS file ({file_path}): {client_err}") # Added context
|
270 |
-
except Exception as e:
|
271 |
-
print(f"ERROR: Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}") # Added context
|
272 |
-
else:
|
273 |
-
print(f"ERROR: GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
|
274 |
|
275 |
# If none of the methods worked, this error is still useful
|
276 |
-
#
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
278 |
except Exception as e:
|
279 |
print(f"Error initializing authentication: {e}")
|
280 |
return False
|
@@ -283,9 +278,9 @@ def init_vertex_ai():
|
|
283 |
@app.on_event("startup")
|
284 |
async def startup_event():
|
285 |
if init_vertex_ai():
|
286 |
-
print("INFO: Vertex AI client successfully
|
287 |
else:
|
288 |
-
print("ERROR: Failed to initialize Vertex AI client. Please check credential configuration (GOOGLE_CREDENTIALS_JSON, /app/credentials/*.json, or GOOGLE_APPLICATION_CREDENTIALS) and logs for details.")
|
289 |
|
290 |
# Conversion functions
|
291 |
# Define supported roles for Gemini API
|
@@ -651,11 +646,11 @@ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
|
|
651 |
config["stop_sequences"] = request.stop
|
652 |
|
653 |
# Additional parameters with direct mappings
|
654 |
-
if request.presence_penalty is not None:
|
655 |
-
|
656 |
|
657 |
-
if request.frequency_penalty is not None:
|
658 |
-
|
659 |
|
660 |
if request.seed is not None:
|
661 |
config["seed"] = request.seed
|
@@ -988,7 +983,7 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
|
|
988 |
}
|
989 |
|
990 |
@app.post("/v1/chat/completions")
|
991 |
-
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)):
|
992 |
try:
|
993 |
# Validate model availability
|
994 |
models_response = await list_models()
|
@@ -1016,14 +1011,32 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1016 |
# Create generation config
|
1017 |
generation_config = create_generation_config(request)
|
1018 |
|
1019 |
-
#
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
|
1028 |
# Common safety settings
|
1029 |
safety_settings = [
|
@@ -1036,7 +1049,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1036 |
|
1037 |
|
1038 |
# --- Helper function to make the API call (handles stream/non-stream) ---
|
1039 |
-
async def make_gemini_call(model_name, prompt_func, current_gen_config):
|
1040 |
prompt = prompt_func(request.messages)
|
1041 |
|
1042 |
# Log prompt structure
|
@@ -1058,7 +1071,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1058 |
# Check if fake streaming is enabled (directly from environment variable)
|
1059 |
fake_streaming = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
|
1060 |
if fake_streaming:
|
1061 |
-
return await fake_stream_generator(model_name, prompt, current_gen_config, request)
|
1062 |
|
1063 |
# Regular streaming call
|
1064 |
response_id = f"chatcmpl-{int(time.time())}"
|
@@ -1070,7 +1083,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1070 |
try:
|
1071 |
for candidate_index in range(candidate_count):
|
1072 |
print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1073 |
-
responses = await
|
1074 |
model=model_name,
|
1075 |
contents=prompt,
|
1076 |
config=current_gen_config,
|
@@ -1109,7 +1122,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1109 |
# Non-streaming call
|
1110 |
try:
|
1111 |
print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1112 |
-
response = await
|
1113 |
model=model_name,
|
1114 |
contents=prompt,
|
1115 |
config=current_gen_config,
|
@@ -1152,7 +1165,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1152 |
current_config = attempt["config_modifier"](generation_config.copy())
|
1153 |
|
1154 |
try:
|
1155 |
-
result = await make_gemini_call(attempt["model"], attempt["prompt_func"], current_config)
|
1156 |
|
1157 |
# For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
|
1158 |
# For non-streaming, if make_gemini_call doesn't raise, it's successful
|
@@ -1230,7 +1243,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1230 |
current_config["system_instruction"] = encryption_instructions
|
1231 |
|
1232 |
try:
|
1233 |
-
result = await make_gemini_call(current_model_name, current_prompt_func, current_config)
|
1234 |
return result
|
1235 |
except Exception as e:
|
1236 |
# Handle potential errors for non-auto models
|
@@ -1326,7 +1339,7 @@ def is_response_valid(response):
|
|
1326 |
return False
|
1327 |
|
1328 |
# --- Fake streaming implementation ---
|
1329 |
-
async def fake_stream_generator(model_name, prompt, current_gen_config, request):
|
1330 |
"""
|
1331 |
Simulates streaming by making a non-streaming API call and chunking the response.
|
1332 |
While waiting for the response, sends keep-alive messages to the client.
|
@@ -1337,7 +1350,7 @@ async def fake_stream_generator(model_name, prompt, current_gen_config, request)
|
|
1337 |
# Create a task for the non-streaming API call
|
1338 |
print(f"FAKE STREAMING: Making non-streaming request to Gemini API (Model: {model_name})")
|
1339 |
api_call_task = asyncio.create_task(
|
1340 |
-
|
1341 |
model=model_name,
|
1342 |
contents=prompt,
|
1343 |
config=current_gen_config,
|
|
|
184 |
# Allow extra fields to pass through without causing validation errors
|
185 |
model_config = ConfigDict(extra='allow')
|
186 |
|
187 |
+
# Configure authentication - Initializes a fallback client and validates credential sources
|
188 |
def init_vertex_ai():
|
189 |
+
global client # This will hold the fallback client if initialized
|
190 |
try:
|
191 |
# Priority 1: Check for credentials JSON content in environment variable (Hugging Face)
|
192 |
credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
|
|
224 |
|
225 |
# Initialize the client with the credentials
|
226 |
try:
|
227 |
+
# Initialize the global client ONLY if it hasn't been set yet
|
228 |
+
if client is None:
|
229 |
+
client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
|
230 |
+
print(f"INFO: Initialized fallback Vertex AI client using GOOGLE_CREDENTIALS_JSON env var for project: {project_id}")
|
231 |
+
else:
|
232 |
+
print(f"INFO: Fallback client already initialized. GOOGLE_CREDENTIALS_JSON credentials validated for project: {project_id}")
|
233 |
+
# Even if client was already set, we return True because this method worked
|
234 |
+
return True
|
235 |
except Exception as client_err:
|
236 |
+
print(f"ERROR: Failed to initialize genai.Client from GOOGLE_CREDENTIALS_JSON: {client_err}")
|
237 |
raise
|
|
|
238 |
except Exception as e:
|
239 |
+
print(f"WARNING: Error processing GOOGLE_CREDENTIALS_JSON: {e}. Will try other methods.")
|
|
|
240 |
# Fall through to other methods if this fails
|
241 |
|
242 |
# Priority 2: Try to use the credential manager to get credentials from files
|
243 |
# print(f"Trying credential manager (directory: {credential_manager.credentials_dir})") # Reduced verbosity
|
244 |
+
# Priority 2: Try to use the credential manager to get credentials from files
|
245 |
+
# We call get_next_credentials here mainly to validate it works and log the first file found
|
246 |
+
# The actual rotation happens per-request
|
247 |
+
print(f"INFO: Checking Credential Manager (directory: {credential_manager.credentials_dir})")
|
248 |
+
cm_credentials, cm_project_id = credential_manager.get_next_credentials() # Use temp vars
|
249 |
+
|
250 |
+
if cm_credentials and cm_project_id:
|
251 |
try:
|
252 |
+
# Initialize the global client ONLY if it hasn't been set yet
|
253 |
+
if client is None:
|
254 |
+
client = genai.Client(vertexai=True, credentials=cm_credentials, project=cm_project_id, location="us-central1")
|
255 |
+
print(f"INFO: Initialized fallback Vertex AI client using Credential Manager for project: {cm_project_id}")
|
256 |
+
return True # Successfully initialized global client
|
257 |
+
else:
|
258 |
+
print(f"INFO: Fallback client already initialized. Credential Manager validated for project: {cm_project_id}")
|
259 |
+
# Don't return True here if client was already set, let it fall through to check GAC
|
260 |
except Exception as e:
|
261 |
+
print(f"ERROR: Failed to initialize client with credentials from Credential Manager file ({credential_manager.credentials_dir}): {e}")
|
262 |
+
else:
|
263 |
+
print(f"INFO: No credentials loaded via Credential Manager.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
# If none of the methods worked, this error is still useful
|
266 |
+
# If we reach here, either no method worked, or a prior method already initialized the client
|
267 |
+
if client is not None:
|
268 |
+
print("INFO: Fallback client initialization check complete.")
|
269 |
+
return True # A fallback client exists
|
270 |
+
else:
|
271 |
+
print(f"ERROR: No valid credentials found or failed to initialize client. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager ({credential_manager.credentials_dir}), and GOOGLE_APPLICATION_CREDENTIALS.")
|
272 |
+
return False
|
273 |
except Exception as e:
|
274 |
print(f"Error initializing authentication: {e}")
|
275 |
return False
|
|
|
278 |
@app.on_event("startup")
|
279 |
async def startup_event():
|
280 |
if init_vertex_ai():
|
281 |
+
print("INFO: Fallback Vertex AI client initialization check completed successfully.")
|
282 |
else:
|
283 |
+
print("ERROR: Failed to initialize a fallback Vertex AI client. API will likely fail. Please check credential configuration (GOOGLE_CREDENTIALS_JSON, /app/credentials/*.json, or GOOGLE_APPLICATION_CREDENTIALS) and logs for details.")
|
284 |
|
285 |
# Conversion functions
|
286 |
# Define supported roles for Gemini API
|
|
|
646 |
config["stop_sequences"] = request.stop
|
647 |
|
648 |
# Additional parameters with direct mappings
|
649 |
+
# if request.presence_penalty is not None:
|
650 |
+
# config["presence_penalty"] = request.presence_penalty
|
651 |
|
652 |
+
# if request.frequency_penalty is not None:
|
653 |
+
# config["frequency_penalty"] = request.frequency_penalty
|
654 |
|
655 |
if request.seed is not None:
|
656 |
config["seed"] = request.seed
|
|
|
983 |
}
|
984 |
|
985 |
@app.post("/v1/chat/completions")
|
986 |
+
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
|
987 |
try:
|
988 |
# Validate model availability
|
989 |
models_response = await list_models()
|
|
|
1011 |
# Create generation config
|
1012 |
generation_config = create_generation_config(request)
|
1013 |
|
1014 |
+
# --- Determine which client to use (Rotation or Fallback) ---
|
1015 |
+
client_to_use = None
|
1016 |
+
rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
|
1017 |
+
|
1018 |
+
if rotated_credentials and rotated_project_id:
|
1019 |
+
try:
|
1020 |
+
# Create a request-specific client using the rotated credentials
|
1021 |
+
client_to_use = genai.Client(vertexai=True, credentials=rotated_credentials, project=rotated_project_id, location="us-central1")
|
1022 |
+
print(f"INFO: Using rotated credential for project: {rotated_project_id} (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})") # Log which credential was used
|
1023 |
+
except Exception as e:
|
1024 |
+
print(f"ERROR: Failed to create client from rotated credential: {e}. Will attempt fallback.")
|
1025 |
+
client_to_use = None # Ensure it's None if creation failed
|
1026 |
+
|
1027 |
+
# If rotation failed or wasn't possible, try the fallback client
|
1028 |
+
if client_to_use is None:
|
1029 |
+
global client # Access the fallback client initialized at startup
|
1030 |
+
if client is not None:
|
1031 |
+
client_to_use = client
|
1032 |
+
print("INFO: Using fallback Vertex AI client.")
|
1033 |
+
else:
|
1034 |
+
# Critical error: No rotated client AND no fallback client
|
1035 |
+
error_response = create_openai_error_response(
|
1036 |
+
500, "Vertex AI client not available (Rotation failed and no fallback)", "server_error"
|
1037 |
+
)
|
1038 |
+
return JSONResponse(status_code=500, content=error_response)
|
1039 |
+
# --- Client determined ---
|
1040 |
|
1041 |
# Common safety settings
|
1042 |
safety_settings = [
|
|
|
1049 |
|
1050 |
|
1051 |
# --- Helper function to make the API call (handles stream/non-stream) ---
|
1052 |
+
async def make_gemini_call(client_instance, model_name, prompt_func, current_gen_config): # Add client_instance parameter
|
1053 |
prompt = prompt_func(request.messages)
|
1054 |
|
1055 |
# Log prompt structure
|
|
|
1071 |
# Check if fake streaming is enabled (directly from environment variable)
|
1072 |
fake_streaming = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
|
1073 |
if fake_streaming:
|
1074 |
+
return await fake_stream_generator(client_instance, model_name, prompt, current_gen_config, request) # Pass client_instance
|
1075 |
|
1076 |
# Regular streaming call
|
1077 |
response_id = f"chatcmpl-{int(time.time())}"
|
|
|
1083 |
try:
|
1084 |
for candidate_index in range(candidate_count):
|
1085 |
print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1086 |
+
responses = await client_instance.aio.models.generate_content_stream( # Use client_instance
|
1087 |
model=model_name,
|
1088 |
contents=prompt,
|
1089 |
config=current_gen_config,
|
|
|
1122 |
# Non-streaming call
|
1123 |
try:
|
1124 |
print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1125 |
+
response = await client_instance.aio.models.generate_content( # Use client_instance
|
1126 |
model=model_name,
|
1127 |
contents=prompt,
|
1128 |
config=current_gen_config,
|
|
|
1165 |
current_config = attempt["config_modifier"](generation_config.copy())
|
1166 |
|
1167 |
try:
|
1168 |
+
result = await make_gemini_call(client_to_use, attempt["model"], attempt["prompt_func"], current_config) # Pass client_to_use
|
1169 |
|
1170 |
# For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
|
1171 |
# For non-streaming, if make_gemini_call doesn't raise, it's successful
|
|
|
1243 |
current_config["system_instruction"] = encryption_instructions
|
1244 |
|
1245 |
try:
|
1246 |
+
result = await make_gemini_call(client_to_use, current_model_name, current_prompt_func, current_config) # Pass client_to_use
|
1247 |
return result
|
1248 |
except Exception as e:
|
1249 |
# Handle potential errors for non-auto models
|
|
|
1339 |
return False
|
1340 |
|
1341 |
# --- Fake streaming implementation ---
|
1342 |
+
async def fake_stream_generator(client_instance, model_name, prompt, current_gen_config, request): # Add client_instance parameter
|
1343 |
"""
|
1344 |
Simulates streaming by making a non-streaming API call and chunking the response.
|
1345 |
While waiting for the response, sends keep-alive messages to the client.
|
|
|
1350 |
# Create a task for the non-streaming API call
|
1351 |
print(f"FAKE STREAMING: Making non-streaming request to Gemini API (Model: {model_name})")
|
1352 |
api_call_task = asyncio.create_task(
|
1353 |
+
client_instance.aio.models.generate_content( # Use client_instance
|
1354 |
model=model_name,
|
1355 |
contents=prompt,
|
1356 |
config=current_gen_config,
|