Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/decoding
/IAD
/fairseq
/examples
/simultaneous_translation
/eval
/evaluate.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from agents import build_agent | |
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService | |
from fairseq.registry import REGISTRIES | |
DEFAULT_HOSTNAME = "localhost" | |
DEFAULT_PORT = 12321 | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" | |
) | |
parser.add_argument( | |
"--port", type=int, default=DEFAULT_PORT, help="server port number" | |
) | |
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") | |
parser.add_argument("--scorer-type", default="text", help="Scorer type") | |
parser.add_argument( | |
"--start-idx", | |
type=int, | |
default=0, | |
help="Start index of the sentence to evaluate", | |
) | |
parser.add_argument( | |
"--end-idx", | |
type=int, | |
default=float("inf"), | |
help="End index of the sentence to evaluate", | |
) | |
parser.add_argument( | |
"--scores", action="store_true", help="Request scores from server" | |
) | |
parser.add_argument("--reset-server", action="store_true", help="Reset the server") | |
parser.add_argument( | |
"--num-threads", type=int, default=10, help="Number of threads used by agent" | |
) | |
parser.add_argument( | |
"--local", action="store_true", default=False, help="Local evaluation" | |
) | |
args, _ = parser.parse_known_args() | |
for registry_name, REGISTRY in REGISTRIES.items(): | |
choice = getattr(args, registry_name, None) | |
if choice is not None: | |
cls = REGISTRY["registry"][choice] | |
if hasattr(cls, "add_args"): | |
cls.add_args(parser) | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = get_args() | |
if args.local: | |
session = SimulSTLocalEvaluationService(args) | |
else: | |
session = SimulSTEvaluationService(args.hostname, args.port) | |
if args.reset_server: | |
session.new_session() | |
if args.agent_type is not None: | |
agent = build_agent(args) | |
agent.decode(session, args.start_idx, args.end_idx, args.num_threads) | |
if args.scores: | |
session.get_scores() | |
print(session.get_scores()) | |