Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
d92ab0e
1
Parent(s):
19758dc
add-chat-completio
Browse files- requirements.txt +2 -1
- src/server/main.py +52 -115
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ pillow
|
|
8 |
pyjwt
|
9 |
sqlalchemy
|
10 |
passlib[bcrypt]
|
11 |
-
pycryptodome
|
|
|
|
8 |
pyjwt
|
9 |
sqlalchemy
|
10 |
passlib[bcrypt]
|
11 |
+
pycryptodome
|
12 |
+
openai
|
src/server/main.py
CHANGED
@@ -1283,10 +1283,12 @@ async def indic_custom_prompt_kannada_pdf(
|
|
1283 |
temp_file.close()
|
1284 |
from typing import List, Optional, Dict, Any
|
1285 |
|
|
|
|
|
1286 |
class ChatCompletionRequest(BaseModel):
|
1287 |
-
model: str = Field(default="gemma-3-12b-it", description="Model identifier
|
1288 |
-
messages: List[Dict[str, str]] = Field(..., description="List of messages
|
1289 |
-
max_tokens: Optional[int] = Field(None, description="Maximum
|
1290 |
temperature: Optional[float] = Field(1.0, description="Sampling temperature")
|
1291 |
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
1292 |
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
@@ -1305,151 +1307,86 @@ class ChatCompletionResponse(BaseModel):
|
|
1305 |
choices: List[ChatCompletionChoice]
|
1306 |
usage: Optional[Dict[str, int]] = None
|
1307 |
|
1308 |
-
|
1309 |
-
|
1310 |
-
|
1311 |
-
|
1312 |
-
|
1313 |
-
|
1314 |
-
"model": "gemma-3-12b-it",
|
1315 |
-
"choices": [
|
1316 |
-
{
|
1317 |
-
"index": 0,
|
1318 |
-
"message": {
|
1319 |
-
"role": "assistant",
|
1320 |
-
"content": "Hello! How can I assist you today?"
|
1321 |
-
},
|
1322 |
-
"finish_reason": "stop"
|
1323 |
-
}
|
1324 |
-
],
|
1325 |
-
"usage": {
|
1326 |
-
"prompt_tokens": 10,
|
1327 |
-
"completion_tokens": 10,
|
1328 |
-
"total_tokens": 20
|
1329 |
-
}
|
1330 |
-
}
|
1331 |
-
}
|
1332 |
-
|
1333 |
-
# Helper function to convert OpenAI messages to a prompt for llama-server
|
1334 |
-
def messages_to_prompt(messages: List[Dict[str, str]]) -> str:
|
1335 |
-
prompt = ""
|
1336 |
-
for msg in messages:
|
1337 |
-
role = msg.get("role", "user")
|
1338 |
-
content = msg.get("content", "")
|
1339 |
-
if role == "system":
|
1340 |
-
prompt += f"System: {content}\n"
|
1341 |
-
elif role == "user":
|
1342 |
-
prompt += f"User: {content}\n"
|
1343 |
-
elif role == "assistant":
|
1344 |
-
prompt += f"Assistant: {content}\n"
|
1345 |
-
prompt += "Assistant: "
|
1346 |
-
return prompt
|
1347 |
|
1348 |
@app.post("/v1/chat/completions",
|
1349 |
response_model=ChatCompletionResponse,
|
1350 |
summary="OpenAI-Compatible Chat Completions",
|
1351 |
-
description="
|
1352 |
-
tags=["Chat"]
|
1353 |
-
|
1354 |
-
|
1355 |
-
400: {"description": "Invalid request parameters"},
|
1356 |
-
500: {"description": "External llama-server error"},
|
1357 |
-
504: {"description": "External llama-server timeout"}
|
1358 |
-
})
|
1359 |
-
async def chat_completions(
|
1360 |
-
request: Request,
|
1361 |
-
body: ChatCompletionRequest
|
1362 |
-
):
|
1363 |
-
logger.info("Processing chat completion request", extra={
|
1364 |
"endpoint": "/v1/chat/completions",
|
1365 |
"model": body.model,
|
1366 |
-
"
|
1367 |
"client_ip": request.client.host
|
1368 |
})
|
1369 |
|
1370 |
# Validate messages
|
1371 |
if not body.messages:
|
|
|
1372 |
raise HTTPException(status_code=400, detail="Messages cannot be empty")
|
1373 |
|
1374 |
-
# Prepare payload for llama-server
|
1375 |
-
# Adjust this based on the actual llama-server API requirements
|
1376 |
-
llama_payload = {
|
1377 |
-
"prompt": messages_to_prompt(body.messages),
|
1378 |
-
"max_tokens": body.max_tokens if body.max_tokens is not None else 512,
|
1379 |
-
"temperature": body.temperature,
|
1380 |
-
"top_p": body.top_p,
|
1381 |
-
"stream": body.stream
|
1382 |
-
}
|
1383 |
-
|
1384 |
-
external_url = f"{os.getenv('DWANI_AI_LLM_URL')}/v1/chat/completions"
|
1385 |
-
|
1386 |
-
# llama-server endpoint (adjust if different)
|
1387 |
start_time = time()
|
1388 |
|
1389 |
try:
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
)
|
1399 |
-
response.raise_for_status()
|
1400 |
-
|
1401 |
-
# Parse llama-server response
|
1402 |
-
response_data = response.json()
|
1403 |
-
|
1404 |
-
# Transform llama-server response to OpenAI-compatible format
|
1405 |
-
# Adjust based on actual response structure
|
1406 |
-
completion_text = response_data.get("choices", [{}])[0].get("text", "")
|
1407 |
-
finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", "stop")
|
1408 |
|
1409 |
-
#
|
1410 |
-
|
|
|
|
|
1411 |
|
1412 |
-
#
|
1413 |
openai_response = ChatCompletionResponse(
|
1414 |
-
id=
|
1415 |
-
created=
|
1416 |
-
model=
|
1417 |
choices=[
|
1418 |
ChatCompletionChoice(
|
1419 |
-
index=
|
1420 |
message={
|
1421 |
-
"role":
|
1422 |
-
"content":
|
1423 |
},
|
1424 |
-
finish_reason=finish_reason
|
1425 |
-
)
|
1426 |
],
|
1427 |
-
usage=
|
1428 |
-
|
1429 |
-
|
1430 |
-
|
1431 |
-
|
|
|
|
|
1432 |
)
|
1433 |
|
1434 |
logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
|
1435 |
-
"response_length": len(
|
1436 |
})
|
1437 |
return openai_response
|
1438 |
|
1439 |
-
except
|
1440 |
-
logger.error("llama-server
|
1441 |
-
|
1442 |
-
|
1443 |
-
logger.error(f"llama-server request failed: {str(e)}")
|
1444 |
-
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}")
|
1445 |
-
except ValueError as e:
|
1446 |
-
logger.error(f"Invalid JSON response from llama-server: {str(e)}")
|
1447 |
-
raise HTTPException(status_code=500, detail="Invalid response format from llama-server")
|
1448 |
except Exception as e:
|
1449 |
-
logger.error(f"
|
1450 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
1451 |
-
|
1452 |
-
|
1453 |
|
1454 |
if __name__ == "__main__":
|
1455 |
# Ensure EXTERNAL_API_BASE_URL is set
|
|
|
1283 |
temp_file.close()
|
1284 |
from typing import List, Optional, Dict, Any
|
1285 |
|
1286 |
+
from openai import AsyncOpenAI, OpenAIError
|
1287 |
+
# OpenAI-compatible request model
|
1288 |
class ChatCompletionRequest(BaseModel):
|
1289 |
+
model: str = Field(default="gemma-3-12b-it", description="Model identifier")
|
1290 |
+
messages: List[Dict[str, str]] = Field(..., description="List of messages")
|
1291 |
+
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
1292 |
temperature: Optional[float] = Field(1.0, description="Sampling temperature")
|
1293 |
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
1294 |
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
|
|
1307 |
choices: List[ChatCompletionChoice]
|
1308 |
usage: Optional[Dict[str, int]] = None
|
1309 |
|
1310 |
+
# Initialize OpenAI client
|
1311 |
+
openai_client = AsyncOpenAI(
|
1312 |
+
base_url=os.getenv("DWANI_AI_LLM_URL"), # e.g., https://<ngrok-url>.ngrok.io or http://localhost:7860
|
1313 |
+
api_key=os.getenv("DWANI_AI_LLM_API_KEY", ""), # Optional API key
|
1314 |
+
timeout=30.0
|
1315 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1316 |
|
1317 |
@app.post("/v1/chat/completions",
|
1318 |
response_model=ChatCompletionResponse,
|
1319 |
summary="OpenAI-Compatible Chat Completions",
|
1320 |
+
description="Proxies chat completions to llama-server using OpenAI API format.",
|
1321 |
+
tags=["Chat"])
|
1322 |
+
async def chat_completions(request: Request, body: ChatCompletionRequest):
|
1323 |
+
logger.info("Received chat completion request", extra={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1324 |
"endpoint": "/v1/chat/completions",
|
1325 |
"model": body.model,
|
1326 |
+
"messages": body.messages,
|
1327 |
"client_ip": request.client.host
|
1328 |
})
|
1329 |
|
1330 |
# Validate messages
|
1331 |
if not body.messages:
|
1332 |
+
logger.error("Messages field is empty", extra={"client_ip": request.client.host})
|
1333 |
raise HTTPException(status_code=400, detail="Messages cannot be empty")
|
1334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1335 |
start_time = time()
|
1336 |
|
1337 |
try:
|
1338 |
+
# Proxy request to llama-server using OpenAI client
|
1339 |
+
response = await openai_client.chat.completions.create(
|
1340 |
+
model=body.model,
|
1341 |
+
messages=body.messages,
|
1342 |
+
max_tokens=body.max_tokens,
|
1343 |
+
temperature=body.temperature,
|
1344 |
+
top_p=body.top_p,
|
1345 |
+
stream=body.stream
|
1346 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1347 |
|
1348 |
+
# Streaming not supported in this simple version
|
1349 |
+
if body.stream:
|
1350 |
+
logger.error("Streaming requested but not supported")
|
1351 |
+
raise HTTPException(status_code=400, detail="Streaming not supported")
|
1352 |
|
1353 |
+
# Map OpenAI response to Pydantic model
|
1354 |
openai_response = ChatCompletionResponse(
|
1355 |
+
id=response.id,
|
1356 |
+
created=response.created,
|
1357 |
+
model=response.model,
|
1358 |
choices=[
|
1359 |
ChatCompletionChoice(
|
1360 |
+
index=choice.index,
|
1361 |
message={
|
1362 |
+
"role": choice.message.role,
|
1363 |
+
"content": choice.message.content
|
1364 |
},
|
1365 |
+
finish_reason=choice.finish_reason
|
1366 |
+
) for choice in response.choices
|
1367 |
],
|
1368 |
+
usage=(
|
1369 |
+
{
|
1370 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
1371 |
+
"completion_tokens": response.usage.completion_tokens,
|
1372 |
+
"total_tokens": response.usage.total_tokens
|
1373 |
+
} if response.usage else None
|
1374 |
+
)
|
1375 |
)
|
1376 |
|
1377 |
logger.info(f"Chat completion successful in {time() - start_time:.2f} seconds", extra={
|
1378 |
+
"response_length": len(response.choices[0].message.content if response.choices else 0)
|
1379 |
})
|
1380 |
return openai_response
|
1381 |
|
1382 |
+
except OpenAIError as e:
|
1383 |
+
logger.error(f"llama-server error: {str(e)}", extra={"client_ip": request.client.host})
|
1384 |
+
status_code = 504 if "timeout" in str(e).lower() else 500
|
1385 |
+
raise HTTPException(status_code=status_code, detail=f"llama-server error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
1386 |
except Exception as e:
|
1387 |
+
logger.error(f"Internal error: {str(e)}", extra={"client_ip": request.client.host})
|
1388 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
1389 |
+
|
|
|
1390 |
|
1391 |
if __name__ == "__main__":
|
1392 |
# Ensure EXTERNAL_API_BASE_URL is set
|