File size: 3,293 Bytes
009d93e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import yaml
from pipeline import Pipeline
from typing import Literal
import models
from models import *
from utils import *
from modules import *

def load_extraction_config(yaml_path):
    # 从文件路径读取 YAML 内容
    if not os.path.exists(yaml_path):
        print(f"Error: The config file '{yaml_path}' does not exist.")
        return {}

    with open(yaml_path, 'r') as file:
        config = yaml.safe_load(file)

    # 提取'extraction'配置的字典
    model_config = config.get('model', {})
    extraction_config = config.get('extraction', {})
    # model config
    model_name_or_path = model_config.get('model_name_or_path', "")
    model_category = model_config.get('category', "")
    api_key = model_config.get('api_key', "")
    base_url = model_config.get('base_url', "")

    # extraction config
    task = extraction_config.get('task', "")
    instruction = extraction_config.get('instruction', "")
    text = extraction_config.get('text', "")
    output_schema = extraction_config.get('output_schema', "")
    constraint = extraction_config.get('constraint', "")
    truth = extraction_config.get('truth', "")
    use_file = extraction_config.get('use_file', False)
    mode = extraction_config.get('mode', "quick")
    update_case = extraction_config.get('update_case', False)

    # 返回一个包含这些变量的字典
    return {
        "model": {
            "model_name_or_path": model_name_or_path,
            "category": model_category,
            "api_key": api_key,
            "base_url": base_url
        },
        "extraction": {
            "task": task,
            "instruction": instruction,
            "text": text,
            "output_schema": output_schema,
            "constraint": constraint,
            "truth": truth,
            "use_file": use_file,
            "mode": mode,
            "update_case": update_case
        }
    }


def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description='Run the extraction model.')
    parser.add_argument('--config', type=str, required=True,
                        help='Path to the YAML configuration file.')

    # 解析命令行参数
    args = parser.parse_args()

    # 加载配置
    config = load_extraction_config(args.config)
    model_config = config['model']
    extraction_config = config['extraction']
    clazz = getattr(models, model_config['category'], None)
    if clazz is None:
        print(f"Error: The model category '{model_config['category']}' is not supported.")
        return
    if model_config['api_key'] == "":
        model = clazz(model_config['model_name_or_path'])
    else:
        model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
    pipeline = Pipeline(model)
    result, trajectory, *_ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'])
    return

if __name__ == "__main__":
    main()