sachin commited on
Commit
19758dc
·
1 Parent(s): ac28f35

add-chat-completio

Browse files
Files changed (1) hide show
  1. src/server/main.py +168 -0
src/server/main.py CHANGED
@@ -1281,6 +1281,174 @@ async def indic_custom_prompt_kannada_pdf(
1281
  finally:
1282
  # Close the temporary file to ensure it's fully written
1283
  temp_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1284
 
1285
 
1286
  if __name__ == "__main__":
 
1281
  finally:
1282
  # Close the temporary file to ensure it's fully written
1283
  temp_file.close()
1284
+ from typing import List, Optional, Dict, Any
1285
+
1286
+ class ChatCompletionRequest(BaseModel):
1287
+ model: str = Field(default="gemma-3-12b-it", description="Model identifier (e.g., gemma-3-12b-it)")
1288
+ messages: List[Dict[str, str]] = Field(..., description="List of messages in the conversation")
1289
+ max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
1290
+ temperature: Optional[float] = Field(1.0, description="Sampling temperature")
1291
+ top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
1292
+ stream: Optional[bool] = Field(False, description="Whether to stream the response")
1293
+
1294
+ # OpenAI-compatible response model
1295
+ class ChatCompletionChoice(BaseModel):
1296
+ index: int
1297
+ message: Dict[str, str]
1298
+ finish_reason: Optional[str]
1299
+
1300
+ class ChatCompletionResponse(BaseModel):
1301
+ id: str
1302
+ object: str = "chat.completion"
1303
+ created: int
1304
+ model: str
1305
+ choices: List[ChatCompletionChoice]
1306
+ usage: Optional[Dict[str, int]] = None
1307
+
1308
+ class Config:
1309
+ schema_extra = {
1310
+ "example": {
1311
+ "id": "chatcmpl-123",
1312
+ "object": "chat.completion",
1313
+ "created": 1698765432,
1314
+ "model": "gemma-3-12b-it",
1315
+ "choices": [
1316
+ {
1317
+ "index": 0,
1318
+ "message": {
1319
+ "role": "assistant",
1320
+ "content": "Hello! How can I assist you today?"
1321
+ },
1322
+ "finish_reason": "stop"
1323
+ }
1324
+ ],
1325
+ "usage": {
1326
+ "prompt_tokens": 10,
1327
+ "completion_tokens": 10,
1328
+ "total_tokens": 20
1329
+ }
1330
+ }
1331
+ }
1332
+
1333
+ # Helper function to convert OpenAI messages to a prompt for llama-server
1334
+ def messages_to_prompt(messages: List[Dict[str, str]]) -> str:
1335
+ prompt = ""
1336
+ for msg in messages:
1337
+ role = msg.get("role", "user")
1338
+ content = msg.get("content", "")
1339
+ if role == "system":
1340
+ prompt += f"System: {content}\n"
1341
+ elif role == "user":
1342
+ prompt += f"User: {content}\n"
1343
+ elif role == "assistant":
1344
+ prompt += f"Assistant: {content}\n"
1345
+ prompt += "Assistant: "
1346
+ return prompt
1347
+
1348
+ @app.post("/v1/chat/completions",
1349
+ response_model=ChatCompletionResponse,
1350
+ summary="OpenAI-Compatible Chat Completions",
1351
+ description="Proxy endpoint to generate chat completions using llama-server with gemma-3-12b-it model, compatible with OpenAI's API.",
1352
+ tags=["Chat"],
1353
+ responses={
1354
+ 200: {"description": "Chat completion response", "model": ChatCompletionResponse},
1355
+ 400: {"description": "Invalid request parameters"},
1356
+ 500: {"description": "External llama-server error"},
1357
+ 504: {"description": "External llama-server timeout"}
1358
+ })
1359
+ async def chat_completions(
1360
+ request: Request,
1361
+ body: ChatCompletionRequest
1362
+ ):
1363
+ logger.info("Processing chat completion request", extra={
1364
+ "endpoint": "/v1/chat/completions",
1365
+ "model": body.model,
1366
+ "messages_count": len(body.messages),
1367
+ "client_ip": request.client.host
1368
+ })
1369
+
1370
+ # Validate messages
1371
+ if not body.messages:
1372
+ raise HTTPException(status_code=400, detail="Messages cannot be empty")
1373
+
1374
+ # Prepare payload for llama-server
1375
+ # Adjust this based on the actual llama-server API requirements
1376
+ llama_payload = {
1377
+ "prompt": messages_to_prompt(body.messages),
1378
+ "max_tokens": body.max_tokens if body.max_tokens is not None else 512,
1379
+ "temperature": body.temperature,
1380
+ "top_p": body.top_p,
1381
+ "stream": body.stream
1382
+ }
1383
+
1384
+ external_url = f"{os.getenv('DWANI_AI_LLM_URL')}/v1/chat/completions"
1385
+
1386
+ # llama-server endpoint (adjust if different)
1387
+ start_time = time()
1388
+
1389
+ try:
1390
+ response = requests.post(
1391
+ external_url,
1392
+ json=llama_payload,
1393
+ headers={
1394
+ "accept": "application/json",
1395
+ "Content-Type": "application/json"
1396
+ },
1397
+ timeout=30
1398
+ )
1399
+ response.raise_for_status()
1400
+
1401
+ # Parse llama-server response
1402
+ response_data = response.json()
1403
+
1404
+ # Transform llama-server response to OpenAI-compatible format
1405
+ # Adjust based on actual response structure
1406
+ completion_text = response_data.get("choices", [{}])[0].get("text", "")
1407
+ finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", "stop")
1408
+
1409
+ # Generate a unique ID for the response
1410
+ completion_id = f"chatcmpl-{int(time.time())}"
1411
+
1412
+ # Build OpenAI-compatible response
1413
+ openai_response = ChatCompletionResponse(
1414
+ id=completion_id,
1415
+ created=int(time.time()),
1416
+ model=body.model,
1417
+ choices=[
1418
+ ChatCompletionChoice(
1419
+ index=0,
1420
+ message={
1421
+ "role": "assistant",
1422
+ "content": completion_text.strip()
1423
+ },
1424
+ finish_reason=finish_reason
1425
+ )
1426
+ ],
1427
+ usage={
1428
+ "prompt_tokens": len(llama_payload["prompt"].split()), # Rough estimate
1429
+ "completion_tokens": len(completion_text.split()), # Rough estimate
1430
+ "total_tokens": len(llama_payload["prompt"].split()) + len(completion_text.split())
1431
+ }
1432
+ )
1433
+
1434
+ logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
1435
+ "response_length": len(completion_text)
1436
+ })
1437
+ return openai_response
1438
+
1439
+ except requests.Timeout:
1440
+ logger.error("llama-server request timed out")
1441
+ raise HTTPException(status_code=504, detail="llama-server timeout")
1442
+ except requests.RequestException as e:
1443
+ logger.error(f"llama-server request failed: {str(e)}")
1444
+ raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
1445
+ except ValueError as e:
1446
+ logger.error(f"Invalid JSON response from llama-server: {str(e)}")
1447
+ raise HTTPException(status_code=500, detail="Invalid response format from llama-server")
1448
+ except Exception as e:
1449
+ logger.error(f"Unexpected error: {str(e)}")
1450
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
1451
+
1452
 
1453
 
1454
  if __name__ == "__main__":