File size: 4,474 Bytes
2224132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import warnings
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='Eval Arguments.')
parser.add_argument('--method',
                    type=str,
                    choices=['direct', 'cot', 'react', 'rewoo'],
                    help='Paradigm to use')

parser.add_argument('--dataset',
                    type=str,
                    choices=["hotpot_qa", "trivia_qa", "gsm8k", "physics_question",
                             "sports_understanding", "strategy_qa", "sotu_qa"],
                    help='Dataset to use')

parser.add_argument('--sample_size',
                    type=int,
                    default=10,
                    help='Sample size to eval')

parser.add_argument('--toolset',
                    nargs='+',
                    default=['Google', 'Wikipedia', 'WolframAlpha', 'Calculator', 'LLM'],
                    help='Tools available to ALMs.')

parser.add_argument('--base_lm',
                    type=str,
                    default='text-davinci-003',
                    help='Base language model to use. Can be text-davinci-003, gpt-3.5-turbo or directory to alpca-lora')

parser.add_argument('--planner_lm',
                    type=str,
                    help='Base LM for Planner. Default to base_lm')

parser.add_argument('--solver_lm',
                    type=str,
                    help='Base LM for Solver. Default to base_lm')

parser.add_argument('--save_result',
                    action='store_true',
                    help='Save result to file')

parser.add_argument('--seed',
                    type=int,
                    default=2024,
                    help='Random seed')

parser.add_argument('--key_path',
                    type=str,
                    default='./keys/',
                    help='Path where you store your openai.key and serpapi.key. Default to ./key/')

args = parser.parse_args()

with open(os.path.join(args.key_path, 'openai.key'), 'r') as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()
with open(os.path.join(args.key_path, 'serpapi.key'), 'r') as f:
    os.environ["SERPAPI_API_KEY"] = f.read().strip()

from algos.PWS import *
from algos.notool import IO, CoT
from algos.react import ReactBase, ReactExtraTool
from utils.DataLoader import DataLoader
from utils.Evaluator import Evaluator
from utils.util import *


def save_data(dataset, data, save_path):
    dataset["preds"] = data["preds"]
    dataset["em"] = data["em"]
    dataset["f1"] = data["f1"]
    dataset["acc"] = data["acc"]
    dataset["wall_time"] = data["wall_time"]
    dataset["total_tokens"] = data["total_tokens"]
    dataset["steps"] = data["steps"]
    dataset["tool_cost"] = data["tool_cost"]
    dataset["token_cost"] = data["token_cost"]
    dataset["total_cost"] = data["total_cost"]
    dataset.to_csv(save_path, index=False)
    return dataset


def main(args):
    dataset = DataLoader(args.dataset, seed=args.seed).load(sample_size=args.sample_size)
    if args.method == 'direct':
        method = IO(model_name=args.base_lm)
        eval = Evaluator(args.dataset, dataset, method)
    elif args.method == 'cot':
        method = CoT(model_name=args.base_lm, fewshot=DEFAULT_EXEMPLARS_COT[args.dataset])
        eval = Evaluator(args.dataset, dataset, method)
    elif args.method == 'react':
        if args.dataset in ['hotpot_qa', 'trivia_qa']:
            method = ReactBase(model_name=args.base_lm, fewshot=DEFAULT_EXEMPLARS_REACT[args.dataset], verbose=False)
        else:
            method = ReactExtraTool(model_name=args.base_lm, available_tools=args.toolset,
                                    fewshot=DEFAULT_EXEMPLARS_REACT[args.dataset], verbose=False)
        eval = Evaluator(args.dataset, dataset, method)
    elif args.method == 'rewoo':
        if args.planner_lm is None:
            args.planner_lm = args.base_lm
        if args.solver_lm is None:
            args.solver_lm = args.base_lm
        method = PWS_Base(planner_model=args.planner_lm, solver_model=args.solver_lm,
                          fewshot=DEFAULT_EXEMPLARS_PWS[args.dataset], available_tools=args.toolset)
        eval = Evaluator(args.dataset, dataset, method)
    else:
        raise NotImplementedError

    responses, data = eval.run()
    if args.save_result:
        save_data(dataset, data, f'./results/eval_{args.dataset}_{args.method}_{args.base_lm}.csv')
    print(responses)


if __name__ == '__main__':
    main(args)