Tzktz's picture
Upload 7664 files
6fc683c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
)
parser.add_argument(
"--port", type=int, default=DEFAULT_PORT, help="server port number"
)
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
parser.add_argument("--scorer-type", default="text", help="Scorer type")
parser.add_argument(
"--start-idx",
type=int,
default=0,
help="Start index of the sentence to evaluate",
)
parser.add_argument(
"--end-idx",
type=int,
default=float("inf"),
help="End index of the sentence to evaluate",
)
parser.add_argument(
"--scores", action="store_true", help="Request scores from server"
)
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
parser.add_argument(
"--num-threads", type=int, default=10, help="Number of threads used by agent"
)
parser.add_argument(
"--local", action="store_true", default=False, help="Local evaluation"
)
args, _ = parser.parse_known_args()
for registry_name, REGISTRY in REGISTRIES.items():
choice = getattr(args, registry_name, None)
if choice is not None:
cls = REGISTRY["registry"][choice]
if hasattr(cls, "add_args"):
cls.add_args(parser)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if args.local:
session = SimulSTLocalEvaluationService(args)
else:
session = SimulSTEvaluationService(args.hostname, args.port)
if args.reset_server:
session.new_session()
if args.agent_type is not None:
agent = build_agent(args)
agent.decode(session, args.start_idx, args.end_idx, args.num_threads)
if args.scores:
session.get_scores()
print(session.get_scores())