Spaces:
Sleeping
Sleeping
File size: 9,863 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Dict, Tuple
from torch import Tensor
from transformers import (
AutoModel,
PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput
from config import Arguments
from logger_config import logger
from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels
@dataclass
class BiencoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
labels: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiencoderModel(nn.Module):
def __init__(self, args: Arguments,
lm_q: PreTrainedModel,
lm_p: PreTrainedModel):
super().__init__()
self.lm_q = lm_q
self.lm_p = lm_p
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
self.args = args
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()
from trainers import BiencoderTrainer
self.trainer: Optional[BiencoderTrainer] = None
def forward(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None):
assert self.args.process_index >= 0
scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(query, passage)
start = self.args.process_index * q_reps.shape[0]
group_indices = select_grouped_indices(scores=scores,
group_size=self.args.train_n_passages,
start=start * self.args.train_n_passages)
if not self.args.do_kd_biencoder:
# training biencoder from scratch
if self.args.use_scaled_loss:
loss = self.cross_entropy(all_scores, all_labels)
loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
else:
loss = self.cross_entropy(scores, labels)
else:
# training biencoder with kd
# batch_size x train_n_passage
group_scores = torch.gather(input=scores, dim=1, index=group_indices)
assert group_scores.shape[1] == self.args.train_n_passages
group_log_scores = torch.log_softmax(group_scores, dim=-1)
kd_log_target = torch.log_softmax(query['kd_labels'], dim=-1)
kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target)
# (optionally) mask out hard negatives
if self.training and self.args.kd_mask_hn:
scores = torch.scatter(input=scores, dim=1, index=group_indices[:, 1:], value=float('-inf'))
if self.args.use_scaled_loss:
ce_loss = self.cross_entropy(all_scores, all_labels)
ce_loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
else:
ce_loss = self.cross_entropy(scores, labels)
loss = self.args.kd_cont_loss_weight * ce_loss + kd_loss
total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps,
labels=labels.contiguous(),
scores=scores[:, :total_n_psg].contiguous())
def _compute_scores(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None) -> Tuple:
q_reps = self._encode(self.lm_q, query)
p_reps = self._encode(self.lm_p, passage)
all_q_reps = dist_gather_tensor(q_reps)
all_p_reps = dist_gather_tensor(p_reps)
assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
all_scores, all_labels = full_contrastive_scores_and_labels(
query=all_q_reps, key=all_p_reps,
use_all_pairs=self.args.full_contrastive_loss)
if self.args.l2_normalize:
if self.args.t_warmup:
scale = 1 / self.args.t * min(1.0, self.trainer.state.global_step / self.args.warmup_steps)
scale = max(1.0, scale)
else:
scale = 1 / self.args.t
all_scores = all_scores * scale
start = self.args.process_index * q_reps.shape[0]
local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device)
# batch_size x (world_size x batch_size x train_n_passage)
scores = all_scores.index_select(dim=0, index=local_query_indices)
labels = all_labels.index_select(dim=0, index=local_query_indices)
return scores, labels, q_reps, p_reps, all_scores, all_labels
def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]:
if not input_dict:
return None
outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['kd_labels']}, return_dict=True)
hidden_state = outputs.last_hidden_state
embeds = hidden_state[:, 0]
embeds = self.pooler(embeds)
if self.args.l2_normalize:
embeds = F.normalize(embeds, dim=-1)
return embeds.contiguous()
@classmethod
def build(cls, args: Arguments, **hf_kwargs):
# load local
if os.path.isdir(args.model_name_or_path):
if not args.share_encoder:
_qry_model_path = os.path.join(args.model_name_or_path, 'query_model')
_psg_model_path = os.path.join(args.model_name_or_path, 'passage_model')
if not os.path.exists(_qry_model_path):
_qry_model_path = args.model_name_or_path
_psg_model_path = args.model_name_or_path
logger.info(f'loading query model weight from {_qry_model_path}')
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f'loading passage model weight from {_psg_model_path}')
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
else:
logger.info(f'loading shared model weight from {args.model_name_or_path}')
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
lm_p = lm_q
# load pre-trained
else:
lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
lm_p = copy.deepcopy(lm_q) if not args.share_encoder else lm_q
model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
return model
def save(self, output_dir: str):
if not self.args.share_encoder:
os.makedirs(os.path.join(output_dir, 'query_model'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'passage_model'), exist_ok=True)
self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
else:
self.lm_q.save_pretrained(output_dir)
if self.args.add_pooler:
torch.save(self.pooler.state_dict(), os.path.join(output_dir, 'pooler.pt'))
class BiencoderModelForInference(BiencoderModel):
def __init__(self, args: Arguments,
lm_q: PreTrainedModel,
lm_p: PreTrainedModel):
nn.Module.__init__(self)
self.args = args
self.lm_q = lm_q
self.lm_p = lm_p
self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()
@torch.no_grad()
def forward(self, query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None):
q_reps = self._encode(self.lm_q, query)
p_reps = self._encode(self.lm_p, passage)
return BiencoderOutput(q_reps=q_reps, p_reps=p_reps)
@classmethod
def build(cls, args: Arguments, **hf_kwargs):
model_name_or_path = args.model_name_or_path
# load local
if os.path.isdir(model_name_or_path):
_qry_model_path = os.path.join(model_name_or_path, 'query_model')
_psg_model_path = os.path.join(model_name_or_path, 'passage_model')
if os.path.exists(_qry_model_path):
logger.info(f'found separate weight for query/passage encoders')
logger.info(f'loading query model weight from {_qry_model_path}')
lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f'loading passage model weight from {_psg_model_path}')
lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
else:
logger.info(f'try loading tied weight')
logger.info(f'loading model weight from {model_name_or_path}')
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
lm_p = lm_q
else:
logger.info(f'try loading tied weight {model_name_or_path}')
lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
lm_p = lm_q
model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
pooler_path = os.path.join(args.model_name_or_path, 'pooler.pt')
if os.path.exists(pooler_path):
logger.info('loading pooler weights from local files')
state_dict = torch.load(pooler_path, map_location="cpu")
model.pooler.load_state_dict(state_dict)
else:
assert not args.add_pooler
logger.info('No pooler will be loaded')
return model
|