Update app.py
Browse files
app.py
CHANGED
@@ -53,7 +53,6 @@ if not API_KEY:
|
|
53 |
|
54 |
# Pydantic models for request/response validation
|
55 |
class ChatConfig(BaseModel):
|
56 |
-
# Removed api_key field; only model and temperature are received
|
57 |
model: str = "mistralai/mistral-small-3.1-24b-instruct:free"
|
58 |
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=1.0)
|
59 |
|
@@ -82,7 +81,6 @@ class InitResponse(BaseModel):
|
|
82 |
session_id: str
|
83 |
status: str
|
84 |
|
85 |
-
# Simple HTML interface
|
86 |
@app.get("/", response_class=HTMLResponse)
|
87 |
async def root():
|
88 |
"""
|
@@ -404,9 +402,8 @@ async def root():
|
|
404 |
</body>
|
405 |
</html>
|
406 |
"""
|
407 |
-
return html_content
|
408 |
|
409 |
-
# Health check endpoint
|
410 |
@app.get("/health")
|
411 |
async def health_check():
|
412 |
"""Health check endpoint"""
|
@@ -419,11 +416,9 @@ async def initialize_chat(config: ChatConfig):
|
|
419 |
# Generate a session ID
|
420 |
session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
421 |
|
422 |
-
# If the environment variable is missing, raise an error
|
423 |
if not API_KEY:
|
424 |
raise HTTPException(status_code=400, detail="The OPENROUTE_API environment variable is not set.")
|
425 |
|
426 |
-
# Initialize the chat instance
|
427 |
chat = EnhancedRecursiveThinkingChat(
|
428 |
api_key=API_KEY,
|
429 |
model=config.model,
|
@@ -434,7 +429,6 @@ async def initialize_chat(config: ChatConfig):
|
|
434 |
"created_at": datetime.now().isoformat(),
|
435 |
"model": config.model
|
436 |
}
|
437 |
-
|
438 |
return {"session_id": session_id, "status": "initialized"}
|
439 |
except Exception as e:
|
440 |
logger.error(f"Error initializing chat: {str(e)}")
|
@@ -453,8 +447,21 @@ async def send_message_original(request: MessageRequest):
|
|
453 |
# Make a direct call to the LLM without recursion logic
|
454 |
messages = [{"role": "user", "content": request.message}]
|
455 |
response_data = chat._call_api(messages, temperature=chat.temperature, stream=False)
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
return {"response": original_text.strip()}
|
459 |
except Exception as e:
|
460 |
logger.error(f"Error getting original response: {str(e)}")
|
@@ -462,20 +469,21 @@ async def send_message_original(request: MessageRequest):
|
|
462 |
|
463 |
@app.post("/api/send_message")
|
464 |
async def send_message(request: MessageRequest):
|
465 |
-
"""
|
|
|
|
|
|
|
466 |
try:
|
467 |
if request.session_id not in chat_instances:
|
468 |
raise HTTPException(status_code=404, detail="Session not found")
|
469 |
|
470 |
chat = chat_instances[request.session_id]["chat"]
|
471 |
|
472 |
-
# Override class parameters if provided
|
473 |
original_thinking_fn = chat._determine_thinking_rounds
|
474 |
original_alternatives_fn = chat._generate_alternatives
|
475 |
original_temperature = getattr(chat, "temperature", 0.7)
|
476 |
|
477 |
if request.thinking_rounds is not None:
|
478 |
-
# Override the thinking rounds determination
|
479 |
chat._determine_thinking_rounds = lambda _: request.thinking_rounds
|
480 |
|
481 |
if request.alternatives_per_round is not None:
|
@@ -483,18 +491,16 @@ async def send_message(request: MessageRequest):
|
|
483 |
return original_alternatives_fn(base_response, prompt, request.alternatives_per_round)
|
484 |
chat._generate_alternatives = modified_generate_alternatives
|
485 |
|
486 |
-
# Override temperature if provided
|
487 |
if request.temperature is not None:
|
488 |
setattr(chat, "temperature", request.temperature)
|
489 |
|
490 |
-
# Process the message
|
491 |
logger.info(f"Processing message for session {request.session_id}")
|
492 |
start_time = datetime.now()
|
493 |
result = chat.think_and_respond(request.message, verbose=True)
|
494 |
processing_time = (datetime.now() - start_time).total_seconds()
|
495 |
logger.info(f"Message processed in {processing_time:.2f} seconds")
|
496 |
|
497 |
-
# Restore original
|
498 |
chat._determine_thinking_rounds = original_thinking_fn
|
499 |
chat._generate_alternatives = original_alternatives_fn
|
500 |
if request.temperature is not None:
|
@@ -513,21 +519,19 @@ async def send_message(request: MessageRequest):
|
|
513 |
|
514 |
@app.post("/api/save")
|
515 |
async def save_conversation(request: SaveRequest):
|
516 |
-
"""Save the conversation or the full thinking log"""
|
517 |
try:
|
518 |
if request.session_id not in chat_instances:
|
519 |
raise HTTPException(status_code=404, detail="Session not found")
|
520 |
|
521 |
chat = chat_instances[request.session_id]["chat"]
|
522 |
|
523 |
-
# Generate default filename if not provided
|
524 |
filename = request.filename
|
525 |
if filename is None:
|
526 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
527 |
log_type = "full_log" if request.full_log else "conversation"
|
528 |
filename = f"recthink_{log_type}_{timestamp}.json"
|
529 |
|
530 |
-
# Make sure the output directory exists
|
531 |
os.makedirs("logs", exist_ok=True)
|
532 |
file_path = os.path.join("logs", filename)
|
533 |
|
@@ -543,31 +547,28 @@ async def save_conversation(request: SaveRequest):
|
|
543 |
|
544 |
@app.get("/api/sessions", response_model=SessionResponse)
|
545 |
async def list_sessions():
|
546 |
-
"""List all active chat sessions"""
|
547 |
sessions = []
|
548 |
for session_id, session_data in chat_instances.items():
|
549 |
chat = session_data["chat"]
|
550 |
-
message_count = len(chat.conversation_history) // 2
|
551 |
-
|
552 |
sessions.append(SessionInfo(
|
553 |
session_id=session_id,
|
554 |
message_count=message_count,
|
555 |
created_at=session_data["created_at"],
|
556 |
model=session_data["model"]
|
557 |
))
|
558 |
-
|
559 |
return {"sessions": sessions}
|
560 |
|
561 |
@app.get("/api/sessions/{session_id}")
|
562 |
async def get_session(session_id: str):
|
563 |
-
"""Get details for a specific chat session"""
|
564 |
if session_id not in chat_instances:
|
565 |
raise HTTPException(status_code=404, detail="Session not found")
|
566 |
|
567 |
session_data = chat_instances[session_id]
|
568 |
chat = session_data["chat"]
|
569 |
|
570 |
-
# Extract conversation history
|
571 |
conversation = []
|
572 |
for i in range(0, len(chat.conversation_history), 2):
|
573 |
if i+1 < len(chat.conversation_history):
|
@@ -586,14 +587,12 @@ async def get_session(session_id: str):
|
|
586 |
|
587 |
@app.delete("/api/sessions/{session_id}")
|
588 |
async def delete_session(session_id: str):
|
589 |
-
"""Delete a chat session"""
|
590 |
if session_id not in chat_instances:
|
591 |
raise HTTPException(status_code=404, detail="Session not found")
|
592 |
-
|
593 |
del chat_instances[session_id]
|
594 |
return {"status": "deleted", "session_id": session_id}
|
595 |
|
596 |
-
# WebSocket connection manager
|
597 |
class ConnectionManager:
|
598 |
def __init__(self):
|
599 |
self.active_connections: Dict[str, WebSocket] = {}
|
@@ -612,7 +611,6 @@ class ConnectionManager:
|
|
612 |
|
613 |
manager = ConnectionManager()
|
614 |
|
615 |
-
# WebSocket for streaming the thinking process
|
616 |
@app.websocket("/ws/{session_id}")
|
617 |
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
618 |
try:
|
@@ -624,40 +622,31 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
624 |
return
|
625 |
|
626 |
chat = chat_instances[session_id]["chat"]
|
627 |
-
|
628 |
-
# Set up a custom callback to stream the thinking process
|
629 |
original_call_api = chat._call_api
|
630 |
|
631 |
async def stream_callback(chunk):
|
632 |
await manager.send_json(session_id, {"type": "chunk", "content": chunk})
|
633 |
|
634 |
-
# Override the _call_api method to also send updates via WebSocket
|
635 |
def ws_call_api(messages, temperature=0.7, stream=True):
|
636 |
result = original_call_api(messages, temperature, stream)
|
637 |
-
# Send the chunk via WebSocket if we're streaming
|
638 |
if stream:
|
639 |
asyncio.create_task(stream_callback(result))
|
640 |
return result
|
641 |
|
642 |
-
# Replace the method temporarily
|
643 |
chat._call_api = ws_call_api
|
644 |
|
645 |
-
# Wait for messages from the client
|
646 |
while True:
|
647 |
data = await websocket.receive_text()
|
648 |
message_data = json.loads(data)
|
649 |
|
650 |
if message_data["type"] == "message":
|
651 |
-
# Process the message
|
652 |
start_time = datetime.now()
|
653 |
|
654 |
try:
|
655 |
-
# Get parameters if they exist
|
656 |
thinking_rounds = message_data.get("thinking_rounds", None)
|
657 |
alternatives_per_round = message_data.get("alternatives_per_round", None)
|
658 |
temperature = message_data.get("temperature", None)
|
659 |
|
660 |
-
# Override if needed
|
661 |
original_thinking_fn = chat._determine_thinking_rounds
|
662 |
original_alternatives_fn = chat._generate_alternatives
|
663 |
original_temperature = getattr(chat, "temperature", 0.7)
|
@@ -673,24 +662,20 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
673 |
if temperature is not None:
|
674 |
setattr(chat, "temperature", temperature)
|
675 |
|
676 |
-
# Send a status message that we've started processing
|
677 |
await manager.send_json(session_id, {
|
678 |
"type": "status",
|
679 |
"status": "processing",
|
680 |
"message": "Starting recursive thinking process..."
|
681 |
})
|
682 |
|
683 |
-
# Process the message with chain-of-thought
|
684 |
result = chat.think_and_respond(message_data["content"], verbose=True)
|
685 |
processing_time = (datetime.now() - start_time).total_seconds()
|
686 |
|
687 |
-
# Restore original functions
|
688 |
chat._determine_thinking_rounds = original_thinking_fn
|
689 |
chat._generate_alternatives = original_alternatives_fn
|
690 |
if temperature is not None:
|
691 |
setattr(chat, "temperature", original_temperature)
|
692 |
|
693 |
-
# Send the final result
|
694 |
await manager.send_json(session_id, {
|
695 |
"type": "final",
|
696 |
"response": result["response"],
|
@@ -706,7 +691,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
706 |
"type": "error",
|
707 |
"error": error_msg
|
708 |
})
|
709 |
-
|
710 |
except WebSocketDisconnect:
|
711 |
logger.info(f"WebSocket disconnected: {session_id}")
|
712 |
manager.disconnect(session_id)
|
@@ -718,14 +702,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
718 |
except:
|
719 |
pass
|
720 |
finally:
|
721 |
-
# Restore the original method if needed
|
722 |
if 'chat' in locals() and 'original_call_api' in locals():
|
723 |
chat._call_api = original_call_api
|
724 |
|
725 |
-
# Make sure to disconnect
|
726 |
manager.disconnect(session_id)
|
727 |
|
728 |
-
# Use port 7860 for Hugging Face Spaces
|
729 |
if __name__ == "__main__":
|
730 |
port = 7860
|
731 |
print(f"Starting server on port {port}")
|
|
|
53 |
|
54 |
# Pydantic models for request/response validation
|
55 |
class ChatConfig(BaseModel):
|
|
|
56 |
model: str = "mistralai/mistral-small-3.1-24b-instruct:free"
|
57 |
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=1.0)
|
58 |
|
|
|
81 |
session_id: str
|
82 |
status: str
|
83 |
|
|
|
84 |
@app.get("/", response_class=HTMLResponse)
|
85 |
async def root():
|
86 |
"""
|
|
|
402 |
</body>
|
403 |
</html>
|
404 |
"""
|
405 |
+
return HTMLResponse(content=html_content)
|
406 |
|
|
|
407 |
@app.get("/health")
|
408 |
async def health_check():
|
409 |
"""Health check endpoint"""
|
|
|
416 |
# Generate a session ID
|
417 |
session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
418 |
|
|
|
419 |
if not API_KEY:
|
420 |
raise HTTPException(status_code=400, detail="The OPENROUTE_API environment variable is not set.")
|
421 |
|
|
|
422 |
chat = EnhancedRecursiveThinkingChat(
|
423 |
api_key=API_KEY,
|
424 |
model=config.model,
|
|
|
429 |
"created_at": datetime.now().isoformat(),
|
430 |
"model": config.model
|
431 |
}
|
|
|
432 |
return {"session_id": session_id, "status": "initialized"}
|
433 |
except Exception as e:
|
434 |
logger.error(f"Error initializing chat: {str(e)}")
|
|
|
447 |
# Make a direct call to the LLM without recursion logic
|
448 |
messages = [{"role": "user", "content": request.message}]
|
449 |
response_data = chat._call_api(messages, temperature=chat.temperature, stream=False)
|
450 |
+
|
451 |
+
# The structure of response_data depends on the underlying LLM.
|
452 |
+
# We'll try to handle both "message" and "text" keys as possible.
|
453 |
+
if isinstance(response_data, dict) and "choices" in response_data:
|
454 |
+
first_choice = response_data["choices"][0]
|
455 |
+
if "message" in first_choice and "content" in first_choice["message"]:
|
456 |
+
original_text = first_choice["message"]["content"]
|
457 |
+
elif "text" in first_choice:
|
458 |
+
original_text = first_choice["text"]
|
459 |
+
else:
|
460 |
+
original_text = str(first_choice)
|
461 |
+
else:
|
462 |
+
# If for some reason the response is not in the expected format, just convert to string
|
463 |
+
original_text = str(response_data)
|
464 |
+
|
465 |
return {"response": original_text.strip()}
|
466 |
except Exception as e:
|
467 |
logger.error(f"Error getting original response: {str(e)}")
|
|
|
469 |
|
470 |
@app.post("/api/send_message")
|
471 |
async def send_message(request: MessageRequest):
|
472 |
+
"""
|
473 |
+
Send a message and get a response with the chain-of-thought process (HTTP-based, not streaming).
|
474 |
+
Primarily left here for completeness, but the user-facing code calls the WebSocket for streaming.
|
475 |
+
"""
|
476 |
try:
|
477 |
if request.session_id not in chat_instances:
|
478 |
raise HTTPException(status_code=404, detail="Session not found")
|
479 |
|
480 |
chat = chat_instances[request.session_id]["chat"]
|
481 |
|
|
|
482 |
original_thinking_fn = chat._determine_thinking_rounds
|
483 |
original_alternatives_fn = chat._generate_alternatives
|
484 |
original_temperature = getattr(chat, "temperature", 0.7)
|
485 |
|
486 |
if request.thinking_rounds is not None:
|
|
|
487 |
chat._determine_thinking_rounds = lambda _: request.thinking_rounds
|
488 |
|
489 |
if request.alternatives_per_round is not None:
|
|
|
491 |
return original_alternatives_fn(base_response, prompt, request.alternatives_per_round)
|
492 |
chat._generate_alternatives = modified_generate_alternatives
|
493 |
|
|
|
494 |
if request.temperature is not None:
|
495 |
setattr(chat, "temperature", request.temperature)
|
496 |
|
|
|
497 |
logger.info(f"Processing message for session {request.session_id}")
|
498 |
start_time = datetime.now()
|
499 |
result = chat.think_and_respond(request.message, verbose=True)
|
500 |
processing_time = (datetime.now() - start_time).total_seconds()
|
501 |
logger.info(f"Message processed in {processing_time:.2f} seconds")
|
502 |
|
503 |
+
# Restore original
|
504 |
chat._determine_thinking_rounds = original_thinking_fn
|
505 |
chat._generate_alternatives = original_alternatives_fn
|
506 |
if request.temperature is not None:
|
|
|
519 |
|
520 |
@app.post("/api/save")
|
521 |
async def save_conversation(request: SaveRequest):
|
522 |
+
"""Save the conversation or the full thinking log."""
|
523 |
try:
|
524 |
if request.session_id not in chat_instances:
|
525 |
raise HTTPException(status_code=404, detail="Session not found")
|
526 |
|
527 |
chat = chat_instances[request.session_id]["chat"]
|
528 |
|
|
|
529 |
filename = request.filename
|
530 |
if filename is None:
|
531 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
532 |
log_type = "full_log" if request.full_log else "conversation"
|
533 |
filename = f"recthink_{log_type}_{timestamp}.json"
|
534 |
|
|
|
535 |
os.makedirs("logs", exist_ok=True)
|
536 |
file_path = os.path.join("logs", filename)
|
537 |
|
|
|
547 |
|
548 |
@app.get("/api/sessions", response_model=SessionResponse)
|
549 |
async def list_sessions():
|
550 |
+
"""List all active chat sessions."""
|
551 |
sessions = []
|
552 |
for session_id, session_data in chat_instances.items():
|
553 |
chat = session_data["chat"]
|
554 |
+
message_count = len(chat.conversation_history) // 2
|
|
|
555 |
sessions.append(SessionInfo(
|
556 |
session_id=session_id,
|
557 |
message_count=message_count,
|
558 |
created_at=session_data["created_at"],
|
559 |
model=session_data["model"]
|
560 |
))
|
|
|
561 |
return {"sessions": sessions}
|
562 |
|
563 |
@app.get("/api/sessions/{session_id}")
|
564 |
async def get_session(session_id: str):
|
565 |
+
"""Get details for a specific chat session."""
|
566 |
if session_id not in chat_instances:
|
567 |
raise HTTPException(status_code=404, detail="Session not found")
|
568 |
|
569 |
session_data = chat_instances[session_id]
|
570 |
chat = session_data["chat"]
|
571 |
|
|
|
572 |
conversation = []
|
573 |
for i in range(0, len(chat.conversation_history), 2):
|
574 |
if i+1 < len(chat.conversation_history):
|
|
|
587 |
|
588 |
@app.delete("/api/sessions/{session_id}")
|
589 |
async def delete_session(session_id: str):
|
590 |
+
"""Delete a chat session."""
|
591 |
if session_id not in chat_instances:
|
592 |
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
593 |
del chat_instances[session_id]
|
594 |
return {"status": "deleted", "session_id": session_id}
|
595 |
|
|
|
596 |
class ConnectionManager:
|
597 |
def __init__(self):
|
598 |
self.active_connections: Dict[str, WebSocket] = {}
|
|
|
611 |
|
612 |
manager = ConnectionManager()
|
613 |
|
|
|
614 |
@app.websocket("/ws/{session_id}")
|
615 |
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
616 |
try:
|
|
|
622 |
return
|
623 |
|
624 |
chat = chat_instances[session_id]["chat"]
|
|
|
|
|
625 |
original_call_api = chat._call_api
|
626 |
|
627 |
async def stream_callback(chunk):
|
628 |
await manager.send_json(session_id, {"type": "chunk", "content": chunk})
|
629 |
|
|
|
630 |
def ws_call_api(messages, temperature=0.7, stream=True):
|
631 |
result = original_call_api(messages, temperature, stream)
|
|
|
632 |
if stream:
|
633 |
asyncio.create_task(stream_callback(result))
|
634 |
return result
|
635 |
|
|
|
636 |
chat._call_api = ws_call_api
|
637 |
|
|
|
638 |
while True:
|
639 |
data = await websocket.receive_text()
|
640 |
message_data = json.loads(data)
|
641 |
|
642 |
if message_data["type"] == "message":
|
|
|
643 |
start_time = datetime.now()
|
644 |
|
645 |
try:
|
|
|
646 |
thinking_rounds = message_data.get("thinking_rounds", None)
|
647 |
alternatives_per_round = message_data.get("alternatives_per_round", None)
|
648 |
temperature = message_data.get("temperature", None)
|
649 |
|
|
|
650 |
original_thinking_fn = chat._determine_thinking_rounds
|
651 |
original_alternatives_fn = chat._generate_alternatives
|
652 |
original_temperature = getattr(chat, "temperature", 0.7)
|
|
|
662 |
if temperature is not None:
|
663 |
setattr(chat, "temperature", temperature)
|
664 |
|
|
|
665 |
await manager.send_json(session_id, {
|
666 |
"type": "status",
|
667 |
"status": "processing",
|
668 |
"message": "Starting recursive thinking process..."
|
669 |
})
|
670 |
|
|
|
671 |
result = chat.think_and_respond(message_data["content"], verbose=True)
|
672 |
processing_time = (datetime.now() - start_time).total_seconds()
|
673 |
|
|
|
674 |
chat._determine_thinking_rounds = original_thinking_fn
|
675 |
chat._generate_alternatives = original_alternatives_fn
|
676 |
if temperature is not None:
|
677 |
setattr(chat, "temperature", original_temperature)
|
678 |
|
|
|
679 |
await manager.send_json(session_id, {
|
680 |
"type": "final",
|
681 |
"response": result["response"],
|
|
|
691 |
"type": "error",
|
692 |
"error": error_msg
|
693 |
})
|
|
|
694 |
except WebSocketDisconnect:
|
695 |
logger.info(f"WebSocket disconnected: {session_id}")
|
696 |
manager.disconnect(session_id)
|
|
|
702 |
except:
|
703 |
pass
|
704 |
finally:
|
|
|
705 |
if 'chat' in locals() and 'original_call_api' in locals():
|
706 |
chat._call_api = original_call_api
|
707 |
|
|
|
708 |
manager.disconnect(session_id)
|
709 |
|
|
|
710 |
if __name__ == "__main__":
|
711 |
port = 7860
|
712 |
print(f"Starting server on port {port}")
|