Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
@@ -226,13 +226,6 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
|
|
226 |
batch_size = last_hidden_states.shape[0]
|
227 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
228 |
|
229 |
-
def cls_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
230 |
-
"""Extract [CLS] token representations, accounting for left padding."""
|
231 |
-
# Get the index of the first non-padding token in each sequence
|
232 |
-
cls_indices = attention_mask.float().argmax(dim=1)
|
233 |
-
batch_size = last_hidden_states.size(0)
|
234 |
-
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), cls_indices]
|
235 |
-
|
236 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
237 |
"""Create an instruction-formatted query"""
|
238 |
return f'Instruct: {task_description}\nQuery: {query}'
|
@@ -274,7 +267,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
274 |
|
275 |
# Get embeddings
|
276 |
outputs = model(**batch_dict)
|
277 |
-
embeddings =
|
278 |
|
279 |
# Normalize embeddings
|
280 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
226 |
batch_size = last_hidden_states.shape[0]
|
227 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
230 |
"""Create an instruction-formatted query"""
|
231 |
return f'Instruct: {task_description}\nQuery: {query}'
|
|
|
267 |
|
268 |
# Get embeddings
|
269 |
outputs = model(**batch_dict)
|
270 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
271 |
|
272 |
# Normalize embeddings
|
273 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|