ramimu's picture
Upload 586 files
1c72248 verified
import os
from typing import List
from toolkit.models.base_model import BaseModel
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
from toolkit.paths import TOOLKIT_ROOT
import importlib
import pkgutil
from toolkit.models.wan21 import Wan21, Wan21I2V
from toolkit.models.cogview4 import CogView4
BUILT_IN_MODELS = [
Wan21,
Wan21I2V,
CogView4,
]
def get_all_models() -> List[BaseModel]:
extension_folders = ['extensions', 'extensions_built_in']
# This will hold the classes from all extension modules
all_model_classes: List[BaseModel] = BUILT_IN_MODELS
# Iterate over all directories (i.e., packages) in the "extensions" directory
for sub_dir in extension_folders:
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
try:
# Import the module
module = importlib.import_module(f"{sub_dir}.{name}")
# Get the value of the AI_TOOLKIT_MODELS variable
models = getattr(module, "AI_TOOLKIT_MODELS", None)
# Check if the value is a list
if isinstance(models, list):
# Iterate over the list and add the classes to the main list
all_model_classes.extend(models)
except ImportError as e:
print(f"Failed to import the {name} module. Error: {str(e)}")
return all_model_classes
def get_model_class(config: ModelConfig):
all_models = get_all_models()
for ModelClass in all_models:
if ModelClass.arch == config.arch:
return ModelClass
# default to the legacy model
return StableDiffusion