jonluca's picture
add precomputed voices, reformat code, remove unused code
1373f78 unverified
raw
history blame contribute delete
1.42 kB
import os
import yaml
import torch
from transformers import AlbertConfig, AlbertModel
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
# Call the original forward method
outputs = super().forward(*args, **kwargs)
# Only return the last_hidden_state
return outputs.last_hidden_state
def load_plbert(log_dir):
config_path = os.path.join(log_dir, "config.yml")
plbert_config = yaml.safe_load(open(config_path))
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
bert = CustomAlbert(albert_base_configuration)
ckpts = []
for f in os.listdir(log_dir):
if f.startswith("step_"):
ckpts.append(f)
iters = [
int(f.split("_")[-1].split(".")[0])
for f in ckpts
if os.path.isfile(os.path.join(log_dir, f))
]
iters = sorted(iters)[-1]
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu")
state_dict = checkpoint["net"]
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
if name.startswith("encoder."):
name = name[8:] # remove `encoder.`
new_state_dict[name] = v
del new_state_dict["embeddings.position_ids"]
bert.load_state_dict(new_state_dict, strict=False)
return bert