Spaces:
Sleeping
Sleeping
import os | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from typing import Optional, Dict, Tuple | |
from torch import Tensor | |
from transformers import ( | |
AutoModel, | |
PreTrainedModel, | |
) | |
from transformers.modeling_outputs import ModelOutput | |
from config import Arguments | |
from logger_config import logger | |
from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels | |
class BiencoderOutput(ModelOutput): | |
q_reps: Optional[Tensor] = None | |
p_reps: Optional[Tensor] = None | |
loss: Optional[Tensor] = None | |
labels: Optional[Tensor] = None | |
scores: Optional[Tensor] = None | |
class BiencoderModel(nn.Module): | |
def __init__(self, args: Arguments, | |
lm_q: PreTrainedModel, | |
lm_p: PreTrainedModel): | |
super().__init__() | |
self.lm_q = lm_q | |
self.lm_p = lm_p | |
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) | |
self.args = args | |
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity() | |
from trainers import BiencoderTrainer | |
self.trainer: Optional[BiencoderTrainer] = None | |
def forward(self, query: Dict[str, Tensor] = None, | |
passage: Dict[str, Tensor] = None): | |
assert self.args.process_index >= 0 | |
scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(query, passage) | |
start = self.args.process_index * q_reps.shape[0] | |
group_indices = select_grouped_indices(scores=scores, | |
group_size=self.args.train_n_passages, | |
start=start * self.args.train_n_passages) | |
if not self.args.do_kd_biencoder: | |
# training biencoder from scratch | |
if self.args.use_scaled_loss: | |
loss = self.cross_entropy(all_scores, all_labels) | |
loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale | |
else: | |
loss = self.cross_entropy(scores, labels) | |
else: | |
# training biencoder with kd | |
# batch_size x train_n_passage | |
group_scores = torch.gather(input=scores, dim=1, index=group_indices) | |
assert group_scores.shape[1] == self.args.train_n_passages | |
group_log_scores = torch.log_softmax(group_scores, dim=-1) | |
kd_log_target = torch.log_softmax(query['kd_labels'], dim=-1) | |
kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target) | |
# (optionally) mask out hard negatives | |
if self.training and self.args.kd_mask_hn: | |
scores = torch.scatter(input=scores, dim=1, index=group_indices[:, 1:], value=float('-inf')) | |
if self.args.use_scaled_loss: | |
ce_loss = self.cross_entropy(all_scores, all_labels) | |
ce_loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale | |
else: | |
ce_loss = self.cross_entropy(scores, labels) | |
loss = self.args.kd_cont_loss_weight * ce_loss + kd_loss | |
total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages | |
return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps, | |
labels=labels.contiguous(), | |
scores=scores[:, :total_n_psg].contiguous()) | |
def _compute_scores(self, query: Dict[str, Tensor] = None, | |
passage: Dict[str, Tensor] = None) -> Tuple: | |
q_reps = self._encode(self.lm_q, query) | |
p_reps = self._encode(self.lm_p, passage) | |
all_q_reps = dist_gather_tensor(q_reps) | |
all_p_reps = dist_gather_tensor(p_reps) | |
assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages | |
all_scores, all_labels = full_contrastive_scores_and_labels( | |
query=all_q_reps, key=all_p_reps, | |
use_all_pairs=self.args.full_contrastive_loss) | |
if self.args.l2_normalize: | |
if self.args.t_warmup: | |
scale = 1 / self.args.t * min(1.0, self.trainer.state.global_step / self.args.warmup_steps) | |
scale = max(1.0, scale) | |
else: | |
scale = 1 / self.args.t | |
all_scores = all_scores * scale | |
start = self.args.process_index * q_reps.shape[0] | |
local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device) | |
# batch_size x (world_size x batch_size x train_n_passage) | |
scores = all_scores.index_select(dim=0, index=local_query_indices) | |
labels = all_labels.index_select(dim=0, index=local_query_indices) | |
return scores, labels, q_reps, p_reps, all_scores, all_labels | |
def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]: | |
if not input_dict: | |
return None | |
outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['kd_labels']}, return_dict=True) | |
hidden_state = outputs.last_hidden_state | |
embeds = hidden_state[:, 0] | |
embeds = self.pooler(embeds) | |
if self.args.l2_normalize: | |
embeds = F.normalize(embeds, dim=-1) | |
return embeds.contiguous() | |
def build(cls, args: Arguments, **hf_kwargs): | |
# load local | |
if os.path.isdir(args.model_name_or_path): | |
if not args.share_encoder: | |
_qry_model_path = os.path.join(args.model_name_or_path, 'query_model') | |
_psg_model_path = os.path.join(args.model_name_or_path, 'passage_model') | |
if not os.path.exists(_qry_model_path): | |
_qry_model_path = args.model_name_or_path | |
_psg_model_path = args.model_name_or_path | |
logger.info(f'loading query model weight from {_qry_model_path}') | |
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs) | |
logger.info(f'loading passage model weight from {_psg_model_path}') | |
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs) | |
else: | |
logger.info(f'loading shared model weight from {args.model_name_or_path}') | |
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs) | |
lm_p = lm_q | |
# load pre-trained | |
else: | |
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs) | |
lm_p = copy.deepcopy(lm_q) if not args.share_encoder else lm_q | |
model = cls(args=args, lm_q=lm_q, lm_p=lm_p) | |
return model | |
def save(self, output_dir: str): | |
if not self.args.share_encoder: | |
os.makedirs(os.path.join(output_dir, 'query_model'), exist_ok=True) | |
os.makedirs(os.path.join(output_dir, 'passage_model'), exist_ok=True) | |
self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model')) | |
self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model')) | |
else: | |
self.lm_q.save_pretrained(output_dir) | |
if self.args.add_pooler: | |
torch.save(self.pooler.state_dict(), os.path.join(output_dir, 'pooler.pt')) | |
class BiencoderModelForInference(BiencoderModel): | |
def __init__(self, args: Arguments, | |
lm_q: PreTrainedModel, | |
lm_p: PreTrainedModel): | |
nn.Module.__init__(self) | |
self.args = args | |
self.lm_q = lm_q | |
self.lm_p = lm_p | |
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity() | |
def forward(self, query: Dict[str, Tensor] = None, | |
passage: Dict[str, Tensor] = None): | |
q_reps = self._encode(self.lm_q, query) | |
p_reps = self._encode(self.lm_p, passage) | |
return BiencoderOutput(q_reps=q_reps, p_reps=p_reps) | |
def build(cls, args: Arguments, **hf_kwargs): | |
model_name_or_path = args.model_name_or_path | |
# load local | |
if os.path.isdir(model_name_or_path): | |
_qry_model_path = os.path.join(model_name_or_path, 'query_model') | |
_psg_model_path = os.path.join(model_name_or_path, 'passage_model') | |
if os.path.exists(_qry_model_path): | |
logger.info(f'found separate weight for query/passage encoders') | |
logger.info(f'loading query model weight from {_qry_model_path}') | |
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs) | |
logger.info(f'loading passage model weight from {_psg_model_path}') | |
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs) | |
else: | |
logger.info(f'try loading tied weight') | |
logger.info(f'loading model weight from {model_name_or_path}') | |
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) | |
lm_p = lm_q | |
else: | |
logger.info(f'try loading tied weight {model_name_or_path}') | |
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs) | |
lm_p = lm_q | |
model = cls(args=args, lm_q=lm_q, lm_p=lm_p) | |
pooler_path = os.path.join(args.model_name_or_path, 'pooler.pt') | |
if os.path.exists(pooler_path): | |
logger.info('loading pooler weights from local files') | |
state_dict = torch.load(pooler_path, map_location="cpu") | |
model.pooler.load_state_dict(state_dict) | |
else: | |
assert not args.add_pooler | |
logger.info('No pooler will be loaded') | |
return model | |