darpanaswal commited on
Commit
3a53014
·
verified ·
1 Parent(s): 4202987

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +1 -8
cross_encoder_reranking_train.py CHANGED
@@ -226,13 +226,6 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
226
  batch_size = last_hidden_states.shape[0]
227
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
228
 
229
- def cls_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
230
- """Extract [CLS] token representations, accounting for left padding."""
231
- # Get the index of the first non-padding token in each sequence
232
- cls_indices = attention_mask.float().argmax(dim=1)
233
- batch_size = last_hidden_states.size(0)
234
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), cls_indices]
235
-
236
  def get_detailed_instruct(task_description: str, query: str) -> str:
237
  """Create an instruction-formatted query"""
238
  return f'Instruct: {task_description}\nQuery: {query}'
@@ -274,7 +267,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
274
 
275
  # Get embeddings
276
  outputs = model(**batch_dict)
277
- embeddings = cls_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
278
 
279
  # Normalize embeddings
280
  embeddings = F.normalize(embeddings, p=2, dim=1)
 
226
  batch_size = last_hidden_states.shape[0]
227
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
228
 
 
 
 
 
 
 
 
229
  def get_detailed_instruct(task_description: str, query: str) -> str:
230
  """Create an instruction-formatted query"""
231
  return f'Instruct: {task_description}\nQuery: {query}'
 
267
 
268
  # Get embeddings
269
  outputs = model(**batch_dict)
270
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
271
 
272
  # Normalize embeddings
273
  embeddings = F.normalize(embeddings, p=2, dim=1)