Spaces:
Paused
Paused
import os | |
from typing import List | |
from toolkit.models.base_model import BaseModel | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from toolkit.config_modules import ModelConfig | |
from toolkit.paths import TOOLKIT_ROOT | |
import importlib | |
import pkgutil | |
from toolkit.models.wan21 import Wan21, Wan21I2V | |
from toolkit.models.cogview4 import CogView4 | |
BUILT_IN_MODELS = [ | |
Wan21, | |
Wan21I2V, | |
CogView4, | |
] | |
def get_all_models() -> List[BaseModel]: | |
extension_folders = ['extensions', 'extensions_built_in'] | |
# This will hold the classes from all extension modules | |
all_model_classes: List[BaseModel] = BUILT_IN_MODELS | |
# Iterate over all directories (i.e., packages) in the "extensions" directory | |
for sub_dir in extension_folders: | |
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) | |
for (_, name, _) in pkgutil.iter_modules([extensions_dir]): | |
try: | |
# Import the module | |
module = importlib.import_module(f"{sub_dir}.{name}") | |
# Get the value of the AI_TOOLKIT_MODELS variable | |
models = getattr(module, "AI_TOOLKIT_MODELS", None) | |
# Check if the value is a list | |
if isinstance(models, list): | |
# Iterate over the list and add the classes to the main list | |
all_model_classes.extend(models) | |
except ImportError as e: | |
print(f"Failed to import the {name} module. Error: {str(e)}") | |
return all_model_classes | |
def get_model_class(config: ModelConfig): | |
all_models = get_all_models() | |
for ModelClass in all_models: | |
if ModelClass.arch == config.arch: | |
return ModelClass | |
# default to the legacy model | |
return StableDiffusion | |