|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Generator, List, Optional, Set, Tuple |
|
|
|
from camel.messages import BaseMessage |
|
from camel.prompts import PromptTemplateGenerator, TextPrompt |
|
from camel.types import RoleType, TaskType |
|
|
|
|
|
class SystemMessageGenerator: |
|
r"""System message generator for agents. |
|
|
|
Args: |
|
task_type (TaskType, optional): The task type. |
|
(default: :obj:`TaskType.AI_SOCIETY`) |
|
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of |
|
the system messages for each role type. (default: :obj:`None`) |
|
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys |
|
of the meta dictionary used to fill the prompts. |
|
(default: :obj:`None`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
task_type: TaskType = TaskType.AI_SOCIETY, |
|
sys_prompts: Optional[Dict[RoleType, str]] = None, |
|
sys_msg_meta_dict_keys: Optional[Set[str]] = None, |
|
) -> None: |
|
self.sys_prompts: Dict[RoleType, str] |
|
|
|
if sys_prompts is not None: |
|
self.sys_prompts = sys_prompts |
|
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set() |
|
else: |
|
assistant_prompt_template = ( |
|
PromptTemplateGenerator().get_system_prompt( |
|
task_type, |
|
RoleType.ASSISTANT, |
|
) |
|
) |
|
user_prompt_template = PromptTemplateGenerator().get_system_prompt( |
|
task_type, |
|
RoleType.USER, |
|
) |
|
critic_prompt_template = ( |
|
PromptTemplateGenerator().get_system_prompt( |
|
task_type, |
|
RoleType.CRITIC, |
|
) |
|
) |
|
embodiment_prompt_template = ( |
|
PromptTemplateGenerator().get_system_prompt( |
|
task_type, |
|
RoleType.EMBODIMENT, |
|
) |
|
) |
|
|
|
self.sys_prompts = dict() |
|
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template |
|
self.sys_prompts[RoleType.USER] = user_prompt_template |
|
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template |
|
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template |
|
|
|
self.sys_msg_meta_dict_keys = ( |
|
assistant_prompt_template.key_words |
|
| user_prompt_template.key_words |
|
| critic_prompt_template.key_words |
|
| embodiment_prompt_template.key_words |
|
) |
|
|
|
if RoleType.DEFAULT not in self.sys_prompts: |
|
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant." |
|
|
|
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None: |
|
r"""Validates the keys of the meta_dict. |
|
|
|
Args: |
|
meta_dict (Dict[str, str]): The dictionary to validate. |
|
""" |
|
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys): |
|
raise ValueError( |
|
"The keys of the meta_dict should be in " |
|
f"{self.sys_msg_meta_dict_keys}. " |
|
f"Got {set(meta_dict.keys())} instead." |
|
) |
|
|
|
def from_dict( |
|
self, |
|
meta_dict: Dict[str, str], |
|
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT), |
|
) -> BaseMessage: |
|
r"""Generates a system message from a dictionary. |
|
|
|
Args: |
|
meta_dict (Dict[str, str]): The dictionary containing the |
|
information to generate the system message. |
|
role_tuple (Tuple[str, RoleType], optional): The tuple containing |
|
the role name and role type. (default: ("", RoleType.DEFAULT)) |
|
|
|
Returns: |
|
BaseMessage: The generated system message. |
|
""" |
|
self.validate_meta_dict_keys(meta_dict) |
|
role_name, role_type = role_tuple |
|
sys_prompt = self.sys_prompts[role_type] |
|
sys_prompt = sys_prompt.format(**meta_dict) |
|
return BaseMessage( |
|
role_name=role_name, |
|
role_type=role_type, |
|
meta_dict=meta_dict, |
|
content=sys_prompt, |
|
) |
|
|
|
def from_dicts( |
|
self, |
|
meta_dicts: List[Dict[str, str]], |
|
role_tuples: List[Tuple[str, RoleType]], |
|
) -> List[BaseMessage]: |
|
r"""Generates a list of system messages from a list of dictionaries. |
|
|
|
Args: |
|
meta_dicts (List[Dict[str, str]]): A list of dictionaries |
|
containing the information to generate the system messages. |
|
role_tuples (List[Tuple[str, RoleType]]): A list of tuples |
|
containing the role name and role type for each system message. |
|
|
|
Returns: |
|
List[BaseMessage]: A list of generated system messages. |
|
|
|
Raises: |
|
ValueError: If the number of meta_dicts and role_tuples are |
|
different. |
|
""" |
|
if len(meta_dicts) != len(role_tuples): |
|
raise ValueError( |
|
"The number of meta_dicts and role_types should be the same." |
|
) |
|
|
|
return [ |
|
self.from_dict(meta_dict, role_tuple) |
|
for meta_dict, role_tuple in zip(meta_dicts, role_tuples) |
|
] |
|
|
|
|
|
class RoleNameGenerator: |
|
r"""Role name generator for role-playing workers. |
|
|
|
Args: |
|
assistant_role_names_path (str, optional): The path to the file |
|
containing the assistant role names. |
|
(default: :obj:`"data/ai_society/assistant_roles.txt"`) |
|
user_role_names_path (str, optional): The path to the file |
|
containing the user role names. |
|
(default: :obj:`"data/ai_society/user_roles.txt"`) |
|
assistant_role_names (Optional[List[str]], optional): The list of |
|
assistant role names. (default: :obj:`None`) |
|
user_role_names (Optional[List[str]], optional): The list of user role |
|
names. (default: :obj:`None`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt", |
|
user_role_names_path: str = "data/ai_society/user_roles.txt", |
|
assistant_role_names: Optional[List[str]] = None, |
|
user_role_names: Optional[List[str]] = None, |
|
) -> None: |
|
if assistant_role_names is None: |
|
with open(assistant_role_names_path, "r") as f: |
|
assistant_role_names_: List[str] = f.read().splitlines() |
|
self.assistant_role_names = [ |
|
" ".join(name.split(" ")[1:]) |
|
for name in assistant_role_names_ |
|
] |
|
else: |
|
self.assistant_role_names = assistant_role_names |
|
|
|
if user_role_names is None: |
|
with open(user_role_names_path, "r") as f: |
|
user_role_names_: List[str] = f.read().splitlines() |
|
self.user_role_names = [ |
|
" ".join(name.split(" ")[1:]) for name in user_role_names_ |
|
] |
|
else: |
|
self.user_role_names = user_role_names |
|
|
|
def from_role_files(self) -> Generator[Tuple, None, None]: |
|
r"""Generate role names from the file. |
|
|
|
Returns: |
|
Generator[Tuple, None, None]: A generator that yields tuples of |
|
assistant role names and user role names. |
|
""" |
|
for assistant_role_name in self.assistant_role_names: |
|
for user_role_name in self.user_role_names: |
|
yield (assistant_role_name, user_role_name) |
|
|
|
|
|
class AISocietyTaskPromptGenerator: |
|
r"""Task prompt generator for AI society tasks. |
|
|
|
Args: |
|
num_tasks (int, optional): The number of tasks to generate. |
|
(default: :obj:`10`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_tasks: int = 10, |
|
) -> None: |
|
self.generate_tasks_prompt = ( |
|
PromptTemplateGenerator().get_generate_tasks_prompt( |
|
TaskType.AI_SOCIETY |
|
) |
|
) |
|
|
|
self.num_tasks = num_tasks |
|
|
|
|
|
def from_role_files( |
|
self, |
|
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt", |
|
user_role_names_path: str = "data/ai_society/user_roles.txt", |
|
) -> Generator[Tuple[str, Tuple[str, str]], None, None]: |
|
r"""Generate tasks from role files. |
|
|
|
Args: |
|
assistant_role_names_path (str, optional): The path to the file |
|
containing the assistant role names. |
|
(default: :obj:`"data/ai_society/assistant_roles.txt"`) |
|
user_role_names_path (str, optional): The path to the file |
|
containing the user role names. |
|
(default: :obj:`"data/ai_society/user_roles.txt"`) |
|
|
|
Returns: |
|
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator |
|
that yields tuples of task prompts and role names. |
|
""" |
|
roles_generator = RoleNameGenerator( |
|
assistant_role_names_path, user_role_names_path |
|
).from_role_files() |
|
for role_1, role_2 in roles_generator: |
|
generate_tasks_prompt = self.generate_tasks_prompt.format( |
|
assistant_role=role_1, |
|
user_role=role_2, |
|
num_tasks=self.num_tasks, |
|
) |
|
|
|
yield (generate_tasks_prompt, (role_1, role_2)) |
|
|
|
def from_role_generator( |
|
self, role_generator: Generator[Tuple, None, None] |
|
) -> Generator[Tuple[str, Tuple[str, str]], None, None]: |
|
r"""Generate tasks from a role generator. |
|
|
|
Args: |
|
role_generator (Generator[Tuple, None, None]): A generator that |
|
yields tuples of role names. |
|
|
|
Returns: |
|
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator |
|
that yields tuples of task prompts and role names. |
|
""" |
|
for role_1, role_2 in role_generator: |
|
generate_tasks_prompt = self.generate_tasks_prompt.format( |
|
assistant_role=role_1, |
|
user_role=role_2, |
|
num_tasks=self.num_tasks, |
|
) |
|
|
|
yield (generate_tasks_prompt, (role_1, role_2)) |
|
|
|
|
|
class SingleTxtGenerator: |
|
r"""Single text generator for role-playing workers. |
|
|
|
Args: |
|
text_file_path (str): The path to the file containing the text data. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
text_file_path: str, |
|
) -> None: |
|
with open(text_file_path, "r") as f: |
|
data_list: List[str] = f.read().splitlines() |
|
self.data_list = [ |
|
" ".join(name.split(" ")[1:]) for name in data_list |
|
] |
|
|
|
def from_role_files(self) -> Generator[str, None, None]: |
|
r"""Generate text from the file. |
|
|
|
Returns: |
|
Generator[str, None, None]: A generator that yields the text data. |
|
""" |
|
for data in self.data_list: |
|
yield data |
|
|
|
|
|
class CodeTaskPromptGenerator: |
|
r"""Code task prompt generator for code tasks. |
|
|
|
Args: |
|
num_tasks (int, optional): The number of tasks to generate. |
|
(default: :obj:`50`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_tasks: int = 50, |
|
) -> None: |
|
self.generate_tasks_prompt = ( |
|
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE) |
|
) |
|
|
|
self.num_tasks = num_tasks |
|
|
|
def from_role_files( |
|
self, |
|
languages_path: str = "data/code/languages.txt", |
|
domains_path: str = "data/code/domains.txt", |
|
) -> Generator[Tuple[TextPrompt, str, str], None, None]: |
|
r"""Generate tasks from role files. |
|
|
|
Args: |
|
languages_path (str, optional): The path to the file containing |
|
the language names. (default: :obj:`"data/code/languages.txt"`) |
|
domains_path (str, optional): The path to the file containing |
|
the domain names. (default: :obj:`"data/code/domains.txt"`) |
|
|
|
Returns: |
|
Generator[Tuple[TextPrompt, str, str], None, None]: A generator |
|
that yields tuples of task prompts, language names, and domain |
|
names. |
|
""" |
|
language_generator = SingleTxtGenerator( |
|
languages_path |
|
).from_role_files() |
|
|
|
for language in language_generator: |
|
domains_generator = SingleTxtGenerator( |
|
domains_path |
|
).from_role_files() |
|
for domain in domains_generator: |
|
generated_tasks_prompt = self.generate_tasks_prompt.format( |
|
language=language, domain=domain, num_tasks=self.num_tasks |
|
) |
|
yield generated_tasks_prompt, language, domain |
|
|
|
def from_role_generator( |
|
self, role_generator: Generator[Tuple, None, None] |
|
) -> Generator[str, None, None]: |
|
r"""Generate tasks from a role generator. |
|
|
|
Args: |
|
role_generator (Generator[Tuple, None, None]): A generator that |
|
yields tuples of role names. |
|
|
|
Returns: |
|
Generator[str, None, None]: A generator that yields the task |
|
prompts. |
|
""" |
|
raise NotImplementedError |
|
|