USTC975's picture
create the app
43c34cc
import logging
import os
import os.path as osp
from typing import Union
from colorama import Fore
from colorama import Style as CRStyle
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.styles import Style
from rich.console import Console
from agentreview.utility.utils import get_rebuttal_dir, load_llm_ac_decisions, \
save_llm_ac_decisions
from ..arena import Arena, TooManyInvalidActions
from ..backends.human import HumanBackendError
from ..const import AGENTREVIEW_LOGO
from ..environments import PaperReview, PaperDecision
# Get the ASCII art from https://patorjk.com/software/taag/#p=display&f=Big&t=Chat%20Arena
color_dict = {
"red": Fore.RED,
"green": Fore.GREEN,
"blue": Fore.BLUE, # Paper Extractor
"light_red": Fore.LIGHTRED_EX, # AC
"light_green": Fore.LIGHTGREEN_EX, # Author
"yellow": Fore.YELLOW, # R1
"magenta": Fore.MAGENTA, # R2
"cyan": Fore.CYAN,
"white": Fore.WHITE,
"black": Fore.BLACK,
"light_yellow": Fore.LIGHTYELLOW_EX,
"light_blue": Fore.LIGHTBLUE_EX,
"light_magenta": Fore.LIGHTMAGENTA_EX,
"light_cyan": Fore.LIGHTCYAN_EX,
"light_white": Fore.LIGHTWHITE_EX,
"light_black": Fore.LIGHTBLACK_EX,
}
visible_colors = [
color
for color in color_dict # ANSI_COLOR_NAMES.keys()
if color not in ["black", "white", "red", "green"] and "grey" not in color
]
try:
import colorama
except ImportError:
raise ImportError(
"Please install colorama: `pip install colorama`"
)
MAX_STEPS = 20 # We should not need this parameter for paper reviews anyway
# Set logging level to ERROR
logging.getLogger().setLevel(logging.ERROR)
class ArenaCLI:
"""The CLI user interface for ChatArena."""
def __init__(self, arena: Arena):
self.arena = arena
self.args = arena.args
def launch(self, max_steps: int = None, interactive: bool = True):
"""Run the CLI."""
if not interactive and max_steps is None:
max_steps = MAX_STEPS
args = self.args
console = Console()
# Print ascii art
timestep = self.arena.reset()
console.print("🎓AgentReview Initialized!", style="bold green")
env: Union[PaperReview, PaperDecision] = self.arena.environment
players = self.arena.players
env_desc = self.arena.global_prompt
num_players = env.num_players
player_colors = visible_colors[:num_players] # sample different colors for players
name_to_color = dict(zip(env.player_names, player_colors))
print("name_to_color: ", name_to_color)
# System and Moderator messages are printed in red
name_to_color["System"] = "red"
name_to_color["Moderator"] = "red"
console.print(
f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}"
)
# Print the player name, role_desc and backend_type
for i, player in enumerate(players):
player_name_str = f"[{player.name} ({player.backend.type_name})] Role Description:"
# player_name = Text(player_name_str)
# player_name.stylize(f"bold {name_to_color[player.name]} underline")
# console.print(player_name)
# console.print(player.role_desc)
logging.info(color_dict[name_to_color[player.name]] + player_name_str + CRStyle.RESET_ALL)
logging.info(color_dict[name_to_color[player.name]] + player.role_desc + CRStyle.RESET_ALL)
console.print(Fore.GREEN + "\n========= Arena Start! ==========\n" + CRStyle.RESET_ALL)
step = 0
while not timestep.terminal:
if env.type_name == "paper_review":
if env.phase_index > 4:
break
elif env.type_name == "paper_decision":
# Phase 5: AC makes decisions
if env.phase_index > 5:
break
else:
raise NotImplementedError(f"Unknown environment type: {env.type_name}")
if interactive:
command = prompt(
[("class:command", "command (n/r/q/s/h) > ")],
style=Style.from_dict({"command": "blue"}),
completer=WordCompleter(
[
"next",
"n",
"reset",
"r",
"exit",
"quit",
"q",
"help",
"h",
"save",
"s",
]
),
)
command = command.strip()
if command == "help" or command == "h":
console.print("Available commands:")
console.print(" [bold]next or n or <Enter>[/]: next step")
console.print(" [bold]exit or quit or q[/]: exit the game")
console.print(" [bold]help or h[/]: print this message")
console.print(" [bold]reset or r[/]: reset the game")
console.print(" [bold]save or s[/]: save the history to file")
continue
elif command == "exit" or command == "quit" or command == "q":
break
elif command == "reset" or command == "r":
timestep = self.arena.reset()
console.print(
"\n========= Arena Reset! ==========\n", style="bold green"
)
continue
elif command == "next" or command == "n" or command == "":
pass
elif command == "save" or command == "s":
# Prompt to get the file path
file_path = prompt(
[("class:command", "save file path > ")],
style=Style.from_dict({"command": "blue"}),
)
file_path = file_path.strip()
# Save the history to file
self.arena.save_history(file_path)
# Print the save success message
console.print(f"History saved to {file_path}", style="bold green")
else:
console.print(f"Invalid command: {command}", style="bold red")
continue
try:
timestep = self.arena.step()
except HumanBackendError as e:
# Handle human input and recover with the game update
human_player_name = env.get_next_player()
if interactive:
human_input = prompt(
[
(
"class:user_prompt",
f"Type your input for {human_player_name}: ",
)
],
style=Style.from_dict({"user_prompt": "ansicyan underline"}),
)
# If not, the conversation does not stop
timestep = env.step(human_player_name, human_input)
else:
raise e # cannot recover from this error in non-interactive mode
except TooManyInvalidActions as e:
# Print the error message
# console.print(f"Too many invalid actions: {e}", style="bold red")
print(Fore.RED + "This will be red text" + CRStyle.RESET_ALL)
break
# The messages that are not yet logged
messages = [msg for msg in env.get_observation() if not msg.logged]
# Print the new messages
for msg in messages:
message_str = f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}"
if self.args.skip_logging:
console.print(color_dict[name_to_color[msg.agent_name]] + message_str + CRStyle.RESET_ALL)
msg.logged = True
step += 1
if max_steps is not None and step >= max_steps:
break
console.print("\n========= Arena Ended! ==========\n", style="bold red")
if env.type_name == "paper_review":
paper_id = self.arena.environment.paper_id
rebuttal_dir = get_rebuttal_dir(output_dir=self.args.output_dir,
paper_id=paper_id,
experiment_name=self.args.experiment_name,
model_name=self.args.model_name,
conference=self.args.conference)
os.makedirs(rebuttal_dir, exist_ok=True)
path_review_history = f"{rebuttal_dir}/{paper_id}.json"
if osp.exists(path_review_history):
raise Exception(f"History already exists!! ({path_review_history}). There must be something wrong with "
f"the path to save the history ")
self.arena.save_history(path_review_history)
elif env.type_name == "paper_decision":
ac_decisions = load_llm_ac_decisions(output_dir=args.output_dir,
conference=args.conference,
model_name=args.model_name,
ac_scoring_method=args.ac_scoring_method,
experiment_name=args.experiment_name,
num_papers_per_area_chair=args.num_papers_per_area_chair)
ac_decisions += [env.ac_decisions]
save_llm_ac_decisions(ac_decisions,
output_dir=args.output_dir,
conference=args.conference,
model_name=args.model_name,
ac_scoring_method=args.ac_scoring_method,
experiment_name=args.experiment_name)