File size: 1,921 Bytes
690f890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import importlib
from typing import Dict, Any

def load_parser(module_name: str) -> argparse.ArgumentParser:
    module = importlib.import_module(module_name)
    if not hasattr(module, "get_parser"):
        raise ValueError(f"{module_name} undefined get_parser()")
    return module.get_parser()

def filter_args(args: Dict[str, Any], parser: argparse.ArgumentParser) -> Dict[str, Any]:
    known_args = set()
    for action in parser._actions:
        if action.dest and action.dest != "help":
            known_args.add(action.dest)
    return {k: v for k, v in args.items() if k in known_args}

def main():

    main_parser = argparse.ArgumentParser()
    main_parser.add_argument("--base", type=str, default='ltx', choices=['ltx', 'wan'])
    pipeline_args, _ = main_parser.parse_known_args()

    if pipeline_args.base in ["ltx"]:
        preproccess_name, inference_name = "vace_preproccess", "vace_ltx_inference"
    else:
        preproccess_name, inference_name = "vace_preproccess", "vace_wan_inference"

    preprocess_parser = load_parser(preproccess_name)
    inference_parser = load_parser(inference_name)

    for parser in [preprocess_parser, inference_parser]:
        for action in parser._actions:
            if action.dest != "help":
                main_parser._add_action(action)

    cli_args = main_parser.parse_args()
    args_dict = vars(cli_args)

    # run preprocess
    preprocess_args = filter_args(args_dict, preprocess_parser)
    preprocess_output = importlib.import_module(preproccess_name).main(preprocess_args)
    print("preprocess_output:", preprocess_output)

    # run inference
    inference_args = filter_args(args_dict, inference_parser)
    inference_args.update(preprocess_output)
    preprocess_output = importlib.import_module(inference_name).main(inference_args)
    print("preprocess_output:", preprocess_output)


if __name__ == "__main__":
    main()