import sys import torch from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore # Convert prithivida/Splade_PP_en_v2 to onnx. # based on this info: # - https://github.com/naver/splade/issues/47 # - https://github.com/castorini/anserini/blob/master/docs/onnx-conversion.md class TransformerRep(torch.nn.Module): def __init__(self): super().__init__() self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2') self.model.eval() # type: ignore self.fp16 = True def encode(self, input_ids, token_type_ids, attention_mask): return self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask )[0] class SpladeModel(torch.nn.Module): def __init__(self): super().__init__() self.model = TransformerRep() self.agg = "max" self.model.eval() def forward(self, input_ids,token_type_ids, attention_mask): with torch.cuda.amp.autocast(): # type: ignore with torch.no_grad(): lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0] vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1) indices = vec.nonzero().squeeze() weights = vec.squeeze()[indices] return indices[:,1], weights[:,1] if __name__ == '__main__': if len(sys.argv) != 2: print('Usage:', sys.argv[0], '') sys.exit(1) # Convert the model to TorchScript model = SpladeModel() input_ids = torch.randint(1,100, size=(1,50)) token_type_ids = torch.full((1,50), 0) attention_mask = torch.full((1,50), 1) traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask)) dyn_axis = { 'input_ids': {0: 'batch_size', 1: 'sequence'}, 'attention_mask': {0: 'batch_size', 1: 'sequence'}, 'token_type_ids': {0: 'batch_size', 1: 'sequence'}, 'output_idx': {0: 'batch_size', 1: 'sequence'}, 'output_weights': {0: 'batch_size', 1: 'sequence'} } onnx_model = torch.onnx.export( traced_model, (input_ids, token_type_ids, attention_mask), # type: ignore f=sys.argv[1], input_names=['input_ids','token_type_ids', 'attention_mask'], output_names=['output_idx', 'output_weights'], dynamic_axes=dyn_axis, do_constant_folding=True, opset_version=15, verbose=False, )