|
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect |
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
import uvicorn |
|
import json |
|
import os |
|
import asyncio |
|
from datetime import datetime |
|
from typing import List, Dict, Optional, Any |
|
import logging |
|
import uuid |
|
|
|
|
|
try: |
|
from recursive_thinking_ai import EnhancedRecursiveThinkingChat |
|
except ModuleNotFoundError: |
|
|
|
import sys |
|
sys.path.append('.') |
|
from recursive_thinking_ai import EnhancedRecursiveThinkingChat |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI( |
|
title="Chain-of-Recursive-Thoughts: TEST", |
|
description="https://github.com/PhialsBasement/Chain-of-Recursive-Thoughts", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
chat_instances = {} |
|
|
|
|
|
API_KEY = os.getenv("OPENROUTE_API") |
|
if not API_KEY: |
|
logger.warning("The OPENROUTE_API environment variable is not set. Some features may not work.") |
|
|
|
|
|
class ChatConfig(BaseModel): |
|
|
|
model: str = "mistralai/mistral-small-3.1-24b-instruct:free" |
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=1.0) |
|
|
|
class MessageRequest(BaseModel): |
|
session_id: str |
|
message: str |
|
thinking_rounds: Optional[int] = Field(default=None, ge=1, le=10) |
|
alternatives_per_round: Optional[int] = Field(default=3, ge=1, le=5) |
|
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0) |
|
|
|
class SaveRequest(BaseModel): |
|
session_id: str |
|
filename: Optional[str] = None |
|
full_log: bool = False |
|
|
|
class SessionInfo(BaseModel): |
|
session_id: str |
|
message_count: int |
|
created_at: str |
|
model: str |
|
|
|
class SessionResponse(BaseModel): |
|
sessions: List[SessionInfo] |
|
|
|
class InitResponse(BaseModel): |
|
session_id: str |
|
status: str |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
"""Root endpoint with a simple HTML interface""" |
|
html_content = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Chain-of-Recursive-Thoughts: TEST</title> |
|
<style> |
|
body {{ |
|
font-family: Arial, sans-serif; |
|
max-width: 800px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
line-height: 1.6; |
|
}} |
|
h1 {{ |
|
color: #333; |
|
border-bottom: 1px solid #eee; |
|
padding-bottom: 10px; |
|
}} |
|
.container {{ |
|
background-color: #f9f9f9; |
|
border-radius: 5px; |
|
padding: 20px; |
|
margin-top: 20px; |
|
}} |
|
label {{ |
|
display: block; |
|
margin-bottom: 5px; |
|
font-weight: bold; |
|
}} |
|
input, textarea, select {{ |
|
width: 100%; |
|
padding: 8px; |
|
margin-bottom: 10px; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
}} |
|
button {{ |
|
background-color: #4CAF50; |
|
color: white; |
|
padding: 10px 15px; |
|
border: none; |
|
border-radius: 4px; |
|
cursor: pointer; |
|
}} |
|
button:hover {{ |
|
background-color: #45a049; |
|
}} |
|
#response {{ |
|
white-space: pre-wrap; |
|
background-color: #f5f5f5; |
|
padding: 15px; |
|
border-radius: 4px; |
|
margin-top: 20px; |
|
min-height: 100px; |
|
}} |
|
.log {{ |
|
margin-top: 20px; |
|
font-size: 0.9em; |
|
color: #666; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<h1>Chain-of-Recursive-Thoughts: TEST</h1> |
|
<div class="container"> |
|
<div id="init-form"> |
|
<h2>1. Initialize Chat</h2> |
|
|
|
<!-- API key input removed --> |
|
<label for="model">Model:</label> |
|
<input type="text" id="model" value="mistralai/mistral-small-3.1-24b-instruct:free"> |
|
|
|
<label for="temperature">Temperature:</label> |
|
<input type="number" id="temperature" min="0" max="1" step="0.1" value="0.7"> |
|
|
|
<button onclick="initializeChat()">Initialize</button> |
|
</div> |
|
|
|
<div id="chat-form" style="display: none;"> |
|
<h2>2. Send Message</h2> |
|
<p>Session ID: <span id="session-id"></span></p> |
|
|
|
<label for="message">Message:</label> |
|
<textarea id="message" rows="4" placeholder="Enter your message"></textarea> |
|
|
|
<label for="thinking-rounds">Thinking Rounds (optional):</label> |
|
<input type="number" id="thinking-rounds" min="1" max="10" placeholder="Auto"> |
|
|
|
<label for="alternatives">Number of Alternatives (optional):</label> |
|
<input type="number" id="alternatives" min="1" max="5" value="3"> |
|
|
|
<button onclick="sendMessage()">Send</button> |
|
<button onclick="resetChat()" style="background-color: #f44336;">Reset</button> |
|
</div> |
|
|
|
<div id="response-container" style="display: none;"> |
|
<h2>3. Response</h2> |
|
<div id="response">The response will appear here...</div> |
|
<div class="log"> |
|
<h3>Thinking Process Log:</h3> |
|
<div id="thinking-log"></div> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<div style="margin-top: 30px;"> |
|
<p>Repo: https://github.com/PhialsBasement/Chain-of-Recursive-Thoughts</p> |
|
<p>Community: https://discord.gg/openfreeai</p> |
|
</div> |
|
|
|
<script> |
|
let currentSessionId = null; |
|
|
|
async function initializeChat() {{ |
|
const model = document.getElementById('model').value; |
|
const temperature = parseFloat(document.getElementById('temperature').value); |
|
|
|
try {{ |
|
const response = await fetch('/api/initialize', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json', |
|
}}, |
|
body: JSON.stringify({{ |
|
model: model, |
|
temperature: temperature |
|
}}), |
|
}}); |
|
|
|
const data = await response.json(); |
|
|
|
if (response.ok) {{ |
|
currentSessionId = data.session_id; |
|
document.getElementById('session-id').textContent = currentSessionId; |
|
document.getElementById('init-form').style.display = 'none'; |
|
document.getElementById('chat-form').style.display = 'block'; |
|
document.getElementById('response-container').style.display = 'block'; |
|
}} else {{ |
|
alert('Initialization failed: ' + (data.detail || 'Unknown error')); |
|
}} |
|
}} catch (error) {{ |
|
alert('An error occurred: ' + error.message); |
|
}} |
|
}} |
|
|
|
async function sendMessage() {{ |
|
if (!currentSessionId) {{ |
|
alert('Please initialize a chat session first.'); |
|
return; |
|
}} |
|
|
|
const message = document.getElementById('message').value; |
|
const thinkingRounds = document.getElementById('thinking-rounds').value; |
|
const alternatives = document.getElementById('alternatives').value; |
|
|
|
if (!message) {{ |
|
alert('Please enter a message.'); |
|
return; |
|
}} |
|
|
|
document.getElementById('response').textContent = 'Processing...'; |
|
document.getElementById('thinking-log').textContent = ''; |
|
|
|
try {{ |
|
const response = await fetch('/api/send_message', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json', |
|
}}, |
|
body: JSON.stringify({{ |
|
session_id: currentSessionId, |
|
message: message, |
|
thinking_rounds: thinkingRounds ? parseInt(thinkingRounds) : null, |
|
alternatives_per_round: alternatives ? parseInt(alternatives) : 3 |
|
}}), |
|
}}); |
|
|
|
const data = await response.json(); |
|
|
|
if (response.ok) {{ |
|
document.getElementById('response').textContent = data.response; |
|
|
|
// Display thinking history |
|
let thinkingLogHTML = ''; |
|
data.thinking_history.forEach(item => {{ |
|
const selected = item.selected ? ' ✓ Selected' : ''; |
|
thinkingLogHTML += "<p><strong>Round " + item.round + selected + ":</strong> "; |
|
|
|
if (item.explanation && item.selected) {{ |
|
thinkingLogHTML += "<br><em>Reason for selection: " + item.explanation + "</em>"; |
|
}} |
|
thinkingLogHTML += "</p>"; |
|
}}); |
|
|
|
document.getElementById('thinking-log').innerHTML = thinkingLogHTML; |
|
}} else {{ |
|
document.getElementById('response').textContent = 'Error: ' + (data.detail || 'Unknown error'); |
|
}} |
|
}} catch (error) {{ |
|
document.getElementById('response').textContent = 'An error occurred: ' + error.message; |
|
}} |
|
}} |
|
|
|
function resetChat() {{ |
|
currentSessionId = null; |
|
document.getElementById('init-form').style.display = 'block'; |
|
document.getElementById('chat-form').style.display = 'none'; |
|
document.getElementById('response-container').style.display = 'none'; |
|
document.getElementById('message').value = ''; |
|
document.getElementById('thinking-rounds').value = ''; |
|
document.getElementById('alternatives').value = '3'; |
|
}} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return html_content |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "timestamp": datetime.now().isoformat()} |
|
|
|
@app.post("/api/initialize", response_model=InitResponse) |
|
async def initialize_chat(config: ChatConfig): |
|
"""Initialize a new chat session using the environment API key""" |
|
try: |
|
|
|
session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
if not API_KEY: |
|
raise HTTPException(status_code=400, detail="The OPENROUTE_API environment variable is not set.") |
|
|
|
|
|
chat = EnhancedRecursiveThinkingChat( |
|
api_key=API_KEY, |
|
model=config.model, |
|
temperature=config.temperature |
|
) |
|
chat_instances[session_id] = { |
|
"chat": chat, |
|
"created_at": datetime.now().isoformat(), |
|
"model": config.model |
|
} |
|
|
|
return {"session_id": session_id, "status": "initialized"} |
|
except Exception as e: |
|
logger.error(f"Error initializing chat: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Failed to initialize chat: {str(e)}") |
|
|
|
@app.post("/api/send_message") |
|
async def send_message(request: MessageRequest): |
|
"""Send a message and get a response with the thinking process""" |
|
try: |
|
if request.session_id not in chat_instances: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
chat = chat_instances[request.session_id]["chat"] |
|
|
|
|
|
original_thinking_fn = chat._determine_thinking_rounds |
|
original_alternatives_fn = chat._generate_alternatives |
|
original_temperature = getattr(chat, "temperature", 0.7) |
|
|
|
if request.thinking_rounds is not None: |
|
|
|
chat._determine_thinking_rounds = lambda _: request.thinking_rounds |
|
|
|
if request.alternatives_per_round is not None: |
|
def modified_generate_alternatives(base_response, prompt, num_alternatives=3): |
|
return original_alternatives_fn(base_response, prompt, request.alternatives_per_round) |
|
chat._generate_alternatives = modified_generate_alternatives |
|
|
|
|
|
if request.temperature is not None: |
|
setattr(chat, "temperature", request.temperature) |
|
|
|
|
|
logger.info(f"Processing message for session {request.session_id}") |
|
start_time = datetime.now() |
|
result = chat.think_and_respond(request.message, verbose=True) |
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
logger.info(f"Message processed in {processing_time:.2f} seconds") |
|
|
|
|
|
chat._determine_thinking_rounds = original_thinking_fn |
|
chat._generate_alternatives = original_alternatives_fn |
|
if request.temperature is not None: |
|
setattr(chat, "temperature", original_temperature) |
|
|
|
return { |
|
"session_id": request.session_id, |
|
"response": result["response"], |
|
"thinking_rounds": result["thinking_rounds"], |
|
"thinking_history": result["thinking_history"], |
|
"processing_time": processing_time |
|
} |
|
except Exception as e: |
|
logger.error(f"Error processing message: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Failed to process message: {str(e)}") |
|
|
|
@app.post("/api/save") |
|
async def save_conversation(request: SaveRequest): |
|
"""Save the conversation or the full thinking log""" |
|
try: |
|
if request.session_id not in chat_instances: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
chat = chat_instances[request.session_id]["chat"] |
|
|
|
|
|
filename = request.filename |
|
if filename is None: |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
log_type = "full_log" if request.full_log else "conversation" |
|
filename = f"recthink_{log_type}_{timestamp}.json" |
|
|
|
|
|
os.makedirs("logs", exist_ok=True) |
|
file_path = os.path.join("logs", filename) |
|
|
|
if request.full_log: |
|
chat.save_full_log(file_path) |
|
else: |
|
chat.save_conversation(file_path) |
|
|
|
return {"status": "saved", "filename": filename, "path": file_path} |
|
except Exception as e: |
|
logger.error(f"Error saving conversation: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Failed to save conversation: {str(e)}") |
|
|
|
@app.get("/api/sessions", response_model=SessionResponse) |
|
async def list_sessions(): |
|
"""List all active chat sessions""" |
|
sessions = [] |
|
for session_id, session_data in chat_instances.items(): |
|
chat = session_data["chat"] |
|
message_count = len(chat.conversation_history) // 2 |
|
|
|
sessions.append(SessionInfo( |
|
session_id=session_id, |
|
message_count=message_count, |
|
created_at=session_data["created_at"], |
|
model=session_data["model"] |
|
)) |
|
|
|
return {"sessions": sessions} |
|
|
|
@app.get("/api/sessions/{session_id}") |
|
async def get_session(session_id: str): |
|
"""Get details for a specific chat session""" |
|
if session_id not in chat_instances: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
session_data = chat_instances[session_id] |
|
chat = session_data["chat"] |
|
|
|
|
|
conversation = [] |
|
for i in range(0, len(chat.conversation_history), 2): |
|
if i+1 < len(chat.conversation_history): |
|
conversation.append({ |
|
"user": chat.conversation_history[i], |
|
"assistant": chat.conversation_history[i+1] |
|
}) |
|
|
|
return { |
|
"session_id": session_id, |
|
"created_at": session_data["created_at"], |
|
"model": session_data["model"], |
|
"message_count": len(conversation), |
|
"conversation": conversation |
|
} |
|
|
|
@app.delete("/api/sessions/{session_id}") |
|
async def delete_session(session_id: str): |
|
"""Delete a chat session""" |
|
if session_id not in chat_instances: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
del chat_instances[session_id] |
|
return {"status": "deleted", "session_id": session_id} |
|
|
|
|
|
class ConnectionManager: |
|
def __init__(self): |
|
self.active_connections: Dict[str, WebSocket] = {} |
|
|
|
async def connect(self, session_id: str, websocket: WebSocket): |
|
await websocket.accept() |
|
self.active_connections[session_id] = websocket |
|
|
|
def disconnect(self, session_id: str): |
|
if session_id in self.active_connections: |
|
del self.active_connections[session_id] |
|
|
|
async def send_json(self, session_id: str, data: dict): |
|
if session_id in self.active_connections: |
|
await self.active_connections[session_id].send_json(data) |
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
@app.websocket("/ws/{session_id}") |
|
async def websocket_endpoint(websocket: WebSocket, session_id: str): |
|
try: |
|
await manager.connect(session_id, websocket) |
|
|
|
if session_id not in chat_instances: |
|
await websocket.send_json({"error": "Session not found"}) |
|
await websocket.close() |
|
return |
|
|
|
chat = chat_instances[session_id]["chat"] |
|
|
|
|
|
original_call_api = chat._call_api |
|
|
|
async def stream_callback(chunk): |
|
await manager.send_json(session_id, {"type": "chunk", "content": chunk}) |
|
|
|
|
|
def ws_call_api(messages, temperature=0.7, stream=True): |
|
result = original_call_api(messages, temperature, stream) |
|
|
|
if stream: |
|
asyncio.create_task(stream_callback(result)) |
|
return result |
|
|
|
|
|
chat._call_api = ws_call_api |
|
|
|
|
|
while True: |
|
data = await websocket.receive_text() |
|
message_data = json.loads(data) |
|
|
|
if message_data["type"] == "message": |
|
|
|
start_time = datetime.now() |
|
|
|
try: |
|
|
|
thinking_rounds = message_data.get("thinking_rounds", None) |
|
alternatives_per_round = message_data.get("alternatives_per_round", None) |
|
temperature = message_data.get("temperature", None) |
|
|
|
|
|
original_thinking_fn = chat._determine_thinking_rounds |
|
original_alternatives_fn = chat._generate_alternatives |
|
original_temperature = getattr(chat, "temperature", 0.7) |
|
|
|
if thinking_rounds is not None: |
|
chat._determine_thinking_rounds = lambda _: thinking_rounds |
|
|
|
if alternatives_per_round is not None: |
|
def modified_generate_alternatives(base_response, prompt, num_alternatives=3): |
|
return original_alternatives_fn(base_response, prompt, alternatives_per_round) |
|
chat._generate_alternatives = modified_generate_alternatives |
|
|
|
if temperature is not None: |
|
setattr(chat, "temperature", temperature) |
|
|
|
|
|
await manager.send_json(session_id, { |
|
"type": "status", |
|
"status": "processing", |
|
"message": "Starting recursive thinking process..." |
|
}) |
|
|
|
|
|
result = chat.think_and_respond(message_data["content"], verbose=True) |
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
chat._determine_thinking_rounds = original_thinking_fn |
|
chat._generate_alternatives = original_alternatives_fn |
|
if temperature is not None: |
|
setattr(chat, "temperature", original_temperature) |
|
|
|
|
|
await manager.send_json(session_id, { |
|
"type": "final", |
|
"response": result["response"], |
|
"thinking_rounds": result["thinking_rounds"], |
|
"thinking_history": result["thinking_history"], |
|
"processing_time": processing_time |
|
}) |
|
|
|
except Exception as e: |
|
error_msg = str(e) |
|
logger.error(f"Error in WebSocket message processing: {error_msg}") |
|
await manager.send_json(session_id, { |
|
"type": "error", |
|
"error": error_msg |
|
}) |
|
|
|
except WebSocketDisconnect: |
|
logger.info(f"WebSocket disconnected: {session_id}") |
|
manager.disconnect(session_id) |
|
except Exception as e: |
|
error_msg = str(e) |
|
logger.error(f"WebSocket error: {error_msg}") |
|
try: |
|
await websocket.send_json({"type": "error", "error": error_msg}) |
|
except: |
|
pass |
|
finally: |
|
|
|
if 'chat' in locals() and 'original_call_api' in locals(): |
|
chat._call_api = original_call_api |
|
|
|
|
|
manager.disconnect(session_id) |
|
|
|
|
|
if __name__ == "__main__": |
|
port = 7860 |
|
print(f"Starting server on port {port}") |
|
uvicorn.run("app:app", host="0.0.0.0", port=port) |
|
|