# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import atexit import json from collections.abc import AsyncGenerator, AsyncIterator, Sequence from typing import TYPE_CHECKING, Any, Optional, Union import requests from typing_extensions import override from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName from ..extras.misc import get_device_count, torch_gc from ..extras.packages import is_sglang_available from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..model import load_config, load_tokenizer from ..model.model_utils.quantization import QuantizationMethod from .base_engine import BaseEngine, Response if is_sglang_available(): from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore if TYPE_CHECKING: from ..data.mm_plugin import AudioInput, ImageInput, VideoInput logger = logging.get_logger(__name__) class SGLangEngine(BaseEngine): """Inference engine for SGLang models. This class wraps the SGLang engine to provide a consistent interface for text generation that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for better interaction and performance. The engine launches a server process and communicates with it via HTTP requests. For more details on the SGLang HTTP server approach, see: https://docs.sglang.ai/backend/send_request.html """ def __init__( self, model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: self.name = EngineName.SGLANG self.model_args = model_args config = load_config(model_args) # may download model from ms hub if getattr(config, "quantization_config", None): # gptq models should use float16 quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": model_args.infer_dtype = "float16" self.can_generate = finetuning_args.stage == "sft" tokenizer_module = load_tokenizer(model_args) self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.generating_args = generating_args.to_dict() launch_cmd = [ "python3 -m sglang.launch_server", f"--model-path {model_args.model_name_or_path}", f"--dtype {model_args.infer_dtype}", f"--context-length {model_args.sglang_maxlen}", f"--mem-fraction-static {model_args.sglang_mem_fraction}", f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}", f"--download-dir {model_args.cache_dir}", "--log-level error", ] launch_cmd = " ".join(launch_cmd) logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") try: torch_gc() self.server_process, port = launch_server_cmd(launch_cmd) self.base_url = f"http://localhost:{port}" atexit.register(self._cleanup_server) logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}") wait_for_server(self.base_url, timeout=300) logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}") try: response = requests.get(f"{self.base_url}/get_model_info", timeout=5) if response.status_code == 200: model_info = response.json() logger.info(f"SGLang server model info: {model_info}") except Exception as e: logger.debug(f"Note: could not get model info: {str(e)}") except Exception as e: logger.error(f"Failed to start SGLang server: {str(e)}") self._cleanup_server() # make sure to clean up any started process raise RuntimeError(f"SGLang server initialization failed: {str(e)}.") def _cleanup_server(self): r"""Clean up the server process when the engine is destroyed.""" if hasattr(self, "server_process") and self.server_process: try: logger.info("Terminating SGLang server process") terminate_process(self.server_process) logger.info("SGLang server process terminated") except Exception as e: logger.warning(f"Error terminating SGLang server: {str(e)}") async def _generate( self, messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[list["ImageInput"]] = None, videos: Optional[list["VideoInput"]] = None, audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncIterator[dict[str, Any]]: if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] messages = self.template.mm_plugin.process_messages( messages, images or [], videos or [], audios or [], self.processor ) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] enable_thinking = input_kwargs.pop("enable_thinking", None) enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking) prompt_length = len(prompt_ids) temperature: Optional[float] = input_kwargs.pop("temperature", None) top_p: Optional[float] = input_kwargs.pop("top_p", None) top_k: Optional[float] = input_kwargs.pop("top_k", None) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) if num_return_sequences != 1: raise NotImplementedError("SGLang only supports n=1.") if "max_new_tokens" in self.generating_args: max_tokens = self.generating_args["max_new_tokens"] elif "max_length" in self.generating_args: if self.generating_args["max_length"] > prompt_length: max_tokens = self.generating_args["max_length"] - prompt_length else: max_tokens = 1 if max_length: max_tokens = max_length - prompt_length if max_length > prompt_length else 1 if max_new_tokens: max_tokens = max_new_tokens sampling_params = { "temperature": temperature if temperature is not None else self.generating_args["temperature"], "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 "stop": stop, "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer), "max_new_tokens": max_tokens, "repetition_penalty": ( repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] ) or 1.0, # repetition_penalty must > 0 "skip_special_tokens": skip_special_tokens if skip_special_tokens is not None else self.generating_args["skip_special_tokens"], } def stream_request(): json_data = { "input_ids": prompt_ids, "sampling_params": sampling_params, "stream": True, } response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) if response.status_code != 200: raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") for chunk in response.iter_lines(decode_unicode=False): chunk = str(chunk.decode("utf-8")) if chunk == "data: [DONE]": break if chunk and chunk.startswith("data:"): yield json.loads(chunk[5:].strip("\n")) return await asyncio.to_thread(stream_request) @override async def chat( self, messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: final_output = None generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) for request_output in generator: final_output = request_output results = [ Response( response_text=final_output["text"], response_length=final_output["meta_info"]["completion_tokens"], prompt_length=final_output["meta_info"]["prompt_tokens"], finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length", ) ] return results @override async def stream_chat( self, messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[list["ImageInput"]] = None, videos: Optional[list["VideoInput"]] = None, audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) for result in generator: delta_text = result["text"][len(generated_text) :] generated_text = result["text"] yield delta_text @override async def get_scores( self, batch_input: list[str], **input_kwargs, ) -> list[float]: raise NotImplementedError("SGLang engine does not support `get_scores`.") def __del__(self): r"""Ensure server is cleaned up when object is deleted.""" self._cleanup_server() try: atexit.unregister(self._cleanup_server) except Exception: pass