import os from typing import List from toolkit.models.base_model import BaseModel from toolkit.stable_diffusion_model import StableDiffusion from toolkit.config_modules import ModelConfig from toolkit.paths import TOOLKIT_ROOT import importlib import pkgutil from toolkit.models.wan21 import Wan21, Wan21I2V from toolkit.models.cogview4 import CogView4 BUILT_IN_MODELS = [ Wan21, Wan21I2V, CogView4, ] def get_all_models() -> List[BaseModel]: extension_folders = ['extensions', 'extensions_built_in'] # This will hold the classes from all extension modules all_model_classes: List[BaseModel] = BUILT_IN_MODELS # Iterate over all directories (i.e., packages) in the "extensions" directory for sub_dir in extension_folders: extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) for (_, name, _) in pkgutil.iter_modules([extensions_dir]): try: # Import the module module = importlib.import_module(f"{sub_dir}.{name}") # Get the value of the AI_TOOLKIT_MODELS variable models = getattr(module, "AI_TOOLKIT_MODELS", None) # Check if the value is a list if isinstance(models, list): # Iterate over the list and add the classes to the main list all_model_classes.extend(models) except ImportError as e: print(f"Failed to import the {name} module. Error: {str(e)}") return all_model_classes def get_model_class(config: ModelConfig): all_models = get_all_models() for ModelClass in all_models: if ModelClass.arch == config.arch: return ModelClass # default to the legacy model return StableDiffusion