Ali2206 commited on
Commit
083dc3a
·
verified ·
1 Parent(s): 5bfcdc0

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +135 -141
src/txagent/txagent.py CHANGED
@@ -14,7 +14,6 @@ from .toolrag import ToolRAGModel
14
  import torch
15
  import logging
16
  from difflib import SequenceMatcher
17
- import asyncio
18
  import threading
19
 
20
  logger = logging.getLogger(__name__)
@@ -455,6 +454,140 @@ Patient Record Excerpt:
455
  logger.debug("Quick summary output: %s", output[:100])
456
  return output
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
459
  logger.debug("Starting self agent")
460
  conversation = self.set_system_prompt([], self.self_prompt)
@@ -583,143 +716,4 @@ Summarize the function responses in one sentence with all necessary information.
583
  setattr(self, key, value)
584
  updated_attributes[key] = value
585
  logger.debug("Updated parameters: %s", updated_attributes)
586
- return updated_attributes
587
-
588
- async def run_background_report(self, message: str, history: list, temperature: float,
589
- max_new_tokens: int, max_token: int, call_agent: bool,
590
- conversation: gr.State, max_round: int, seed: int,
591
- call_agent_level: int, report_path: str):
592
- """Run detailed report generation in the background and save to file"""
593
- logger.debug("Starting background report for message: %s", message[:100])
594
- combined_response = ""
595
- history_copy = history.copy()
596
-
597
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
598
- call_agent, call_agent_level, message)
599
- conversation = self.initialize_conversation(message, conversation, history_copy)
600
-
601
- next_round = True
602
- current_round = 0
603
- enable_summary = False
604
- last_status = {}
605
- token_overflow = False
606
-
607
- if self.enable_checker:
608
- checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
609
-
610
- try:
611
- while next_round and current_round < max_round:
612
- current_round += 1
613
- last_outputs = []
614
- if last_outputs:
615
- function_call_messages, picked_tools_prompt, special_tool_call, _ = yield from self.run_function_call_stream(
616
- last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
617
- message_for_call_agent=message, call_agent=call_agent,
618
- call_agent_level=call_agent_level, temperature=temperature,
619
- return_gradio_history=False)
620
-
621
- if special_tool_call == 'Finish':
622
- next_round = False
623
- conversation.extend(function_call_messages)
624
- combined_response += function_call_messages[0]['content'] + "\n"
625
- break
626
-
627
- if (self.enable_summary or token_overflow) and not call_agent:
628
- enable_summary = True
629
- last_status = self.function_result_summary(
630
- conversation, status=last_status, enable_summary=enable_summary)
631
-
632
- if function_call_messages:
633
- conversation.extend(function_call_messages)
634
- combined_response += tool_result_format(function_call_messages) + "\n"
635
- else:
636
- next_round = False
637
- combined_response += ''.join(last_outputs).replace("</s>", "") + "\n"
638
- break
639
-
640
- if self.enable_checker:
641
- good_status, wrong_info = checker.check_conversation()
642
- if not good_status:
643
- logger.warning("Checker error: %s", wrong_info)
644
- break
645
-
646
- tools = picked_tools_prompt
647
- last_outputs_str, token_overflow = self.llm_infer(
648
- messages=conversation, temperature=temperature, tools=tools,
649
- max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
650
-
651
- if last_outputs_str is None:
652
- if self.force_finish:
653
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
654
- conversation, temperature, max_new_tokens, max_token)
655
- combined_response += last_outputs_str + "\n"
656
- break
657
- combined_response += "Token limit exceeded.\n"
658
- break
659
-
660
- combined_response += last_outputs_str + "\n"
661
- last_outputs.append(last_outputs_str)
662
-
663
- if next_round and self.force_finish:
664
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
665
- conversation, temperature, max_new_tokens, max_token)
666
- combined_response += last_outputs_str + "\n"
667
-
668
- # Save report
669
- try:
670
- with open(report_path, "w", encoding="utf-8") as f:
671
- f.write(combined_response)
672
- logger.info("Detailed report saved to %s", report_path)
673
- except Exception as e:
674
- logger.error("Failed to save report: %s", e)
675
-
676
- except Exception as e:
677
- logger.error("Background report error: %s", e)
678
- combined_response += f"Error: {e}\n"
679
- with open(report_path, "w", encoding="utf-8") as f:
680
- f.write(combined_response)
681
-
682
- finally:
683
- torch.cuda.empty_cache()
684
- gc.collect()
685
-
686
- def run_gradio_chat(self, message: str, history: list, temperature: float,
687
- max_new_tokens: int, max_token: int, call_agent: bool,
688
- conversation: gr.State, max_round: int = 3, seed: int = None,
689
- call_agent_level: int = 0, sub_agent_task: str = None,
690
- uploaded_files: list = None, report_path: str = None):
691
- logger.debug("Chat started, message: %s", message[:100])
692
- if not message or len(message.strip()) < 5:
693
- yield "Please provide a valid message or upload files to analyze."
694
- return
695
-
696
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
697
- return
698
-
699
- clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
700
- has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
701
- call_agent = call_agent and not has_clinical_data
702
-
703
- # Generate quick summary
704
- quick_summary = self.run_quick_summary(
705
- message, temperature=temperature, max_new_tokens=256, max_token=1024)
706
- history.append(ChatMessage(role="assistant", content=f"**Quick Summary:**\n{quick_summary}"))
707
- yield history
708
-
709
- # Start background report generation
710
- if report_path:
711
- loop = asyncio.get_event_loop()
712
- threading.Thread(
713
- target=lambda: loop.run_until_complete(
714
- self.run_background_report(
715
- message, history, temperature, max_new_tokens, max_token, call_agent,
716
- conversation, max_round, seed, call_agent_level, report_path
717
- )
718
- ),
719
- daemon=True
720
- ).start()
721
- history.append(ChatMessage(
722
- role="assistant",
723
- content="Generating detailed report in the background. Download will be available when ready."
724
- ))
725
- yield history
 
14
  import torch
15
  import logging
16
  from difflib import SequenceMatcher
 
17
  import threading
18
 
19
  logger = logging.getLogger(__name__)
 
454
  logger.debug("Quick summary output: %s", output[:100])
455
  return output
456
 
457
+ def run_background_report(self, message: str, history: list, temperature: float,
458
+ max_new_tokens: int, max_token: int, call_agent: bool,
459
+ conversation: gr.State, max_round: int, seed: int,
460
+ call_agent_level: int, report_path: str):
461
+ """Run detailed report generation in the background and save to file"""
462
+ logger.debug("Starting background report for message: %s", message[:100])
463
+ combined_response = ""
464
+ history_copy = history.copy()
465
+
466
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
467
+ call_agent, call_agent_level, message)
468
+ conversation = self.initialize_conversation(message, conversation, history_copy)
469
+
470
+ next_round = True
471
+ current_round = 0
472
+ enable_summary = False
473
+ last_status = {}
474
+ token_overflow = False
475
+
476
+ if self.enable_checker:
477
+ checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
478
+
479
+ try:
480
+ while next_round and current_round < max_round:
481
+ current_round += 1
482
+ last_outputs = []
483
+ if last_outputs:
484
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
485
+ last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
486
+ message_for_call_agent=message, call_agent=call_agent,
487
+ call_agent_level=call_agent_level, temperature=temperature)
488
+
489
+ if special_tool_call == 'Finish':
490
+ next_round = False
491
+ conversation.extend(function_call_messages)
492
+ combined_response += function_call_messages[0]['content'] + "\n"
493
+ break
494
+
495
+ if (self.enable_summary or token_overflow) and not call_agent:
496
+ enable_summary = True
497
+ last_status = self.function_result_summary(
498
+ conversation, status=last_status, enable_summary=enable_summary)
499
+
500
+ if function_call_messages:
501
+ conversation.extend(function_call_messages)
502
+ combined_response += tool_result_format(function_call_messages) + "\n"
503
+ else:
504
+ next_round = False
505
+ combined_response += ''.join(last_outputs).replace("</s>", "") + "\n"
506
+ break
507
+
508
+ if self.enable_checker:
509
+ good_status, wrong_info = checker.check_conversation()
510
+ if not good_status:
511
+ logger.warning("Checker error: %s", wrong_info)
512
+ break
513
+
514
+ tools = picked_tools_prompt
515
+ last_outputs_str, token_overflow = self.llm_infer(
516
+ messages=conversation, temperature=temperature, tools=tools,
517
+ max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
518
+
519
+ if last_outputs_str is None:
520
+ if self.force_finish:
521
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
522
+ conversation, temperature, max_new_tokens, max_token)
523
+ combined_response += last_outputs_str + "\n"
524
+ break
525
+ combined_response += "Token limit exceeded.\n"
526
+ break
527
+
528
+ combined_response += last_outputs_str + "\n"
529
+ last_outputs.append(last_outputs_str)
530
+
531
+ if next_round and self.force_finish:
532
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
533
+ conversation, temperature, max_new_tokens, max_token)
534
+ combined_response += last_outputs_str + "\n"
535
+
536
+ # Save report
537
+ try:
538
+ with open(report_path, "w", encoding="utf-8") as f:
539
+ f.write(combined_response)
540
+ logger.info("Detailed report saved to %s", report_path)
541
+ except Exception as e:
542
+ logger.error("Failed to save report: %s", e)
543
+
544
+ except Exception as e:
545
+ logger.error("Background report error: %s", e)
546
+ combined_response += f"Error: {e}\n"
547
+ with open(report_path, "w", encoding="utf-8") as f:
548
+ f.write(combined_response)
549
+
550
+ finally:
551
+ torch.cuda.empty_cache()
552
+ gc.collect()
553
+
554
+ def run_gradio_chat(self, message: str, history: list, temperature: float,
555
+ max_new_tokens: int, max_token: int, call_agent: bool,
556
+ conversation: gr.State, max_round: int = 3, seed: int = None,
557
+ call_agent_level: int = 0, sub_agent_task: str = None,
558
+ uploaded_files: list = None, report_path: str = None):
559
+ logger.debug("Chat started, message: %s", message[:100])
560
+ if not message or len(message.strip()) < 5:
561
+ yield "Please provide a valid message or upload files to analyze."
562
+ return
563
+
564
+ if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
565
+ return
566
+
567
+ clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
568
+ has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
569
+ call_agent = call_agent and not has_clinical_data
570
+
571
+ # Generate quick summary
572
+ quick_summary = self.run_quick_summary(
573
+ message, temperature=temperature, max_new_tokens=256, max_token=1024)
574
+ history.append(ChatMessage(role="assistant", content=f"**Quick Summary:**\n{quick_summary}"))
575
+ yield history
576
+
577
+ # Start background report generation
578
+ if report_path:
579
+ threading.Thread(
580
+ target=self.run_background_report,
581
+ args=(message, history, temperature, max_new_tokens, max_token, call_agent,
582
+ conversation, max_round, seed, call_agent_level, report_path),
583
+ daemon=True
584
+ ).start()
585
+ history.append(ChatMessage(
586
+ role="assistant",
587
+ content="Generating detailed report in the background. Download will be available when ready."
588
+ ))
589
+ yield history
590
+
591
  def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
592
  logger.debug("Starting self agent")
593
  conversation = self.set_system_prompt([], self.self_prompt)
 
716
  setattr(self, key, value)
717
  updated_attributes[key] = value
718
  logger.debug("Updated parameters: %s", updated_attributes)
719
+ return updated_attributes