ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
from typing import Any, Dict
from pie_modules.models import * # noqa: F403
from pie_modules.taskmodules import * # noqa: F403
from pytorch_ie import AutoModel, AutoTaskModule, PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.taskmodules import * # noqa: F403
from transformers import PreTrainedModel, PreTrainedTokenizer
def load_model_from_pie_model(model_kwargs: Dict[str, Any]) -> PreTrainedModel:
pie_model: PyTorchIEModel = AutoModel.from_pretrained(**model_kwargs)
return pie_model.model.model
def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> PreTrainedTokenizer:
pie_taskmodule: TaskModule = AutoTaskModule.from_pretrained(**taskmodule_kwargs)
return pie_taskmodule.tokenizer
def load_model_with_adapter(
model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
) -> PreTrainedModel:
from adapters import AutoAdapterModel
model = AutoAdapterModel.from_pretrained(**model_kwargs)
model.load_adapter(set_active=True, **adapter_kwargs)
return model