darpanaswal commited on
Commit
92e8f21
·
verified ·
1 Parent(s): 6bb238e

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +2 -11
cross_encoder_reranking_train.py CHANGED
@@ -37,17 +37,8 @@ def cluster_and_rank(texts, threshold=0.8):
37
  return texts
38
 
39
  embeddings = embed_text_list(texts)
40
- similarity_matrix = np.dot(embeddings, np.transpose(embeddings))
41
- max_sim = np.max(similarity_matrix)
42
- distance_matrix = max_sim - similarity_matrix
43
-
44
- clustering = AgglomerativeClustering(
45
- n_clusters=None,
46
- distance_threshold=max_sim-threshold, # lower threshold = tighter clusters
47
- metric='precomputed',
48
- linkage='average'
49
- )
50
- labels = clustering.fit_predict(distance_matrix)
51
 
52
  clustered_texts = {}
53
  for label, text in zip(labels, texts):
 
37
  return texts
38
 
39
  embeddings = embed_text_list(texts)
40
+ clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=1-threshold, metric = "cosine", linkage='average')
41
+ labels = clustering.fit_predict(embeddings)
 
 
 
 
 
 
 
 
 
42
 
43
  clustered_texts = {}
44
  for label, text in zip(labels, texts):