sachin commited on
Commit
d92ab0e
·
1 Parent(s): 19758dc

add-chat-completio

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. src/server/main.py +52 -115
requirements.txt CHANGED
@@ -8,4 +8,5 @@ pillow
8
  pyjwt
9
  sqlalchemy
10
  passlib[bcrypt]
11
- pycryptodome
 
 
8
  pyjwt
9
  sqlalchemy
10
  passlib[bcrypt]
11
+ pycryptodome
12
+ openai
src/server/main.py CHANGED
@@ -1283,10 +1283,12 @@ async def indic_custom_prompt_kannada_pdf(
1283
  temp_file.close()
1284
  from typing import List, Optional, Dict, Any
1285
 
 
 
1286
  class ChatCompletionRequest(BaseModel):
1287
- model: str = Field(default="gemma-3-12b-it", description="Model identifier (e.g., gemma-3-12b-it)")
1288
- messages: List[Dict[str, str]] = Field(..., description="List of messages in the conversation")
1289
- max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
1290
  temperature: Optional[float] = Field(1.0, description="Sampling temperature")
1291
  top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
1292
  stream: Optional[bool] = Field(False, description="Whether to stream the response")
@@ -1305,151 +1307,86 @@ class ChatCompletionResponse(BaseModel):
1305
  choices: List[ChatCompletionChoice]
1306
  usage: Optional[Dict[str, int]] = None
1307
 
1308
- class Config:
1309
- schema_extra = {
1310
- "example": {
1311
- "id": "chatcmpl-123",
1312
- "object": "chat.completion",
1313
- "created": 1698765432,
1314
- "model": "gemma-3-12b-it",
1315
- "choices": [
1316
- {
1317
- "index": 0,
1318
- "message": {
1319
- "role": "assistant",
1320
- "content": "Hello! How can I assist you today?"
1321
- },
1322
- "finish_reason": "stop"
1323
- }
1324
- ],
1325
- "usage": {
1326
- "prompt_tokens": 10,
1327
- "completion_tokens": 10,
1328
- "total_tokens": 20
1329
- }
1330
- }
1331
- }
1332
-
1333
- # Helper function to convert OpenAI messages to a prompt for llama-server
1334
- def messages_to_prompt(messages: List[Dict[str, str]]) -> str:
1335
- prompt = ""
1336
- for msg in messages:
1337
- role = msg.get("role", "user")
1338
- content = msg.get("content", "")
1339
- if role == "system":
1340
- prompt += f"System: {content}\n"
1341
- elif role == "user":
1342
- prompt += f"User: {content}\n"
1343
- elif role == "assistant":
1344
- prompt += f"Assistant: {content}\n"
1345
- prompt += "Assistant: "
1346
- return prompt
1347
 
1348
  @app.post("/v1/chat/completions",
1349
  response_model=ChatCompletionResponse,
1350
  summary="OpenAI-Compatible Chat Completions",
1351
- description="Proxy endpoint to generate chat completions using llama-server with gemma-3-12b-it model, compatible with OpenAI's API.",
1352
- tags=["Chat"],
1353
- responses={
1354
- 200: {"description": "Chat completion response", "model": ChatCompletionResponse},
1355
- 400: {"description": "Invalid request parameters"},
1356
- 500: {"description": "External llama-server error"},
1357
- 504: {"description": "External llama-server timeout"}
1358
- })
1359
- async def chat_completions(
1360
- request: Request,
1361
- body: ChatCompletionRequest
1362
- ):
1363
- logger.info("Processing chat completion request", extra={
1364
  "endpoint": "/v1/chat/completions",
1365
  "model": body.model,
1366
- "messages_count": len(body.messages),
1367
  "client_ip": request.client.host
1368
  })
1369
 
1370
  # Validate messages
1371
  if not body.messages:
 
1372
  raise HTTPException(status_code=400, detail="Messages cannot be empty")
1373
 
1374
- # Prepare payload for llama-server
1375
- # Adjust this based on the actual llama-server API requirements
1376
- llama_payload = {
1377
- "prompt": messages_to_prompt(body.messages),
1378
- "max_tokens": body.max_tokens if body.max_tokens is not None else 512,
1379
- "temperature": body.temperature,
1380
- "top_p": body.top_p,
1381
- "stream": body.stream
1382
- }
1383
-
1384
- external_url = f"{os.getenv('DWANI_AI_LLM_URL')}/v1/chat/completions"
1385
-
1386
- # llama-server endpoint (adjust if different)
1387
  start_time = time()
1388
 
1389
  try:
1390
- response = requests.post(
1391
- external_url,
1392
- json=llama_payload,
1393
- headers={
1394
- "accept": "application/json",
1395
- "Content-Type": "application/json"
1396
- },
1397
- timeout=30
1398
  )
1399
- response.raise_for_status()
1400
-
1401
- # Parse llama-server response
1402
- response_data = response.json()
1403
-
1404
- # Transform llama-server response to OpenAI-compatible format
1405
- # Adjust based on actual response structure
1406
- completion_text = response_data.get("choices", [{}])[0].get("text", "")
1407
- finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", "stop")
1408
 
1409
- # Generate a unique ID for the response
1410
- completion_id = f"chatcmpl-{int(time.time())}"
 
 
1411
 
1412
- # Build OpenAI-compatible response
1413
  openai_response = ChatCompletionResponse(
1414
- id=completion_id,
1415
- created=int(time.time()),
1416
- model=body.model,
1417
  choices=[
1418
  ChatCompletionChoice(
1419
- index=0,
1420
  message={
1421
- "role": "assistant",
1422
- "content": completion_text.strip()
1423
  },
1424
- finish_reason=finish_reason
1425
- )
1426
  ],
1427
- usage={
1428
- "prompt_tokens": len(llama_payload["prompt"].split()), # Rough estimate
1429
- "completion_tokens": len(completion_text.split()), # Rough estimate
1430
- "total_tokens": len(llama_payload["prompt"].split()) + len(completion_text.split())
1431
- }
 
 
1432
  )
1433
 
1434
  logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
1435
- "response_length": len(completion_text)
1436
  })
1437
  return openai_response
1438
 
1439
- except requests.Timeout:
1440
- logger.error("llama-server request timed out")
1441
- raise HTTPException(status_code=504, detail="llama-server timeout")
1442
- except requests.RequestException as e:
1443
- logger.error(f"llama-server request failed: {str(e)}")
1444
- raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
1445
- except ValueError as e:
1446
- logger.error(f"Invalid JSON response from llama-server: {str(e)}")
1447
- raise HTTPException(status_code=500, detail="Invalid response format from llama-server")
1448
  except Exception as e:
1449
- logger.error(f"Unexpected error: {str(e)}")
1450
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
1451
-
1452
-
1453
 
1454
  if __name__ == "__main__":
1455
  # Ensure EXTERNAL_API_BASE_URL is set
 
1283
  temp_file.close()
1284
  from typing import List, Optional, Dict, Any
1285
 
1286
+ from openai import AsyncOpenAI, OpenAIError
1287
+ # OpenAI-compatible request model
1288
  class ChatCompletionRequest(BaseModel):
1289
+ model: str = Field(default="gemma-3-12b-it", description="Model identifier")
1290
+ messages: List[Dict[str, str]] = Field(..., description="List of messages")
1291
+ max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
1292
  temperature: Optional[float] = Field(1.0, description="Sampling temperature")
1293
  top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
1294
  stream: Optional[bool] = Field(False, description="Whether to stream the response")
 
1307
  choices: List[ChatCompletionChoice]
1308
  usage: Optional[Dict[str, int]] = None
1309
 
1310
+ # Initialize OpenAI client
1311
+ openai_client = AsyncOpenAI(
1312
+ base_url=os.getenv("DWANI_AI_LLM_URL"), # e.g., https://<ngrok-url>.ngrok.io or http://localhost:7860
1313
+ api_key=os.getenv("DWANI_AI_LLM_API_KEY", ""), # Optional API key
1314
+ timeout=30.0
1315
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1316
 
1317
  @app.post("/v1/chat/completions",
1318
  response_model=ChatCompletionResponse,
1319
  summary="OpenAI-Compatible Chat Completions",
1320
+ description="Proxies chat completions to llama-server using OpenAI API format.",
1321
+ tags=["Chat"])
1322
+ async def chat_completions(request: Request, body: ChatCompletionRequest):
1323
+ logger.info("Received chat completion request", extra={
 
 
 
 
 
 
 
 
 
1324
  "endpoint": "/v1/chat/completions",
1325
  "model": body.model,
1326
+ "messages": body.messages,
1327
  "client_ip": request.client.host
1328
  })
1329
 
1330
  # Validate messages
1331
  if not body.messages:
1332
+ logger.error("Messages field is empty", extra={"client_ip": request.client.host})
1333
  raise HTTPException(status_code=400, detail="Messages cannot be empty")
1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1335
  start_time = time()
1336
 
1337
  try:
1338
+ # Proxy request to llama-server using OpenAI client
1339
+ response = await openai_client.chat.completions.create(
1340
+ model=body.model,
1341
+ messages=body.messages,
1342
+ max_tokens=body.max_tokens,
1343
+ temperature=body.temperature,
1344
+ top_p=body.top_p,
1345
+ stream=body.stream
1346
  )
 
 
 
 
 
 
 
 
 
1347
 
1348
+ # Streaming not supported in this simple version
1349
+ if body.stream:
1350
+ logger.error("Streaming requested but not supported")
1351
+ raise HTTPException(status_code=400, detail="Streaming not supported")
1352
 
1353
+ # Map OpenAI response to Pydantic model
1354
  openai_response = ChatCompletionResponse(
1355
+ id=response.id,
1356
+ created=response.created,
1357
+ model=response.model,
1358
  choices=[
1359
  ChatCompletionChoice(
1360
+ index=choice.index,
1361
  message={
1362
+ "role": choice.message.role,
1363
+ "content": choice.message.content
1364
  },
1365
+ finish_reason=choice.finish_reason
1366
+ ) for choice in response.choices
1367
  ],
1368
+ usage=(
1369
+ {
1370
+ "prompt_tokens": response.usage.prompt_tokens,
1371
+ "completion_tokens": response.usage.completion_tokens,
1372
+ "total_tokens": response.usage.total_tokens
1373
+ } if response.usage else None
1374
+ )
1375
  )
1376
 
1377
  logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
1378
+ "response_length": len(response.choices[0].message.content if response.choices else 0)
1379
  })
1380
  return openai_response
1381
 
1382
+ except OpenAIError as e:
1383
+ logger.error(f"llama-server error: {str(e)}", extra={"client_ip": request.client.host})
1384
+ status_code = 504 if "timeout" in str(e).lower() else 500
1385
+ raise HTTPException(status_code=status_code, detail=f"llama-server error: {str(e)}")
 
 
 
 
 
1386
  except Exception as e:
1387
+ logger.error(f"Internal error: {str(e)}", extra={"client_ip": request.client.host})
1388
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
1389
+
 
1390
 
1391
  if __name__ == "__main__":
1392
  # Ensure EXTERNAL_API_BASE_URL is set