bibibi12345 commited on
Commit
40acccd
·
1 Parent(s): f16ae1a

added project files

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ COPY app/requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy application code
10
+ COPY app/ .
11
+
12
+ # Create a directory for the credentials
13
+ RUN mkdir -p /app/credentials
14
+
15
+ # Expose the port
16
+ EXPOSE 8050
17
+
18
+ # Command to run the application
19
+ # Use the default Hugging Face port 7860
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Default password if not set in environment
4
+ DEFAULT_PASSWORD = "123456"
5
+
6
+ # Get password from environment variable or use default
7
+ API_KEY = os.environ.get("API_KEY", DEFAULT_PASSWORD)
8
+
9
+ # Function to validate API key
10
+ def validate_api_key(api_key: str) -> bool:
11
+ """
12
+ Validate the provided API key against the configured key
13
+
14
+ Args:
15
+ api_key: The API key to validate
16
+
17
+ Returns:
18
+ bool: True if the key is valid, False otherwise
19
+ """
20
+ if not API_KEY:
21
+ # If no API key is configured, authentication is disabled
22
+ return True
23
+
24
+ return api_key == API_KEY
app/main.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Header, Request
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.security import APIKeyHeader
4
+ from pydantic import BaseModel, ConfigDict, Field
5
+ from typing import List, Dict, Any, Optional, Union, Literal
6
+ import base64
7
+ import re
8
+ import json
9
+ import time
10
+ import os
11
+ import glob
12
+ import random
13
+ from google.oauth2 import service_account
14
+ import config
15
+
16
+ from google.genai import types
17
+
18
+ from google import genai
19
+
20
+ client = None
21
+
22
+ app = FastAPI(title="OpenAI to Gemini Adapter")
23
+
24
+ # API Key security scheme
25
+ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
26
+
27
+ # Dependency for API key validation
28
+ async def get_api_key(authorization: Optional[str] = Header(None)):
29
+ if authorization is None:
30
+ raise HTTPException(
31
+ status_code=401,
32
+ detail="Missing API key. Please include 'Authorization: Bearer YOUR_API_KEY' header."
33
+ )
34
+
35
+ # Check if the header starts with "Bearer "
36
+ if not authorization.startswith("Bearer "):
37
+ raise HTTPException(
38
+ status_code=401,
39
+ detail="Invalid API key format. Use 'Authorization: Bearer YOUR_API_KEY'"
40
+ )
41
+
42
+ # Extract the API key
43
+ api_key = authorization.replace("Bearer ", "")
44
+
45
+ # Validate the API key
46
+ if not config.validate_api_key(api_key):
47
+ raise HTTPException(
48
+ status_code=401,
49
+ detail="Invalid API key"
50
+ )
51
+
52
+ return api_key
53
+
54
+ # Credential Manager for handling multiple service accounts
55
+ class CredentialManager:
56
+ def __init__(self, default_credentials_dir="/app/credentials"):
57
+ # Use environment variable if set, otherwise use default
58
+ self.credentials_dir = os.environ.get("CREDENTIALS_DIR", default_credentials_dir)
59
+ self.credentials_files = []
60
+ self.current_index = 0
61
+ self.credentials = None
62
+ self.project_id = None
63
+ self.load_credentials_list()
64
+
65
+ def load_credentials_list(self):
66
+ """Load the list of available credential files"""
67
+ # Look for all .json files in the credentials directory
68
+ pattern = os.path.join(self.credentials_dir, "*.json")
69
+ self.credentials_files = glob.glob(pattern)
70
+
71
+ if not self.credentials_files:
72
+ print(f"No credential files found in {self.credentials_dir}")
73
+ return False
74
+
75
+ print(f"Found {len(self.credentials_files)} credential files: {[os.path.basename(f) for f in self.credentials_files]}")
76
+ return True
77
+
78
+ def refresh_credentials_list(self):
79
+ """Refresh the list of credential files (useful if files are added/removed)"""
80
+ old_count = len(self.credentials_files)
81
+ self.load_credentials_list()
82
+ new_count = len(self.credentials_files)
83
+
84
+ if old_count != new_count:
85
+ print(f"Credential files updated: {old_count} -> {new_count}")
86
+
87
+ return len(self.credentials_files) > 0
88
+
89
+ def get_next_credentials(self):
90
+ """Rotate to the next credential file and load it"""
91
+ if not self.credentials_files:
92
+ return None, None
93
+
94
+ # Get the next credential file in rotation
95
+ file_path = self.credentials_files[self.current_index]
96
+ self.current_index = (self.current_index + 1) % len(self.credentials_files)
97
+
98
+ try:
99
+ credentials = service_account.Credentials.from_service_account_file(file_path,scopes=['https://www.googleapis.com/auth/cloud-platform'])
100
+ project_id = credentials.project_id
101
+ print(f"Loaded credentials from {file_path} for project: {project_id}")
102
+ self.credentials = credentials
103
+ self.project_id = project_id
104
+ return credentials, project_id
105
+ except Exception as e:
106
+ print(f"Error loading credentials from {file_path}: {e}")
107
+ # Try the next file if this one fails
108
+ if len(self.credentials_files) > 1:
109
+ print("Trying next credential file...")
110
+ return self.get_next_credentials()
111
+ return None, None
112
+
113
+ def get_random_credentials(self):
114
+ """Get a random credential file and load it"""
115
+ if not self.credentials_files:
116
+ return None, None
117
+
118
+ # Choose a random credential file
119
+ file_path = random.choice(self.credentials_files)
120
+
121
+ try:
122
+ credentials = service_account.Credentials.from_service_account_file(file_path,scopes=['https://www.googleapis.com/auth/cloud-platform'])
123
+ project_id = credentials.project_id
124
+ print(f"Loaded credentials from {file_path} for project: {project_id}")
125
+ self.credentials = credentials
126
+ self.project_id = project_id
127
+ return credentials, project_id
128
+ except Exception as e:
129
+ print(f"Error loading credentials from {file_path}: {e}")
130
+ # Try another random file if this one fails
131
+ if len(self.credentials_files) > 1:
132
+ print("Trying another credential file...")
133
+ return self.get_random_credentials()
134
+ return None, None
135
+
136
+ # Initialize the credential manager
137
+ credential_manager = CredentialManager()
138
+
139
+ # Define data models
140
+ class ImageUrl(BaseModel):
141
+ url: str
142
+
143
+ class ContentPartImage(BaseModel):
144
+ type: Literal["image_url"]
145
+ image_url: ImageUrl
146
+
147
+ class ContentPartText(BaseModel):
148
+ type: Literal["text"]
149
+ text: str
150
+
151
+ class OpenAIMessage(BaseModel):
152
+ role: str
153
+ content: Union[str, List[Union[ContentPartText, ContentPartImage, Dict[str, Any]]]]
154
+
155
+ class OpenAIRequest(BaseModel):
156
+ model: str
157
+ messages: List[OpenAIMessage]
158
+ temperature: Optional[float] = 1.0
159
+ max_tokens: Optional[int] = None
160
+ top_p: Optional[float] = 1.0
161
+ top_k: Optional[int] = None
162
+ stream: Optional[bool] = False
163
+ stop: Optional[List[str]] = None
164
+ presence_penalty: Optional[float] = None
165
+ frequency_penalty: Optional[float] = None
166
+ seed: Optional[int] = None
167
+ logprobs: Optional[int] = None
168
+ response_logprobs: Optional[bool] = None
169
+ n: Optional[int] = None # Maps to candidate_count in Vertex AI
170
+
171
+ # Allow extra fields to pass through without causing validation errors
172
+ model_config = ConfigDict(extra='allow')
173
+
174
+ # Configure authentication
175
+ def init_vertex_ai():
176
+ global client # Ensure we modify the global client variable
177
+ try:
178
+ # Priority 1: Check for credentials JSON content in environment variable (Hugging Face)
179
+ credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
180
+ if credentials_json_str:
181
+ try:
182
+ credentials_info = json.loads(credentials_json_str)
183
+ credentials = service_account.Credentials.from_service_account_info(credentials_info, scopes=['https://www.googleapis.com/auth/cloud-platform'])
184
+ project_id = credentials.project_id
185
+ client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
186
+ print(f"Initialized Vertex AI using GOOGLE_CREDENTIALS_JSON env var for project: {project_id}")
187
+ return True
188
+ except Exception as e:
189
+ print(f"Error loading credentials from GOOGLE_CREDENTIALS_JSON: {e}")
190
+ # Fall through to other methods if this fails
191
+
192
+ # Priority 2: Try to use the credential manager to get credentials from files
193
+ credentials, project_id = credential_manager.get_next_credentials()
194
+
195
+ if credentials and project_id:
196
+ client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
197
+ print(f"Initialized Vertex AI using Credential Manager for project: {project_id}")
198
+ return True
199
+
200
+ # Priority 3: Fall back to GOOGLE_APPLICATION_CREDENTIALS environment variable (file path)
201
+ file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
202
+ if file_path and os.path.exists(file_path):
203
+ try:
204
+ credentials = service_account.Credentials.from_service_account_file(file_path, scopes=['https://www.googleapis.com/auth/cloud-platform'])
205
+ project_id = credentials.project_id
206
+ client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
207
+ print(f"Initialized Vertex AI using GOOGLE_APPLICATION_CREDENTIALS file path for project: {project_id}")
208
+ return True
209
+ except Exception as e:
210
+ print(f"Error loading credentials from GOOGLE_APPLICATION_CREDENTIALS path {file_path}: {e}")
211
+
212
+ # If none of the methods worked
213
+ print(f"Error: No valid credentials found. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager ({credential_manager.credentials_dir}), and GOOGLE_APPLICATION_CREDENTIALS.")
214
+ return False
215
+ except Exception as e:
216
+ print(f"Error initializing authentication: {e}")
217
+ return False
218
+
219
+ # Initialize Vertex AI at startup
220
+ @app.on_event("startup")
221
+ async def startup_event():
222
+ if not init_vertex_ai():
223
+ print("WARNING: Failed to initialize Vertex AI authentication")
224
+
225
+ # Conversion functions
226
+ def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[str, List[Any]]:
227
+ """
228
+ Convert OpenAI messages to Gemini format.
229
+ Returns either a string prompt or a list of content parts if images are present.
230
+ """
231
+ # Check if any message contains image content
232
+ has_images = False
233
+ for message in messages:
234
+ if isinstance(message.content, list):
235
+ for part in message.content:
236
+ if isinstance(part, dict) and part.get('type') == 'image_url':
237
+ has_images = True
238
+ break
239
+ elif isinstance(part, ContentPartImage):
240
+ has_images = True
241
+ break
242
+ if has_images:
243
+ break
244
+
245
+ # If no images, use the text-only format
246
+ if not has_images:
247
+ prompt = ""
248
+
249
+ # Extract system message if present
250
+ system_message = None
251
+ for message in messages:
252
+ if message.role == "system":
253
+ # Handle both string and list[dict] content types
254
+ if isinstance(message.content, str):
255
+ system_message = message.content
256
+ elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
257
+ system_message = message.content[0]['text']
258
+ else:
259
+ # Handle unexpected format or raise error? For now, assume it's usable or skip.
260
+ system_message = str(message.content) # Fallback, might need refinement
261
+ break
262
+
263
+ # If system message exists, prepend it
264
+ if system_message:
265
+ prompt += f"System: {system_message}\n\n"
266
+
267
+ # Add other messages
268
+ for message in messages:
269
+ if message.role == "system":
270
+ continue # Already handled
271
+
272
+ # Handle both string and list[dict] content types
273
+ content_text = ""
274
+ if isinstance(message.content, str):
275
+ content_text = message.content
276
+ elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
277
+ content_text = message.content[0]['text']
278
+ else:
279
+ # Fallback for unexpected format
280
+ content_text = str(message.content)
281
+
282
+ if message.role == "user":
283
+ prompt += f"Human: {content_text}\n"
284
+ elif message.role == "assistant":
285
+ prompt += f"AI: {content_text}\n"
286
+
287
+ # Add final AI prompt if last message was from user
288
+ if messages[-1].role == "user":
289
+ prompt += "AI: "
290
+
291
+ return prompt
292
+
293
+ # If images are present, create a list of content parts
294
+ gemini_contents = []
295
+
296
+ # Extract system message if present and add it first
297
+ for message in messages:
298
+ if message.role == "system":
299
+ if isinstance(message.content, str):
300
+ gemini_contents.append(f"System: {message.content}")
301
+ elif isinstance(message.content, list):
302
+ # Extract text from system message
303
+ system_text = ""
304
+ for part in message.content:
305
+ if isinstance(part, dict) and part.get('type') == 'text':
306
+ system_text += part.get('text', '')
307
+ elif isinstance(part, ContentPartText):
308
+ system_text += part.text
309
+ if system_text:
310
+ gemini_contents.append(f"System: {system_text}")
311
+ break
312
+
313
+ # Process user and assistant messages
314
+ for message in messages:
315
+ if message.role == "system":
316
+ continue # Already handled
317
+
318
+ # For string content, add as text
319
+ if isinstance(message.content, str):
320
+ prefix = "Human: " if message.role == "user" else "AI: "
321
+ gemini_contents.append(f"{prefix}{message.content}")
322
+
323
+ # For list content, process each part
324
+ elif isinstance(message.content, list):
325
+ # First collect all text parts
326
+ text_content = ""
327
+
328
+ for part in message.content:
329
+ # Handle text parts
330
+ if isinstance(part, dict) and part.get('type') == 'text':
331
+ text_content += part.get('text', '')
332
+ elif isinstance(part, ContentPartText):
333
+ text_content += part.text
334
+
335
+ # Add the combined text content if any
336
+ if text_content:
337
+ prefix = "Human: " if message.role == "user" else "AI: "
338
+ gemini_contents.append(f"{prefix}{text_content}")
339
+
340
+ # Then process image parts
341
+ for part in message.content:
342
+ # Handle image parts
343
+ if isinstance(part, dict) and part.get('type') == 'image_url':
344
+ image_url = part.get('image_url', {}).get('url', '')
345
+ if image_url.startswith('data:'):
346
+ # Extract mime type and base64 data
347
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
348
+ if mime_match:
349
+ mime_type, b64_data = mime_match.groups()
350
+ image_bytes = base64.b64decode(b64_data)
351
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
352
+ elif isinstance(part, ContentPartImage):
353
+ image_url = part.image_url.url
354
+ if image_url.startswith('data:'):
355
+ # Extract mime type and base64 data
356
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
357
+ if mime_match:
358
+ mime_type, b64_data = mime_match.groups()
359
+ image_bytes = base64.b64decode(b64_data)
360
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
361
+
362
+ return gemini_contents
363
+
364
+ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
365
+ config = {}
366
+
367
+ # Basic parameters that were already supported
368
+ if request.temperature is not None:
369
+ config["temperature"] = request.temperature
370
+
371
+ if request.max_tokens is not None:
372
+ config["max_output_tokens"] = request.max_tokens
373
+
374
+ if request.top_p is not None:
375
+ config["top_p"] = request.top_p
376
+
377
+ if request.top_k is not None:
378
+ config["top_k"] = request.top_k
379
+
380
+ if request.stop is not None:
381
+ config["stop_sequences"] = request.stop
382
+
383
+ # Additional parameters with direct mappings
384
+ if request.presence_penalty is not None:
385
+ config["presence_penalty"] = request.presence_penalty
386
+
387
+ if request.frequency_penalty is not None:
388
+ config["frequency_penalty"] = request.frequency_penalty
389
+
390
+ if request.seed is not None:
391
+ config["seed"] = request.seed
392
+
393
+ if request.logprobs is not None:
394
+ config["logprobs"] = request.logprobs
395
+
396
+ if request.response_logprobs is not None:
397
+ config["response_logprobs"] = request.response_logprobs
398
+
399
+ # Map OpenAI's 'n' parameter to Vertex AI's 'candidate_count'
400
+ if request.n is not None:
401
+ config["candidate_count"] = request.n
402
+
403
+ return config
404
+
405
+ # Response format conversion
406
+ def convert_to_openai_format(gemini_response, model: str) -> Dict[str, Any]:
407
+ # Handle multiple candidates if present
408
+ if hasattr(gemini_response, 'candidates') and len(gemini_response.candidates) > 1:
409
+ choices = []
410
+ for i, candidate in enumerate(gemini_response.candidates):
411
+ choices.append({
412
+ "index": i,
413
+ "message": {
414
+ "role": "assistant",
415
+ "content": candidate.text
416
+ },
417
+ "finish_reason": "stop"
418
+ })
419
+ else:
420
+ # Handle single response (backward compatibility)
421
+ choices = [
422
+ {
423
+ "index": 0,
424
+ "message": {
425
+ "role": "assistant",
426
+ "content": gemini_response.text
427
+ },
428
+ "finish_reason": "stop"
429
+ }
430
+ ]
431
+
432
+ # Include logprobs if available
433
+ for i, choice in enumerate(choices):
434
+ if hasattr(gemini_response, 'candidates') and i < len(gemini_response.candidates):
435
+ candidate = gemini_response.candidates[i]
436
+ if hasattr(candidate, 'logprobs'):
437
+ choice["logprobs"] = candidate.logprobs
438
+
439
+ return {
440
+ "id": f"chatcmpl-{int(time.time())}",
441
+ "object": "chat.completion",
442
+ "created": int(time.time()),
443
+ "model": model,
444
+ "choices": choices,
445
+ "usage": {
446
+ "prompt_tokens": 0, # Would need token counting logic
447
+ "completion_tokens": 0,
448
+ "total_tokens": 0
449
+ }
450
+ }
451
+
452
+ def convert_chunk_to_openai(chunk, model: str, response_id: str, candidate_index: int = 0) -> str:
453
+ chunk_content = chunk.text if hasattr(chunk, 'text') else ""
454
+
455
+ chunk_data = {
456
+ "id": response_id,
457
+ "object": "chat.completion.chunk",
458
+ "created": int(time.time()),
459
+ "model": model,
460
+ "choices": [
461
+ {
462
+ "index": candidate_index,
463
+ "delta": {
464
+ "content": chunk_content
465
+ },
466
+ "finish_reason": None
467
+ }
468
+ ]
469
+ }
470
+
471
+ # Add logprobs if available
472
+ if hasattr(chunk, 'logprobs'):
473
+ chunk_data["choices"][0]["logprobs"] = chunk.logprobs
474
+
475
+ return f"data: {json.dumps(chunk_data)}\n\n"
476
+
477
+ def create_final_chunk(model: str, response_id: str, candidate_count: int = 1) -> str:
478
+ choices = []
479
+ for i in range(candidate_count):
480
+ choices.append({
481
+ "index": i,
482
+ "delta": {},
483
+ "finish_reason": "stop"
484
+ })
485
+
486
+ final_chunk = {
487
+ "id": response_id,
488
+ "object": "chat.completion.chunk",
489
+ "created": int(time.time()),
490
+ "model": model,
491
+ "choices": choices
492
+ }
493
+
494
+ return f"data: {json.dumps(final_chunk)}\n\n"
495
+
496
+ # /v1/models endpoint
497
+ @app.get("/v1/models")
498
+ async def list_models(api_key: str = Depends(get_api_key)):
499
+ # Based on current information for Vertex AI models
500
+ models = [
501
+ {
502
+ "id": "gemini-2.5-pro-exp-03-25",
503
+ "object": "model",
504
+ "created": int(time.time()),
505
+ "owned_by": "google",
506
+ "permission": [],
507
+ "root": "gemini-2.5-pro-exp-03-25",
508
+ "parent": None,
509
+ },
510
+ {
511
+ "id": "gemini-2.5-pro-exp-03-25-search",
512
+ "object": "model",
513
+ "created": int(time.time()),
514
+ "owned_by": "google",
515
+ "permission": [],
516
+ "root": "gemini-2.5-pro-exp-03-25",
517
+ "parent": None,
518
+ },
519
+ {
520
+ "id": "gemini-2.0-flash",
521
+ "object": "model",
522
+ "created": int(time.time()),
523
+ "owned_by": "google",
524
+ "permission": [],
525
+ "root": "gemini-2.0-flash",
526
+ "parent": None,
527
+ },
528
+ {
529
+ "id": "gemini-2.0-flash-search",
530
+ "object": "model",
531
+ "created": int(time.time()),
532
+ "owned_by": "google",
533
+ "permission": [],
534
+ "root": "gemini-2.0-flash",
535
+ "parent": None,
536
+ },
537
+ {
538
+ "id": "gemini-2.0-flash-lite",
539
+ "object": "model",
540
+ "created": int(time.time()),
541
+ "owned_by": "google",
542
+ "permission": [],
543
+ "root": "gemini-2.0-flash-lite",
544
+ "parent": None,
545
+ },
546
+ {
547
+ "id": "gemini-2.0-flash-lite-search",
548
+ "object": "model",
549
+ "created": int(time.time()),
550
+ "owned_by": "google",
551
+ "permission": [],
552
+ "root": "gemini-2.0-flash-lite",
553
+ "parent": None,
554
+ },
555
+ {
556
+ "id": "gemini-2.0-pro-exp-02-05",
557
+ "object": "model",
558
+ "created": int(time.time()),
559
+ "owned_by": "google",
560
+ "permission": [],
561
+ "root": "gemini-2.0-pro-exp-02-05",
562
+ "parent": None,
563
+ },
564
+ {
565
+ "id": "gemini-1.5-flash",
566
+ "object": "model",
567
+ "created": int(time.time()),
568
+ "owned_by": "google",
569
+ "permission": [],
570
+ "root": "gemini-1.5-flash",
571
+ "parent": None,
572
+ },
573
+ {
574
+ "id": "gemini-1.5-flash-8b",
575
+ "object": "model",
576
+ "created": int(time.time()),
577
+ "owned_by": "google",
578
+ "permission": [],
579
+ "root": "gemini-1.5-flash-8b",
580
+ "parent": None,
581
+ },
582
+ {
583
+ "id": "gemini-1.5-pro",
584
+ "object": "model",
585
+ "created": int(time.time()),
586
+ "owned_by": "google",
587
+ "permission": [],
588
+ "root": "gemini-1.5-pro",
589
+ "parent": None,
590
+ },
591
+ {
592
+ "id": "gemini-1.0-pro-002",
593
+ "object": "model",
594
+ "created": int(time.time()),
595
+ "owned_by": "google",
596
+ "permission": [],
597
+ "root": "gemini-1.0-pro-002",
598
+ "parent": None,
599
+ },
600
+ {
601
+ "id": "gemini-1.0-pro-vision-001",
602
+ "object": "model",
603
+ "created": int(time.time()),
604
+ "owned_by": "google",
605
+ "permission": [],
606
+ "root": "gemini-1.0-pro-vision-001",
607
+ "parent": None,
608
+ },
609
+ {
610
+ "id": "gemini-embedding-exp",
611
+ "object": "model",
612
+ "created": int(time.time()),
613
+ "owned_by": "google",
614
+ "permission": [],
615
+ "root": "gemini-embedding-exp",
616
+ "parent": None,
617
+ }
618
+ ]
619
+
620
+ return {"object": "list", "data": models}
621
+
622
+ # Main chat completion endpoint
623
+ # OpenAI-compatible error response
624
+ def create_openai_error_response(status_code: int, message: str, error_type: str) -> Dict[str, Any]:
625
+ return {
626
+ "error": {
627
+ "message": message,
628
+ "type": error_type,
629
+ "code": status_code,
630
+ "param": None,
631
+ }
632
+ }
633
+
634
+ @app.post("/v1/chat/completions")
635
+ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)):
636
+ try:
637
+ # Validate model availability
638
+ models_response = await list_models()
639
+ if not request.model or not any(model["id"] == request.model for model in models_response.get("data", [])):
640
+ error_response = create_openai_error_response(
641
+ 400, f"Model '{request.model}' not found", "invalid_request_error"
642
+ )
643
+ return JSONResponse(status_code=400, content=error_response)
644
+
645
+ # Check if this is a grounded search model
646
+ is_grounded_search = request.model.endswith("-search")
647
+
648
+ # Extract the base model name (remove -search suffix if present)
649
+ gemini_model = request.model.replace("-search", "") if is_grounded_search else request.model
650
+
651
+ # Create generation config
652
+ generation_config = create_generation_config(request)
653
+
654
+ # Get fresh credentials for this request
655
+ credentials, project_id = credential_manager.get_next_credentials()
656
+
657
+ if not credentials or not project_id:
658
+ error_response = create_openai_error_response(
659
+ 500, "Failed to obtain valid credentials", "server_error"
660
+ )
661
+ return JSONResponse(status_code=500, content=error_response)
662
+
663
+ # Initialize Vertex AI with the rotated credentials
664
+ try:
665
+ # Re-initialize client for this request - credentials might have rotated
666
+ client = genai.Client(vertexai=True, credentials=credentials, project=project_id, location="us-central1")
667
+ print(f"Using credentials for project: {project_id} for this request")
668
+ except Exception as auth_error:
669
+ error_response = create_openai_error_response(
670
+ 500, f"Failed to initialize authentication: {str(auth_error)}", "server_error"
671
+ )
672
+ return JSONResponse(status_code=500, content=error_response)
673
+
674
+ # Initialize Gemini model
675
+ search_tool = types.Tool(google_search=types.GoogleSearch())
676
+
677
+ safety_settings = [
678
+ types.SafetySetting(
679
+ category="HARM_CATEGORY_HATE_SPEECH",
680
+ threshold="OFF"
681
+ ),types.SafetySetting(
682
+ category="HARM_CATEGORY_DANGEROUS_CONTENT",
683
+ threshold="OFF"
684
+ ),types.SafetySetting(
685
+ category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
686
+ threshold="OFF"
687
+ ),types.SafetySetting(
688
+ category="HARM_CATEGORY_HARASSMENT",
689
+ threshold="OFF"
690
+ )]
691
+
692
+ generation_config["safety_settings"] = safety_settings
693
+ if is_grounded_search:
694
+ generation_config["tools"] = [search_tool]
695
+
696
+ # Create prompt from messages
697
+ prompt = create_gemini_prompt(request.messages)
698
+
699
+ if request.stream:
700
+ # Handle streaming response
701
+ async def stream_generator():
702
+ response_id = f"chatcmpl-{int(time.time())}"
703
+ candidate_count = request.n or 1
704
+
705
+ try:
706
+ # For streaming, we can only handle one candidate at a time
707
+ # If multiple candidates are requested, we'll generate them sequentially
708
+ for candidate_index in range(candidate_count):
709
+ # Generate content with streaming
710
+ # Handle both string and list content formats (for images)
711
+ responses = client.models.generate_content_stream(
712
+ model=gemini_model,
713
+ contents=prompt, # This can be either a string or a list of content parts
714
+ config=generation_config,
715
+ )
716
+
717
+ # Convert and yield each chunk
718
+ for response in responses:
719
+ yield convert_chunk_to_openai(response, request.model, response_id, candidate_index)
720
+
721
+ # Send final chunk with all candidates
722
+ yield create_final_chunk(request.model, response_id, candidate_count)
723
+ yield "data: [DONE]\n\n"
724
+
725
+ except Exception as stream_error:
726
+ # Format streaming errors in SSE format
727
+ error_msg = f"Error during streaming: {str(stream_error)}"
728
+ print(error_msg)
729
+ error_response = create_openai_error_response(500, error_msg, "server_error")
730
+ yield f"data: {json.dumps(error_response)}\n\n"
731
+ yield "data: [DONE]\n\n"
732
+
733
+ return StreamingResponse(
734
+ stream_generator(),
735
+ media_type="text/event-stream"
736
+ )
737
+ else:
738
+ # Handle non-streaming response
739
+ try:
740
+ # If multiple candidates are requested, set candidate_count
741
+ if request.n and request.n > 1:
742
+ # Make sure generation_config has candidate_count set
743
+ if "candidate_count" not in generation_config:
744
+ generation_config["candidate_count"] = request.n
745
+ # Handle both string and list content formats (for images)
746
+ response = client.models.generate_content(
747
+ model=gemini_model,
748
+ contents=prompt, # This can be either a string or a list of content parts
749
+ config=generation_config,
750
+ )
751
+
752
+
753
+ openai_response = convert_to_openai_format(response, request.model)
754
+ return JSONResponse(content=openai_response)
755
+ except Exception as generate_error:
756
+ error_msg = f"Error generating content: {str(generate_error)}"
757
+ print(error_msg)
758
+ error_response = create_openai_error_response(500, error_msg, "server_error")
759
+ return JSONResponse(status_code=500, content=error_response)
760
+
761
+ except Exception as e:
762
+ error_msg = f"Error processing request: {str(e)}"
763
+ print(error_msg)
764
+ error_response = create_openai_error_response(500, error_msg, "server_error")
765
+ return JSONResponse(status_code=500, content=error_response)
766
+
767
+ # Health check endpoint
768
+ @app.get("/health")
769
+ def health_check(api_key: str = Depends(get_api_key)):
770
+ # Refresh the credentials list to get the latest status
771
+ credential_manager.refresh_credentials_list()
772
+
773
+ return {
774
+ "status": "ok",
775
+ "credentials": {
776
+ "available": len(credential_manager.credentials_files),
777
+ "files": [os.path.basename(f) for f in credential_manager.credentials_files],
778
+ "current_index": credential_manager.current_index
779
+ }
780
+ }
app/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.110.0
2
+ uvicorn==0.27.1
3
+ google-auth==2.38.0
4
+ google-cloud-aiplatform==1.86.0
5
+ pydantic==2.6.1
6
+ google-genai==1.8.0
credentials/Placeholder Place credential json files here ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ openai-to-gemini:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ # Map host port 8050 to container port 7860 (for Hugging Face compatibility)
10
+ - "8050:7860"
11
+ volumes:
12
+ - ./credentials:/app/credentials
13
+ environment:
14
+ # This is kept for backward compatibility but our app now primarily uses the credential manager
15
+ - GOOGLE_APPLICATION_CREDENTIALS=/app/credentials/service-account.json
16
+ # Directory where credential files are stored (used by credential manager)
17
+ - CREDENTIALS_DIR=/app/credentials
18
+ # API key for authentication (default: 123456)
19
+ - API_KEY=123456
20
+ restart: unless-stopped