openfree commited on
Commit
2624ce9
·
verified ·
1 Parent(s): fba7498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -53,7 +53,6 @@ if not API_KEY:
53
 
54
  # Pydantic models for request/response validation
55
  class ChatConfig(BaseModel):
56
- # Removed api_key field; only model and temperature are received
57
  model: str = "mistralai/mistral-small-3.1-24b-instruct:free"
58
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=1.0)
59
 
@@ -82,7 +81,6 @@ class InitResponse(BaseModel):
82
  session_id: str
83
  status: str
84
 
85
- # Simple HTML interface
86
  @app.get("/", response_class=HTMLResponse)
87
  async def root():
88
  """
@@ -404,9 +402,8 @@ async def root():
404
  </body>
405
  </html>
406
  """
407
- return html_content
408
 
409
- # Health check endpoint
410
  @app.get("/health")
411
  async def health_check():
412
  """Health check endpoint"""
@@ -419,11 +416,9 @@ async def initialize_chat(config: ChatConfig):
419
  # Generate a session ID
420
  session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
421
 
422
- # If the environment variable is missing, raise an error
423
  if not API_KEY:
424
  raise HTTPException(status_code=400, detail="The OPENROUTE_API environment variable is not set.")
425
 
426
- # Initialize the chat instance
427
  chat = EnhancedRecursiveThinkingChat(
428
  api_key=API_KEY,
429
  model=config.model,
@@ -434,7 +429,6 @@ async def initialize_chat(config: ChatConfig):
434
  "created_at": datetime.now().isoformat(),
435
  "model": config.model
436
  }
437
-
438
  return {"session_id": session_id, "status": "initialized"}
439
  except Exception as e:
440
  logger.error(f"Error initializing chat: {str(e)}")
@@ -453,8 +447,21 @@ async def send_message_original(request: MessageRequest):
453
  # Make a direct call to the LLM without recursion logic
454
  messages = [{"role": "user", "content": request.message}]
455
  response_data = chat._call_api(messages, temperature=chat.temperature, stream=False)
456
- # Extract the text from the response
457
- original_text = response_data["choices"][0]["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  return {"response": original_text.strip()}
459
  except Exception as e:
460
  logger.error(f"Error getting original response: {str(e)}")
@@ -462,20 +469,21 @@ async def send_message_original(request: MessageRequest):
462
 
463
  @app.post("/api/send_message")
464
  async def send_message(request: MessageRequest):
465
- """Send a message and get a response with the chain-of-thought process (HTTP-based, not streaming)."""
 
 
 
466
  try:
467
  if request.session_id not in chat_instances:
468
  raise HTTPException(status_code=404, detail="Session not found")
469
 
470
  chat = chat_instances[request.session_id]["chat"]
471
 
472
- # Override class parameters if provided
473
  original_thinking_fn = chat._determine_thinking_rounds
474
  original_alternatives_fn = chat._generate_alternatives
475
  original_temperature = getattr(chat, "temperature", 0.7)
476
 
477
  if request.thinking_rounds is not None:
478
- # Override the thinking rounds determination
479
  chat._determine_thinking_rounds = lambda _: request.thinking_rounds
480
 
481
  if request.alternatives_per_round is not None:
@@ -483,18 +491,16 @@ async def send_message(request: MessageRequest):
483
  return original_alternatives_fn(base_response, prompt, request.alternatives_per_round)
484
  chat._generate_alternatives = modified_generate_alternatives
485
 
486
- # Override temperature if provided
487
  if request.temperature is not None:
488
  setattr(chat, "temperature", request.temperature)
489
 
490
- # Process the message
491
  logger.info(f"Processing message for session {request.session_id}")
492
  start_time = datetime.now()
493
  result = chat.think_and_respond(request.message, verbose=True)
494
  processing_time = (datetime.now() - start_time).total_seconds()
495
  logger.info(f"Message processed in {processing_time:.2f} seconds")
496
 
497
- # Restore original functions and parameters
498
  chat._determine_thinking_rounds = original_thinking_fn
499
  chat._generate_alternatives = original_alternatives_fn
500
  if request.temperature is not None:
@@ -513,21 +519,19 @@ async def send_message(request: MessageRequest):
513
 
514
  @app.post("/api/save")
515
  async def save_conversation(request: SaveRequest):
516
- """Save the conversation or the full thinking log"""
517
  try:
518
  if request.session_id not in chat_instances:
519
  raise HTTPException(status_code=404, detail="Session not found")
520
 
521
  chat = chat_instances[request.session_id]["chat"]
522
 
523
- # Generate default filename if not provided
524
  filename = request.filename
525
  if filename is None:
526
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
527
  log_type = "full_log" if request.full_log else "conversation"
528
  filename = f"recthink_{log_type}_{timestamp}.json"
529
 
530
- # Make sure the output directory exists
531
  os.makedirs("logs", exist_ok=True)
532
  file_path = os.path.join("logs", filename)
533
 
@@ -543,31 +547,28 @@ async def save_conversation(request: SaveRequest):
543
 
544
  @app.get("/api/sessions", response_model=SessionResponse)
545
  async def list_sessions():
546
- """List all active chat sessions"""
547
  sessions = []
548
  for session_id, session_data in chat_instances.items():
549
  chat = session_data["chat"]
550
- message_count = len(chat.conversation_history) // 2 # Each message-response pair counts as 2
551
-
552
  sessions.append(SessionInfo(
553
  session_id=session_id,
554
  message_count=message_count,
555
  created_at=session_data["created_at"],
556
  model=session_data["model"]
557
  ))
558
-
559
  return {"sessions": sessions}
560
 
561
  @app.get("/api/sessions/{session_id}")
562
  async def get_session(session_id: str):
563
- """Get details for a specific chat session"""
564
  if session_id not in chat_instances:
565
  raise HTTPException(status_code=404, detail="Session not found")
566
 
567
  session_data = chat_instances[session_id]
568
  chat = session_data["chat"]
569
 
570
- # Extract conversation history
571
  conversation = []
572
  for i in range(0, len(chat.conversation_history), 2):
573
  if i+1 < len(chat.conversation_history):
@@ -586,14 +587,12 @@ async def get_session(session_id: str):
586
 
587
  @app.delete("/api/sessions/{session_id}")
588
  async def delete_session(session_id: str):
589
- """Delete a chat session"""
590
  if session_id not in chat_instances:
591
  raise HTTPException(status_code=404, detail="Session not found")
592
-
593
  del chat_instances[session_id]
594
  return {"status": "deleted", "session_id": session_id}
595
 
596
- # WebSocket connection manager
597
  class ConnectionManager:
598
  def __init__(self):
599
  self.active_connections: Dict[str, WebSocket] = {}
@@ -612,7 +611,6 @@ class ConnectionManager:
612
 
613
  manager = ConnectionManager()
614
 
615
- # WebSocket for streaming the thinking process
616
  @app.websocket("/ws/{session_id}")
617
  async def websocket_endpoint(websocket: WebSocket, session_id: str):
618
  try:
@@ -624,40 +622,31 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
624
  return
625
 
626
  chat = chat_instances[session_id]["chat"]
627
-
628
- # Set up a custom callback to stream the thinking process
629
  original_call_api = chat._call_api
630
 
631
  async def stream_callback(chunk):
632
  await manager.send_json(session_id, {"type": "chunk", "content": chunk})
633
 
634
- # Override the _call_api method to also send updates via WebSocket
635
  def ws_call_api(messages, temperature=0.7, stream=True):
636
  result = original_call_api(messages, temperature, stream)
637
- # Send the chunk via WebSocket if we're streaming
638
  if stream:
639
  asyncio.create_task(stream_callback(result))
640
  return result
641
 
642
- # Replace the method temporarily
643
  chat._call_api = ws_call_api
644
 
645
- # Wait for messages from the client
646
  while True:
647
  data = await websocket.receive_text()
648
  message_data = json.loads(data)
649
 
650
  if message_data["type"] == "message":
651
- # Process the message
652
  start_time = datetime.now()
653
 
654
  try:
655
- # Get parameters if they exist
656
  thinking_rounds = message_data.get("thinking_rounds", None)
657
  alternatives_per_round = message_data.get("alternatives_per_round", None)
658
  temperature = message_data.get("temperature", None)
659
 
660
- # Override if needed
661
  original_thinking_fn = chat._determine_thinking_rounds
662
  original_alternatives_fn = chat._generate_alternatives
663
  original_temperature = getattr(chat, "temperature", 0.7)
@@ -673,24 +662,20 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
673
  if temperature is not None:
674
  setattr(chat, "temperature", temperature)
675
 
676
- # Send a status message that we've started processing
677
  await manager.send_json(session_id, {
678
  "type": "status",
679
  "status": "processing",
680
  "message": "Starting recursive thinking process..."
681
  })
682
 
683
- # Process the message with chain-of-thought
684
  result = chat.think_and_respond(message_data["content"], verbose=True)
685
  processing_time = (datetime.now() - start_time).total_seconds()
686
 
687
- # Restore original functions
688
  chat._determine_thinking_rounds = original_thinking_fn
689
  chat._generate_alternatives = original_alternatives_fn
690
  if temperature is not None:
691
  setattr(chat, "temperature", original_temperature)
692
 
693
- # Send the final result
694
  await manager.send_json(session_id, {
695
  "type": "final",
696
  "response": result["response"],
@@ -706,7 +691,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
706
  "type": "error",
707
  "error": error_msg
708
  })
709
-
710
  except WebSocketDisconnect:
711
  logger.info(f"WebSocket disconnected: {session_id}")
712
  manager.disconnect(session_id)
@@ -718,14 +702,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
718
  except:
719
  pass
720
  finally:
721
- # Restore the original method if needed
722
  if 'chat' in locals() and 'original_call_api' in locals():
723
  chat._call_api = original_call_api
724
 
725
- # Make sure to disconnect
726
  manager.disconnect(session_id)
727
 
728
- # Use port 7860 for Hugging Face Spaces
729
  if __name__ == "__main__":
730
  port = 7860
731
  print(f"Starting server on port {port}")
 
53
 
54
  # Pydantic models for request/response validation
55
  class ChatConfig(BaseModel):
 
56
  model: str = "mistralai/mistral-small-3.1-24b-instruct:free"
57
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=1.0)
58
 
 
81
  session_id: str
82
  status: str
83
 
 
84
  @app.get("/", response_class=HTMLResponse)
85
  async def root():
86
  """
 
402
  </body>
403
  </html>
404
  """
405
+ return HTMLResponse(content=html_content)
406
 
 
407
  @app.get("/health")
408
  async def health_check():
409
  """Health check endpoint"""
 
416
  # Generate a session ID
417
  session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
418
 
 
419
  if not API_KEY:
420
  raise HTTPException(status_code=400, detail="The OPENROUTE_API environment variable is not set.")
421
 
 
422
  chat = EnhancedRecursiveThinkingChat(
423
  api_key=API_KEY,
424
  model=config.model,
 
429
  "created_at": datetime.now().isoformat(),
430
  "model": config.model
431
  }
 
432
  return {"session_id": session_id, "status": "initialized"}
433
  except Exception as e:
434
  logger.error(f"Error initializing chat: {str(e)}")
 
447
  # Make a direct call to the LLM without recursion logic
448
  messages = [{"role": "user", "content": request.message}]
449
  response_data = chat._call_api(messages, temperature=chat.temperature, stream=False)
450
+
451
+ # The structure of response_data depends on the underlying LLM.
452
+ # We'll try to handle both "message" and "text" keys as possible.
453
+ if isinstance(response_data, dict) and "choices" in response_data:
454
+ first_choice = response_data["choices"][0]
455
+ if "message" in first_choice and "content" in first_choice["message"]:
456
+ original_text = first_choice["message"]["content"]
457
+ elif "text" in first_choice:
458
+ original_text = first_choice["text"]
459
+ else:
460
+ original_text = str(first_choice)
461
+ else:
462
+ # If for some reason the response is not in the expected format, just convert to string
463
+ original_text = str(response_data)
464
+
465
  return {"response": original_text.strip()}
466
  except Exception as e:
467
  logger.error(f"Error getting original response: {str(e)}")
 
469
 
470
  @app.post("/api/send_message")
471
  async def send_message(request: MessageRequest):
472
+ """
473
+ Send a message and get a response with the chain-of-thought process (HTTP-based, not streaming).
474
+ Primarily left here for completeness, but the user-facing code calls the WebSocket for streaming.
475
+ """
476
  try:
477
  if request.session_id not in chat_instances:
478
  raise HTTPException(status_code=404, detail="Session not found")
479
 
480
  chat = chat_instances[request.session_id]["chat"]
481
 
 
482
  original_thinking_fn = chat._determine_thinking_rounds
483
  original_alternatives_fn = chat._generate_alternatives
484
  original_temperature = getattr(chat, "temperature", 0.7)
485
 
486
  if request.thinking_rounds is not None:
 
487
  chat._determine_thinking_rounds = lambda _: request.thinking_rounds
488
 
489
  if request.alternatives_per_round is not None:
 
491
  return original_alternatives_fn(base_response, prompt, request.alternatives_per_round)
492
  chat._generate_alternatives = modified_generate_alternatives
493
 
 
494
  if request.temperature is not None:
495
  setattr(chat, "temperature", request.temperature)
496
 
 
497
  logger.info(f"Processing message for session {request.session_id}")
498
  start_time = datetime.now()
499
  result = chat.think_and_respond(request.message, verbose=True)
500
  processing_time = (datetime.now() - start_time).total_seconds()
501
  logger.info(f"Message processed in {processing_time:.2f} seconds")
502
 
503
+ # Restore original
504
  chat._determine_thinking_rounds = original_thinking_fn
505
  chat._generate_alternatives = original_alternatives_fn
506
  if request.temperature is not None:
 
519
 
520
  @app.post("/api/save")
521
  async def save_conversation(request: SaveRequest):
522
+ """Save the conversation or the full thinking log."""
523
  try:
524
  if request.session_id not in chat_instances:
525
  raise HTTPException(status_code=404, detail="Session not found")
526
 
527
  chat = chat_instances[request.session_id]["chat"]
528
 
 
529
  filename = request.filename
530
  if filename is None:
531
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
532
  log_type = "full_log" if request.full_log else "conversation"
533
  filename = f"recthink_{log_type}_{timestamp}.json"
534
 
 
535
  os.makedirs("logs", exist_ok=True)
536
  file_path = os.path.join("logs", filename)
537
 
 
547
 
548
  @app.get("/api/sessions", response_model=SessionResponse)
549
  async def list_sessions():
550
+ """List all active chat sessions."""
551
  sessions = []
552
  for session_id, session_data in chat_instances.items():
553
  chat = session_data["chat"]
554
+ message_count = len(chat.conversation_history) // 2
 
555
  sessions.append(SessionInfo(
556
  session_id=session_id,
557
  message_count=message_count,
558
  created_at=session_data["created_at"],
559
  model=session_data["model"]
560
  ))
 
561
  return {"sessions": sessions}
562
 
563
  @app.get("/api/sessions/{session_id}")
564
  async def get_session(session_id: str):
565
+ """Get details for a specific chat session."""
566
  if session_id not in chat_instances:
567
  raise HTTPException(status_code=404, detail="Session not found")
568
 
569
  session_data = chat_instances[session_id]
570
  chat = session_data["chat"]
571
 
 
572
  conversation = []
573
  for i in range(0, len(chat.conversation_history), 2):
574
  if i+1 < len(chat.conversation_history):
 
587
 
588
  @app.delete("/api/sessions/{session_id}")
589
  async def delete_session(session_id: str):
590
+ """Delete a chat session."""
591
  if session_id not in chat_instances:
592
  raise HTTPException(status_code=404, detail="Session not found")
 
593
  del chat_instances[session_id]
594
  return {"status": "deleted", "session_id": session_id}
595
 
 
596
  class ConnectionManager:
597
  def __init__(self):
598
  self.active_connections: Dict[str, WebSocket] = {}
 
611
 
612
  manager = ConnectionManager()
613
 
 
614
  @app.websocket("/ws/{session_id}")
615
  async def websocket_endpoint(websocket: WebSocket, session_id: str):
616
  try:
 
622
  return
623
 
624
  chat = chat_instances[session_id]["chat"]
 
 
625
  original_call_api = chat._call_api
626
 
627
  async def stream_callback(chunk):
628
  await manager.send_json(session_id, {"type": "chunk", "content": chunk})
629
 
 
630
  def ws_call_api(messages, temperature=0.7, stream=True):
631
  result = original_call_api(messages, temperature, stream)
 
632
  if stream:
633
  asyncio.create_task(stream_callback(result))
634
  return result
635
 
 
636
  chat._call_api = ws_call_api
637
 
 
638
  while True:
639
  data = await websocket.receive_text()
640
  message_data = json.loads(data)
641
 
642
  if message_data["type"] == "message":
 
643
  start_time = datetime.now()
644
 
645
  try:
 
646
  thinking_rounds = message_data.get("thinking_rounds", None)
647
  alternatives_per_round = message_data.get("alternatives_per_round", None)
648
  temperature = message_data.get("temperature", None)
649
 
 
650
  original_thinking_fn = chat._determine_thinking_rounds
651
  original_alternatives_fn = chat._generate_alternatives
652
  original_temperature = getattr(chat, "temperature", 0.7)
 
662
  if temperature is not None:
663
  setattr(chat, "temperature", temperature)
664
 
 
665
  await manager.send_json(session_id, {
666
  "type": "status",
667
  "status": "processing",
668
  "message": "Starting recursive thinking process..."
669
  })
670
 
 
671
  result = chat.think_and_respond(message_data["content"], verbose=True)
672
  processing_time = (datetime.now() - start_time).total_seconds()
673
 
 
674
  chat._determine_thinking_rounds = original_thinking_fn
675
  chat._generate_alternatives = original_alternatives_fn
676
  if temperature is not None:
677
  setattr(chat, "temperature", original_temperature)
678
 
 
679
  await manager.send_json(session_id, {
680
  "type": "final",
681
  "response": result["response"],
 
691
  "type": "error",
692
  "error": error_msg
693
  })
 
694
  except WebSocketDisconnect:
695
  logger.info(f"WebSocket disconnected: {session_id}")
696
  manager.disconnect(session_id)
 
702
  except:
703
  pass
704
  finally:
 
705
  if 'chat' in locals() and 'original_call_api' in locals():
706
  chat._call_api = original_call_api
707
 
 
708
  manager.disconnect(session_id)
709
 
 
710
  if __name__ == "__main__":
711
  port = 7860
712
  print(f"Starting server on port {port}")