Spaces:
Sleeping
Sleeping
import logging | |
from argparse import ArgumentParser | |
from transformers.commands import BaseTransformersCLICommand | |
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
def try_infer_format_from_ext(path: str): | |
if not path: | |
return "pipe" | |
for ext in PipelineDataFormat.SUPPORTED_FORMATS: | |
if path.endswith(ext): | |
return ext | |
raise Exception( | |
"Unable to determine file format from file extension {}. " | |
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS) | |
) | |
def run_command_factory(args): | |
nlp = pipeline( | |
task=args.task, | |
model=args.model if args.model else None, | |
config=args.config, | |
tokenizer=args.tokenizer, | |
device=args.device, | |
) | |
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format | |
reader = PipelineDataFormat.from_str( | |
format=format, | |
output_path=args.output, | |
input_path=args.input, | |
column=args.column if args.column else nlp.default_input_names, | |
overwrite=args.overwrite, | |
) | |
return RunCommand(nlp, reader) | |
class RunCommand(BaseTransformersCLICommand): | |
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): | |
self._nlp = nlp | |
self._reader = reader | |
def register_subcommand(parser: ArgumentParser): | |
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") | |
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run") | |
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") | |
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") | |
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") | |
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.") | |
run_parser.add_argument( | |
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)" | |
) | |
run_parser.add_argument( | |
"--column", | |
type=str, | |
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)", | |
) | |
run_parser.add_argument( | |
"--format", | |
type=str, | |
default="infer", | |
choices=PipelineDataFormat.SUPPORTED_FORMATS, | |
help="Input format to read from", | |
) | |
run_parser.add_argument( | |
"--device", | |
type=int, | |
default=-1, | |
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", | |
) | |
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.") | |
run_parser.set_defaults(func=run_command_factory) | |
def run(self): | |
nlp, outputs = self._nlp, [] | |
for entry in self._reader: | |
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry) | |
if isinstance(output, dict): | |
outputs.append(output) | |
else: | |
outputs += output | |
# Saving data | |
if self._nlp.binary_output: | |
binary_path = self._reader.save_binary(outputs) | |
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path)) | |
else: | |
self._reader.save(outputs) | |