Tzktz's picture
Upload 7664 files
6fc683c verified
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Dict, Tuple
from torch import Tensor
from transformers import (
AutoModel,
PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput
from config import Arguments
from logger_config import logger
from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels
@dataclass
class BiencoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
labels: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiencoderModel(nn.Module):
def __init__(self, args: Arguments,
lm_q: PreTrainedModel,
lm_p: PreTrainedModel):
super().__init__()
self.lm_q = lm_q
self.lm_p = lm_p
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
self.args = args
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()
from trainers import BiencoderTrainer
self.trainer: Optional[BiencoderTrainer] = None
def forward(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None):
assert self.args.process_index >= 0
scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(query, passage)
start = self.args.process_index * q_reps.shape[0]
group_indices = select_grouped_indices(scores=scores,
group_size=self.args.train_n_passages,
start=start * self.args.train_n_passages)
if not self.args.do_kd_biencoder:
# training biencoder from scratch
if self.args.use_scaled_loss:
loss = self.cross_entropy(all_scores, all_labels)
loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
else:
loss = self.cross_entropy(scores, labels)
else:
# training biencoder with kd
# batch_size x train_n_passage
group_scores = torch.gather(input=scores, dim=1, index=group_indices)
assert group_scores.shape[1] == self.args.train_n_passages
group_log_scores = torch.log_softmax(group_scores, dim=-1)
kd_log_target = torch.log_softmax(query['kd_labels'], dim=-1)
kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target)
# (optionally) mask out hard negatives
if self.training and self.args.kd_mask_hn:
scores = torch.scatter(input=scores, dim=1, index=group_indices[:, 1:], value=float('-inf'))
if self.args.use_scaled_loss:
ce_loss = self.cross_entropy(all_scores, all_labels)
ce_loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
else:
ce_loss = self.cross_entropy(scores, labels)
loss = self.args.kd_cont_loss_weight * ce_loss + kd_loss
total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps,
labels=labels.contiguous(),
scores=scores[:, :total_n_psg].contiguous())
def _compute_scores(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None) -> Tuple:
q_reps = self._encode(self.lm_q, query)
p_reps = self._encode(self.lm_p, passage)
all_q_reps = dist_gather_tensor(q_reps)
all_p_reps = dist_gather_tensor(p_reps)
assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
all_scores, all_labels = full_contrastive_scores_and_labels(
query=all_q_reps, key=all_p_reps,
use_all_pairs=self.args.full_contrastive_loss)
if self.args.l2_normalize:
if self.args.t_warmup:
scale = 1 / self.args.t * min(1.0, self.trainer.state.global_step / self.args.warmup_steps)
scale = max(1.0, scale)
else:
scale = 1 / self.args.t
all_scores = all_scores * scale
start = self.args.process_index * q_reps.shape[0]
local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device)
# batch_size x (world_size x batch_size x train_n_passage)
scores = all_scores.index_select(dim=0, index=local_query_indices)
labels = all_labels.index_select(dim=0, index=local_query_indices)
return scores, labels, q_reps, p_reps, all_scores, all_labels
def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]:
if not input_dict:
return None
outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['kd_labels']}, return_dict=True)
hidden_state = outputs.last_hidden_state
embeds = hidden_state[:, 0]
embeds = self.pooler(embeds)
if self.args.l2_normalize:
embeds = F.normalize(embeds, dim=-1)
return embeds.contiguous()
@classmethod
def build(cls, args: Arguments, **hf_kwargs):
# load local
if os.path.isdir(args.model_name_or_path):
if not args.share_encoder:
_qry_model_path = os.path.join(args.model_name_or_path, 'query_model')
_psg_model_path = os.path.join(args.model_name_or_path, 'passage_model')
if not os.path.exists(_qry_model_path):
_qry_model_path = args.model_name_or_path
_psg_model_path = args.model_name_or_path
logger.info(f'loading query model weight from {_qry_model_path}')
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f'loading passage model weight from {_psg_model_path}')
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
else:
logger.info(f'loading shared model weight from {args.model_name_or_path}')
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
lm_p = lm_q
# load pre-trained
else:
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
lm_p = copy.deepcopy(lm_q) if not args.share_encoder else lm_q
model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
return model
def save(self, output_dir: str):
if not self.args.share_encoder:
os.makedirs(os.path.join(output_dir, 'query_model'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'passage_model'), exist_ok=True)
self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
else:
self.lm_q.save_pretrained(output_dir)
if self.args.add_pooler:
torch.save(self.pooler.state_dict(), os.path.join(output_dir, 'pooler.pt'))
class BiencoderModelForInference(BiencoderModel):
def __init__(self, args: Arguments,
lm_q: PreTrainedModel,
lm_p: PreTrainedModel):
nn.Module.__init__(self)
self.args = args
self.lm_q = lm_q
self.lm_p = lm_p
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()
@torch.no_grad()
def forward(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None):
q_reps = self._encode(self.lm_q, query)
p_reps = self._encode(self.lm_p, passage)
return BiencoderOutput(q_reps=q_reps, p_reps=p_reps)
@classmethod
def build(cls, args: Arguments, **hf_kwargs):
model_name_or_path = args.model_name_or_path
# load local
if os.path.isdir(model_name_or_path):
_qry_model_path = os.path.join(model_name_or_path, 'query_model')
_psg_model_path = os.path.join(model_name_or_path, 'passage_model')
if os.path.exists(_qry_model_path):
logger.info(f'found separate weight for query/passage encoders')
logger.info(f'loading query model weight from {_qry_model_path}')
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f'loading passage model weight from {_psg_model_path}')
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
else:
logger.info(f'try loading tied weight')
logger.info(f'loading model weight from {model_name_or_path}')
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
lm_p = lm_q
else:
logger.info(f'try loading tied weight {model_name_or_path}')
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
lm_p = lm_q
model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
pooler_path = os.path.join(args.model_name_or_path, 'pooler.pt')
if os.path.exists(pooler_path):
logger.info('loading pooler weights from local files')
state_dict = torch.load(pooler_path, map_location="cpu")
model.pooler.load_state_dict(state_dict)
else:
assert not args.add_pooler
logger.info('No pooler will be loaded')
return model