File size: 3,524 Bytes
2224132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import warnings
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='Eval Arguments.')
parser.add_argument('--method',
                    type=str,
                    choices=['direct', 'cot', 'react', 'rewoo'],
                    help='Paradigm to use')

parser.add_argument('--exemplar',
                    type=str,
                    help='Input exemplar')

parser.add_argument('--toolset',
                    nargs='+',
                    default=['Google', 'Wikipedia', 'WolframAlpha', 'Calculator', 'LLM'],
                    help='Tools available to ALMs.')

parser.add_argument('--base_lm',
                    type=str,
                    default='text-davinci-003',
                    help='Base language model to use. Can be text-davinci-003, gpt-3.5-turbo or directory to alpca-lora')

parser.add_argument('--planner_lm',
                    type=str,
                    help='Base LM for Planner. Default to base_lm')

parser.add_argument('--solver_lm',
                    type=str,
                    help='Base LM for Solver. Default to base_lm')

parser.add_argument('--print_trajectory',
                    action='store_true',
                    help='Print reasoning traces to stdout (Only for ALMs)')

parser.add_argument('--key_path',
                    type=str,
                    default='./keys/',
                    help='Path where you store your openai.key and serpapi.key. Default to ./key/')

args = parser.parse_args()

# os.environ["OPENAI_API_KEY"] = "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"
with open(os.path.join(args.key_path, 'openai.key'), 'r') as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()
with open(os.path.join(args.key_path, 'serpapi.key'), 'r') as f:
    os.environ["SERPAPI_API_KEY"] = f.read().strip()

from algos.PWS import *
from algos.notool import IO, CoT
from algos.react import ReactExtraTool
from utils.util import *


def main(args):
    task = input("Ask a question or give a task: ")
    if args.method == 'direct':
        method = IO(model_name=args.base_lm)
        response = method.run(task)
    elif args.method == 'cot':
        method = CoT(model_name=args.base_lm, fewshot=DEFAULT_EXEMPLARS_COT[args.dataset])
        response = method.run(task)
    elif args.method == 'react':
        if args.exemplar is None:
            args.exemplar = fewshots.DEFAULT_REACT
        method = ReactExtraTool(model_name=args.base_lm, available_tools=args.toolset,
                                fewshot=args.exemplar, verbose=args.print_trajectory)
        response = method.run(task)
    elif args.method == 'rewoo':
        if args.planner_lm is None:
            args.planner_lm = args.base_lm
        if args.solver_lm is None:
            args.solver_lm = args.base_lm
        if args.exemplar is None:
            args.exemplar = fewshots.TRIVIAQA_PWS
        method = PWS_Base(planner_model=args.planner_lm, solver_model=args.solver_lm,
                          fewshot=args.exemplar, available_tools=args.toolset)
        response = method.run(task)
        if args.print_trajectory:
            print("=== Planner ===" + '\n\n' + response["planner_log"] + '\n' + "=== Solver ===" + '\n\n' + response[
                "solver_log"])
    else:
        raise NotImplementedError

    print(response["output"])


if __name__ == '__main__':
    main(args)