Spaces:
Paused
Paused
sachin
commited on
Commit
·
19758dc
1
Parent(s):
ac28f35
add-chat-completio
Browse files- src/server/main.py +168 -0
src/server/main.py
CHANGED
@@ -1281,6 +1281,174 @@ async def indic_custom_prompt_kannada_pdf(
|
|
1281 |
finally:
|
1282 |
# Close the temporary file to ensure it's fully written
|
1283 |
temp_file.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1284 |
|
1285 |
|
1286 |
if __name__ == "__main__":
|
|
|
1281 |
finally:
|
1282 |
# Close the temporary file to ensure it's fully written
|
1283 |
temp_file.close()
|
1284 |
+
from typing import List, Optional, Dict, Any
|
1285 |
+
|
1286 |
+
class ChatCompletionRequest(BaseModel):
|
1287 |
+
model: str = Field(default="gemma-3-12b-it", description="Model identifier (e.g., gemma-3-12b-it)")
|
1288 |
+
messages: List[Dict[str, str]] = Field(..., description="List of messages in the conversation")
|
1289 |
+
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
|
1290 |
+
temperature: Optional[float] = Field(1.0, description="Sampling temperature")
|
1291 |
+
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
1292 |
+
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
1293 |
+
|
1294 |
+
# OpenAI-compatible response model
|
1295 |
+
class ChatCompletionChoice(BaseModel):
|
1296 |
+
index: int
|
1297 |
+
message: Dict[str, str]
|
1298 |
+
finish_reason: Optional[str]
|
1299 |
+
|
1300 |
+
class ChatCompletionResponse(BaseModel):
|
1301 |
+
id: str
|
1302 |
+
object: str = "chat.completion"
|
1303 |
+
created: int
|
1304 |
+
model: str
|
1305 |
+
choices: List[ChatCompletionChoice]
|
1306 |
+
usage: Optional[Dict[str, int]] = None
|
1307 |
+
|
1308 |
+
class Config:
|
1309 |
+
schema_extra = {
|
1310 |
+
"example": {
|
1311 |
+
"id": "chatcmpl-123",
|
1312 |
+
"object": "chat.completion",
|
1313 |
+
"created": 1698765432,
|
1314 |
+
"model": "gemma-3-12b-it",
|
1315 |
+
"choices": [
|
1316 |
+
{
|
1317 |
+
"index": 0,
|
1318 |
+
"message": {
|
1319 |
+
"role": "assistant",
|
1320 |
+
"content": "Hello! How can I assist you today?"
|
1321 |
+
},
|
1322 |
+
"finish_reason": "stop"
|
1323 |
+
}
|
1324 |
+
],
|
1325 |
+
"usage": {
|
1326 |
+
"prompt_tokens": 10,
|
1327 |
+
"completion_tokens": 10,
|
1328 |
+
"total_tokens": 20
|
1329 |
+
}
|
1330 |
+
}
|
1331 |
+
}
|
1332 |
+
|
1333 |
+
# Helper function to convert OpenAI messages to a prompt for llama-server
|
1334 |
+
def messages_to_prompt(messages: List[Dict[str, str]]) -> str:
|
1335 |
+
prompt = ""
|
1336 |
+
for msg in messages:
|
1337 |
+
role = msg.get("role", "user")
|
1338 |
+
content = msg.get("content", "")
|
1339 |
+
if role == "system":
|
1340 |
+
prompt += f"System: {content}\n"
|
1341 |
+
elif role == "user":
|
1342 |
+
prompt += f"User: {content}\n"
|
1343 |
+
elif role == "assistant":
|
1344 |
+
prompt += f"Assistant: {content}\n"
|
1345 |
+
prompt += "Assistant: "
|
1346 |
+
return prompt
|
1347 |
+
|
1348 |
+
@app.post("/v1/chat/completions",
|
1349 |
+
response_model=ChatCompletionResponse,
|
1350 |
+
summary="OpenAI-Compatible Chat Completions",
|
1351 |
+
description="Proxy endpoint to generate chat completions using llama-server with gemma-3-12b-it model, compatible with OpenAI's API.",
|
1352 |
+
tags=["Chat"],
|
1353 |
+
responses={
|
1354 |
+
200: {"description": "Chat completion response", "model": ChatCompletionResponse},
|
1355 |
+
400: {"description": "Invalid request parameters"},
|
1356 |
+
500: {"description": "External llama-server error"},
|
1357 |
+
504: {"description": "External llama-server timeout"}
|
1358 |
+
})
|
1359 |
+
async def chat_completions(
|
1360 |
+
request: Request,
|
1361 |
+
body: ChatCompletionRequest
|
1362 |
+
):
|
1363 |
+
logger.info("Processing chat completion request", extra={
|
1364 |
+
"endpoint": "/v1/chat/completions",
|
1365 |
+
"model": body.model,
|
1366 |
+
"messages_count": len(body.messages),
|
1367 |
+
"client_ip": request.client.host
|
1368 |
+
})
|
1369 |
+
|
1370 |
+
# Validate messages
|
1371 |
+
if not body.messages:
|
1372 |
+
raise HTTPException(status_code=400, detail="Messages cannot be empty")
|
1373 |
+
|
1374 |
+
# Prepare payload for llama-server
|
1375 |
+
# Adjust this based on the actual llama-server API requirements
|
1376 |
+
llama_payload = {
|
1377 |
+
"prompt": messages_to_prompt(body.messages),
|
1378 |
+
"max_tokens": body.max_tokens if body.max_tokens is not None else 512,
|
1379 |
+
"temperature": body.temperature,
|
1380 |
+
"top_p": body.top_p,
|
1381 |
+
"stream": body.stream
|
1382 |
+
}
|
1383 |
+
|
1384 |
+
external_url = f"{os.getenv('DWANI_AI_LLM_URL')}/v1/chat/completions"
|
1385 |
+
|
1386 |
+
# llama-server endpoint (adjust if different)
|
1387 |
+
start_time = time()
|
1388 |
+
|
1389 |
+
try:
|
1390 |
+
response = requests.post(
|
1391 |
+
external_url,
|
1392 |
+
json=llama_payload,
|
1393 |
+
headers={
|
1394 |
+
"accept": "application/json",
|
1395 |
+
"Content-Type": "application/json"
|
1396 |
+
},
|
1397 |
+
timeout=30
|
1398 |
+
)
|
1399 |
+
response.raise_for_status()
|
1400 |
+
|
1401 |
+
# Parse llama-server response
|
1402 |
+
response_data = response.json()
|
1403 |
+
|
1404 |
+
# Transform llama-server response to OpenAI-compatible format
|
1405 |
+
# Adjust based on actual response structure
|
1406 |
+
completion_text = response_data.get("choices", [{}])[0].get("text", "")
|
1407 |
+
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", "stop")
|
1408 |
+
|
1409 |
+
# Generate a unique ID for the response
|
1410 |
+
completion_id = f"chatcmpl-{int(time.time())}"
|
1411 |
+
|
1412 |
+
# Build OpenAI-compatible response
|
1413 |
+
openai_response = ChatCompletionResponse(
|
1414 |
+
id=completion_id,
|
1415 |
+
created=int(time.time()),
|
1416 |
+
model=body.model,
|
1417 |
+
choices=[
|
1418 |
+
ChatCompletionChoice(
|
1419 |
+
index=0,
|
1420 |
+
message={
|
1421 |
+
"role": "assistant",
|
1422 |
+
"content": completion_text.strip()
|
1423 |
+
},
|
1424 |
+
finish_reason=finish_reason
|
1425 |
+
)
|
1426 |
+
],
|
1427 |
+
usage={
|
1428 |
+
"prompt_tokens": len(llama_payload["prompt"].split()), # Rough estimate
|
1429 |
+
"completion_tokens": len(completion_text.split()), # Rough estimate
|
1430 |
+
"total_tokens": len(llama_payload["prompt"].split()) + len(completion_text.split())
|
1431 |
+
}
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
|
1435 |
+
"response_length": len(completion_text)
|
1436 |
+
})
|
1437 |
+
return openai_response
|
1438 |
+
|
1439 |
+
except requests.Timeout:
|
1440 |
+
logger.error("llama-server request timed out")
|
1441 |
+
raise HTTPException(status_code=504, detail="llama-server timeout")
|
1442 |
+
except requests.RequestException as e:
|
1443 |
+
logger.error(f"llama-server request failed: {str(e)}")
|
1444 |
+
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
1445 |
+
except ValueError as e:
|
1446 |
+
logger.error(f"Invalid JSON response from llama-server: {str(e)}")
|
1447 |
+
raise HTTPException(status_code=500, detail="Invalid response format from llama-server")
|
1448 |
+
except Exception as e:
|
1449 |
+
logger.error(f"Unexpected error: {str(e)}")
|
1450 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
1451 |
+
|
1452 |
|
1453 |
|
1454 |
if __name__ == "__main__":
|