Spaces:
Running
Running
# Copyright 2025 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import subprocess | |
import sys | |
from copy import deepcopy | |
from functools import partial | |
USAGE = ( | |
"-" * 70 | |
+ "\n" | |
+ "| Usage: |\n" | |
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" | |
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" | |
+ "| llamafactory-cli eval -h: evaluate models |\n" | |
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" | |
+ "| llamafactory-cli train -h: train models |\n" | |
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" | |
+ "| llamafactory-cli webui: launch LlamaBoard |\n" | |
+ "| llamafactory-cli version: show version info |\n" | |
+ "-" * 70 | |
) | |
def main(): | |
from . import launcher | |
from .api.app import run_api | |
from .chat.chat_model import run_chat | |
from .eval.evaluator import run_eval | |
from .extras import logging | |
from .extras.env import VERSION, print_env | |
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray | |
from .train.tuner import export_model, run_exp | |
from .webui.interface import run_web_demo, run_web_ui | |
logger = logging.get_logger(__name__) | |
WELCOME = ( | |
"-" * 58 | |
+ "\n" | |
+ f"| Welcome to LLaMA Factory, version {VERSION}" | |
+ " " * (21 - len(VERSION)) | |
+ "|\n|" | |
+ " " * 56 | |
+ "|\n" | |
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" | |
+ "-" * 58 | |
) | |
COMMAND_MAP = { | |
"api": run_api, | |
"chat": run_chat, | |
"env": print_env, | |
"eval": run_eval, | |
"export": export_model, | |
"train": run_exp, | |
"webchat": run_web_demo, | |
"webui": run_web_ui, | |
"version": partial(print, WELCOME), | |
"help": partial(print, USAGE), | |
} | |
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" | |
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): | |
# launch distributed training | |
nnodes = os.getenv("NNODES", "1") | |
node_rank = os.getenv("NODE_RANK", "0") | |
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) | |
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") | |
master_port = os.getenv("MASTER_PORT", str(find_available_port())) | |
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") | |
if int(nnodes) > 1: | |
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") | |
env = deepcopy(os.environ) | |
if is_env_enabled("OPTIM_TORCH", "1"): | |
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 | |
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | |
# NOTE: DO NOT USE shell=True to avoid security risk | |
process = subprocess.run( | |
( | |
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " | |
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" | |
) | |
.format( | |
nnodes=nnodes, | |
node_rank=node_rank, | |
nproc_per_node=nproc_per_node, | |
master_addr=master_addr, | |
master_port=master_port, | |
file_name=launcher.__file__, | |
args=" ".join(sys.argv[1:]), | |
) | |
.split(), | |
env=env, | |
check=True, | |
) | |
sys.exit(process.returncode) | |
elif command in COMMAND_MAP: | |
COMMAND_MAP[command]() | |
else: | |
print(f"Unknown command: {command}.\n{USAGE}") | |
if __name__ == "__main__": | |
from multiprocessing import freeze_support | |
freeze_support() | |
main() | |