|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
from camel.messages import BaseMessage |
|
from camel.types import TerminationMode |
|
|
|
from .base import ResponseTerminator |
|
|
|
|
|
class ResponseWordsTerminator(ResponseTerminator): |
|
r"""Terminate agent when some words reached to occurrence |
|
limit by any message of the response. |
|
|
|
Args: |
|
words_dict (dict): Dictionary of words and its occurrence |
|
threshold. |
|
case_sensitive (bool): Whether count the words as |
|
case-sensitive. (default: :obj:`False`) |
|
mode (TerminationMode): Whether terminate agent if any |
|
or all pre-set words reached the threshold. |
|
(default: :obj:`TerminationMode.ANY`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
words_dict: Dict[str, int], |
|
case_sensitive: bool = False, |
|
mode: TerminationMode = TerminationMode.ANY, |
|
): |
|
super().__init__() |
|
self.words_dict = words_dict |
|
self.case_sensitive = case_sensitive |
|
self.mode = mode |
|
self._word_count_dict: List[Dict[str, int]] = [] |
|
self._validate() |
|
|
|
def _validate(self): |
|
if len(self.words_dict) == 0: |
|
raise ValueError("`words_dict` cannot be empty") |
|
for word in self.words_dict: |
|
threshold = self.words_dict[word] |
|
if threshold <= 0: |
|
raise ValueError( |
|
f"Threshold for word `{word}` should " |
|
f"be larger than 0, got `{threshold}`" |
|
) |
|
|
|
def is_terminated( |
|
self, messages: List[BaseMessage] |
|
) -> Tuple[bool, Optional[str]]: |
|
r"""Whether terminate the agent by checking the occurrence |
|
of specified words reached to preset thresholds. |
|
|
|
Args: |
|
messages (list): List of :obj:`BaseMessage` from a response. |
|
|
|
Returns: |
|
tuple: A tuple containing whether the agent should be |
|
terminated and a string of termination reason. |
|
""" |
|
if self._terminated: |
|
return True, self._termination_reason |
|
|
|
for i in range(len(messages)): |
|
if i >= len(self._word_count_dict): |
|
self._word_count_dict.append(defaultdict(int)) |
|
|
|
for word in self.words_dict: |
|
special_word = word if self.case_sensitive else word.lower() |
|
for i, message in enumerate(messages): |
|
if self.case_sensitive: |
|
content = message.content |
|
else: |
|
content = message.content.lower() |
|
if special_word in content: |
|
self._word_count_dict[i][word] += 1 |
|
|
|
num_reached: List[int] = [] |
|
all_reasons: List[List[str]] = [] |
|
for i in range(len(self._word_count_dict)): |
|
reached = 0 |
|
reasons: List[str] = [] |
|
for word, value in self._word_count_dict[i].items(): |
|
if value >= self.words_dict[word]: |
|
reached += 1 |
|
reason = ( |
|
f"Word `{word}` appears {value} times in the " |
|
f"{i + 1} message of the response which has " |
|
f"reached termination threshold " |
|
f"{self.words_dict[word]}." |
|
) |
|
reasons.append(reason) |
|
all_reasons.append(reasons) |
|
num_reached.append(reached) |
|
|
|
for i, reached in enumerate(num_reached): |
|
if self.mode == TerminationMode.ANY: |
|
if reached > 0: |
|
self._terminated = True |
|
self._termination_reason = "\n".join(all_reasons[i]) |
|
elif self.mode == TerminationMode.ALL: |
|
if reached >= len(self.words_dict): |
|
self._terminated = True |
|
self._termination_reason = "\n".join(all_reasons[i]) |
|
else: |
|
raise ValueError( |
|
f"Unsupported termination mode " f"`{self.mode}`" |
|
) |
|
return self._terminated, self._termination_reason |
|
|
|
def reset(self): |
|
r"""Reset the terminator.""" |
|
self._terminated = False |
|
self._termination_reason = None |
|
self._word_count_dict = defaultdict(int) |
|
|