Pringled commited on
Commit
3b4c438
·
1 Parent(s): 471be58

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +414 -90
app.py CHANGED
@@ -10,12 +10,15 @@ import tqdm
10
  # Load the model at startup
11
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
- # Load the default datasets at startup
14
- default_dataset1_name = "ag_news"
15
  default_dataset1_split = "train"
16
- default_dataset2_name = "ag_news"
17
- default_dataset2_split = "test"
 
 
18
 
 
19
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
20
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
21
 
@@ -23,20 +26,28 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
23
  """
24
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
25
  """
26
- reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
27
 
28
  deduplicated_indices = set(range(len(embedding_matrix)))
29
  duplicate_to_original_mapping = {}
30
 
 
 
31
  results = reach.nearest_neighbor_threshold(
32
  embedding_matrix,
33
  threshold=threshold,
34
  batch_size=batch_size,
35
- show_progressbar=True # Allow internal progress bar
36
  )
37
 
38
- # Process duplicates
39
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embedding_matrix))):
 
 
40
  if i not in deduplicated_indices:
41
  continue
42
 
@@ -53,19 +64,28 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
53
  """
54
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
55
  """
56
- reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
 
 
 
 
57
 
58
  duplicate_indices_in_test = []
59
  duplicate_to_original_mapping = {}
60
 
 
 
61
  results = reach.nearest_neighbor_threshold(
62
  embedding_matrix_2,
63
  threshold=threshold,
64
  batch_size=batch_size,
65
- show_progressbar=True # Allow internal progress bar
66
  )
67
 
68
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=len(embedding_matrix_2))):
 
 
 
69
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
70
 
71
  if similar_indices:
@@ -86,7 +106,7 @@ def perform_deduplication(
86
  dataset2_name="",
87
  dataset2_split="",
88
  dataset2_text_column="",
89
- threshold=0.8,
90
  progress=gr.Progress(track_tqdm=True)
91
  ):
92
  # Monkey-patch tqdm
@@ -102,89 +122,63 @@ def perform_deduplication(
102
  threshold = float(threshold)
103
 
104
  if deduplication_type == "Single dataset":
105
- # Check if the dataset is the default one
106
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
107
  ds = ds_default1
108
  else:
109
  ds = load_dataset(dataset1_name, split=dataset1_split)
110
 
111
- # Extract texts
112
- texts = [example[dataset1_text_column] for example in ds]
 
113
 
114
- # Compute embeddings
115
- embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
 
 
116
 
117
  # Deduplicate
118
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
119
-
120
- # Prepare the results
121
- num_duplicates = len(duplicate_to_original_mapping)
122
- num_total = len(texts)
123
- num_deduplicated = len(deduplicated_indices)
124
-
125
- result_text = f"**Total documents:** {num_total}\n"
126
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
127
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
128
-
129
- # Show deduplicated examples
130
- result_text += "**Examples of duplicates found:**\n\n"
131
- num_examples = min(5, num_duplicates)
132
- for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
133
- original_text = texts[original_idx]
134
- duplicate_text = texts[duplicate_idx]
135
- differences = display_word_differences(original_text, duplicate_text)
136
- result_text += f"**Original text:**\n{original_text}\n\n"
137
- result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
138
- result_text += f"**Differences:**\n{differences}\n"
139
- result_text += "-" * 50 + "\n\n"
140
 
141
  return result_text
142
 
143
  elif deduplication_type == "Cross-dataset":
144
- # Dataset 1
145
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
146
  ds1 = ds_default1
147
  else:
148
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
149
 
150
- # Dataset 2
151
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
152
  ds2 = ds_default2
153
  else:
154
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
155
 
156
- # Extract texts
157
- texts1 = [example[dataset1_text_column] for example in ds1]
158
- texts2 = [example[dataset2_text_column] for example in ds2]
 
 
 
 
159
 
160
- # Compute embeddings
161
- embedding_matrix1 = model.encode(texts1, show_progressbar=True) # Enable internal progress bar
162
- embedding_matrix2 = model.encode(texts2, show_progressbar=True) # Enable internal progress bar
 
 
 
 
 
 
163
 
164
  # Deduplicate across datasets
165
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
166
- embedding_matrix1, embedding_matrix2, threshold, progress=progress)
167
-
168
- num_duplicates = len(duplicate_indices_in_ds2)
169
- num_total_ds2 = len(texts2)
170
- num_unique_ds2 = num_total_ds2 - num_duplicates
171
-
172
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
173
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
174
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
175
-
176
- # Show deduplicated examples
177
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
178
- num_examples = min(5, num_duplicates)
179
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
180
- original_idx = duplicate_to_original_mapping[duplicate_idx]
181
- original_text = texts1[original_idx]
182
- duplicate_text = texts2[duplicate_idx]
183
- differences = display_word_differences(original_text, duplicate_text)
184
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
185
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
186
- result_text += f"**Differences:**\n{differences}\n"
187
- result_text += "-" * 50 + "\n\n"
188
 
189
  return result_text
190
 
@@ -200,52 +194,116 @@ def perform_deduplication(
200
  else:
201
  del Reach.tqdm # If it wasn't originally in Reach's __dict__
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  with gr.Blocks() as demo:
204
  gr.Markdown("# Semantic Deduplication")
205
-
206
  deduplication_type = gr.Radio(
207
  choices=["Single dataset", "Cross-dataset"],
208
  label="Deduplication Type",
209
  value="Single dataset"
210
  )
211
-
212
  with gr.Row():
213
- dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
214
- dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
215
- dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
216
-
217
  dataset2_inputs = gr.Column(visible=False)
218
  with dataset2_inputs:
219
  gr.Markdown("### Dataset 2")
220
  with gr.Row():
221
- dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
222
- dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
223
- dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
224
-
225
  threshold = gr.Slider(
226
  minimum=0.0,
227
  maximum=1.0,
228
- value=0.8,
229
  label="Similarity Threshold"
230
  )
231
-
232
  compute_button = gr.Button("Compute")
233
-
234
  output = gr.Markdown()
235
-
236
  # Function to update the visibility of dataset2_inputs
237
  def update_visibility(deduplication_type_value):
238
  if deduplication_type_value == "Cross-dataset":
239
  return gr.update(visible=True)
240
  else:
241
  return gr.update(visible=False)
242
-
243
  deduplication_type.change(
244
  update_visibility,
245
  inputs=deduplication_type,
246
  outputs=dataset2_inputs
247
  )
248
-
249
  compute_button.click(
250
  fn=perform_deduplication,
251
  inputs=[
@@ -302,7 +360,7 @@ demo.launch()
302
  # )
303
 
304
  # # Process duplicates
305
- # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
306
  # if i not in deduplicated_indices:
307
  # continue
308
 
@@ -331,8 +389,7 @@ demo.launch()
331
  # show_progressbar=True # Allow internal progress bar
332
  # )
333
 
334
- # # Process duplicates
335
- # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
336
  # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
337
 
338
  # if similar_indices:
@@ -358,9 +415,11 @@ demo.launch()
358
  # ):
359
  # # Monkey-patch tqdm
360
  # original_tqdm = tqdm.tqdm
 
361
  # tqdm.tqdm = progress.tqdm
362
  # sys.modules['tqdm'].tqdm = progress.tqdm
363
  # sys.modules['tqdm.auto'].tqdm = progress.tqdm
 
364
 
365
  # try:
366
  # # Convert threshold to float
@@ -427,7 +486,8 @@ demo.launch()
427
  # embedding_matrix2 = model.encode(texts2, show_progressbar=True) # Enable internal progress bar
428
 
429
  # # Deduplicate across datasets
430
- # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
431
 
432
  # num_duplicates = len(duplicate_indices_in_ds2)
433
  # num_total_ds2 = len(texts2)
@@ -458,6 +518,12 @@ demo.launch()
458
  # sys.modules['tqdm'].tqdm = original_tqdm
459
  # sys.modules['tqdm.auto'].tqdm = original_tqdm
460
 
 
 
 
 
 
 
461
  # with gr.Blocks() as demo:
462
  # gr.Markdown("# Semantic Deduplication")
463
 
@@ -520,3 +586,261 @@ demo.launch()
520
  # )
521
 
522
  # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Load the model at startup
11
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
+ # Update default dataset to 'sst2' and set default threshold to 0.9
14
+ default_dataset1_name = "sst2"
15
  default_dataset1_split = "train"
16
+ default_dataset2_name = "sst2"
17
+ default_dataset2_split = "validation"
18
+ default_text_column = "sentence"
19
+ default_threshold = 0.9
20
 
21
+ # Load the default datasets at startup
22
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
23
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
24
 
 
26
  """
27
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
28
  """
29
+ # Informative progress bar for building the index
30
+ progress.tqdm.write("Building search index...")
31
+ with progress.tqdm(total=1, desc="Building index") as p:
32
+ reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
33
+ p.update(1)
34
 
35
  deduplicated_indices = set(range(len(embedding_matrix)))
36
  duplicate_to_original_mapping = {}
37
 
38
+ # Informative progress bar for nearest neighbor search
39
+ progress.tqdm.write("Finding nearest neighbors...")
40
  results = reach.nearest_neighbor_threshold(
41
  embedding_matrix,
42
  threshold=threshold,
43
  batch_size=batch_size,
44
+ show_progressbar=False # Disable internal progress bar
45
  )
46
 
47
+ total_items = len(embedding_matrix)
48
+ # Processing duplicates with a progress bar
49
+ progress.tqdm.write("Processing duplicates...")
50
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
51
  if i not in deduplicated_indices:
52
  continue
53
 
 
64
  """
65
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
66
  """
67
+ # Informative progress bar for building the index
68
+ progress.tqdm.write("Building search index from Dataset 1...")
69
+ with progress.tqdm(total=1, desc="Building index for Dataset 1") as p:
70
+ reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
71
+ p.update(1)
72
 
73
  duplicate_indices_in_test = []
74
  duplicate_to_original_mapping = {}
75
 
76
+ # Informative progress bar for nearest neighbor search
77
+ progress.tqdm.write("Finding nearest neighbors between datasets...")
78
  results = reach.nearest_neighbor_threshold(
79
  embedding_matrix_2,
80
  threshold=threshold,
81
  batch_size=batch_size,
82
+ show_progressbar=False # Disable internal progress bar
83
  )
84
 
85
+ total_items = len(embedding_matrix_2)
86
+ # Processing duplicates with a progress bar
87
+ progress.tqdm.write("Processing duplicates across datasets...")
88
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
89
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
90
 
91
  if similar_indices:
 
106
  dataset2_name="",
107
  dataset2_split="",
108
  dataset2_text_column="",
109
+ threshold=default_threshold,
110
  progress=gr.Progress(track_tqdm=True)
111
  ):
112
  # Monkey-patch tqdm
 
122
  threshold = float(threshold)
123
 
124
  if deduplication_type == "Single dataset":
125
+ # Load Dataset 1
126
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
127
  ds = ds_default1
128
  else:
129
  ds = load_dataset(dataset1_name, split=dataset1_split)
130
 
131
+ # Extract texts with progress bar
132
+ progress.tqdm.write("Extracting texts from Dataset 1...")
133
+ texts = [example[dataset1_text_column] for example in progress.tqdm(ds, desc="Extracting texts", total=len(ds))]
134
 
135
+ # Compute embeddings with progress bar
136
+ progress.tqdm.write("Computing embeddings for Dataset 1...")
137
+ embedding_matrix = model.encode(texts, show_progressbar=False) # Disable internal progress bar
138
+ embedding_matrix = progress.tqdm(embedding_matrix, desc="Computing embeddings", total=len(texts))
139
 
140
  # Deduplicate
141
+ result_text = deduplicate_and_prepare_results_single(
142
+ embedding_matrix, texts, threshold, progress
143
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  return result_text
146
 
147
  elif deduplication_type == "Cross-dataset":
148
+ # Load Dataset 1
149
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
150
  ds1 = ds_default1
151
  else:
152
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
153
 
154
+ # Load Dataset 2
155
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
156
  ds2 = ds_default2
157
  else:
158
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
159
 
160
+ # Extract texts from Dataset 1
161
+ progress.tqdm.write("Extracting texts from Dataset 1...")
162
+ texts1 = [example[dataset1_text_column] for example in progress.tqdm(ds1, desc="Extracting texts from Dataset 1", total=len(ds1))]
163
+
164
+ # Extract texts from Dataset 2
165
+ progress.tqdm.write("Extracting texts from Dataset 2...")
166
+ texts2 = [example[dataset2_text_column] for example in progress.tqdm(ds2, desc="Extracting texts from Dataset 2", total=len(ds2))]
167
 
168
+ # Compute embeddings for Dataset 1
169
+ progress.tqdm.write("Computing embeddings for Dataset 1...")
170
+ embedding_matrix1 = model.encode(texts1, show_progressbar=False)
171
+ embedding_matrix1 = progress.tqdm(embedding_matrix1, desc="Computing embeddings for Dataset 1", total=len(texts1))
172
+
173
+ # Compute embeddings for Dataset 2
174
+ progress.tqdm.write("Computing embeddings for Dataset 2...")
175
+ embedding_matrix2 = model.encode(texts2, show_progressbar=False)
176
+ embedding_matrix2 = progress.tqdm(embedding_matrix2, desc="Computing embeddings for Dataset 2", total=len(texts2))
177
 
178
  # Deduplicate across datasets
179
+ result_text = deduplicate_and_prepare_results_cross(
180
+ embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split
181
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  return result_text
184
 
 
194
  else:
195
  del Reach.tqdm # If it wasn't originally in Reach's __dict__
196
 
197
+ def deduplicate_and_prepare_results_single(embedding_matrix, texts, threshold, progress):
198
+ # Deduplicate
199
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(
200
+ embedding_matrix, threshold, progress=progress
201
+ )
202
+
203
+ # Prepare the results
204
+ num_duplicates = len(duplicate_to_original_mapping)
205
+ num_total = len(texts)
206
+ num_deduplicated = len(deduplicated_indices)
207
+
208
+ result_text = f"**Total documents:** {num_total}\n"
209
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
210
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
211
+
212
+ # Show deduplicated examples
213
+ if num_duplicates > 0:
214
+ result_text += "**Examples of duplicates found:**\n\n"
215
+ num_examples = min(5, num_duplicates)
216
+ for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
217
+ original_text = texts[original_idx]
218
+ duplicate_text = texts[duplicate_idx]
219
+ differences = display_word_differences(original_text, duplicate_text)
220
+ result_text += f"**Original text:**\n{original_text}\n\n"
221
+ result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
222
+ result_text += f"**Differences:**\n{differences}\n"
223
+ result_text += "-" * 50 + "\n\n"
224
+ else:
225
+ result_text += "No duplicates found."
226
+
227
+ return result_text
228
+
229
+ def deduplicate_and_prepare_results_cross(embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split):
230
+ # Deduplicate across datasets
231
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
232
+ embedding_matrix1, embedding_matrix2, threshold, progress=progress
233
+ )
234
+
235
+ num_duplicates = len(duplicate_indices_in_ds2)
236
+ num_total_ds2 = len(texts2)
237
+ num_unique_ds2 = num_total_ds2 - num_duplicates
238
+
239
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
240
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
241
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
242
+
243
+ # Show deduplicated examples
244
+ if num_duplicates > 0:
245
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
246
+ num_examples = min(5, num_duplicates)
247
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
248
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
249
+ original_text = texts1[original_idx]
250
+ duplicate_text = texts2[duplicate_idx]
251
+ differences = display_word_differences(original_text, duplicate_text)
252
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
253
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
254
+ result_text += f"**Differences:**\n{differences}\n"
255
+ result_text += "-" * 50 + "\n\n"
256
+ else:
257
+ result_text += "No duplicates found."
258
+
259
+ return result_text
260
+
261
  with gr.Blocks() as demo:
262
  gr.Markdown("# Semantic Deduplication")
263
+
264
  deduplication_type = gr.Radio(
265
  choices=["Single dataset", "Cross-dataset"],
266
  label="Deduplication Type",
267
  value="Single dataset"
268
  )
269
+
270
  with gr.Row():
271
+ dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
272
+ dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
273
+ dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
274
+
275
  dataset2_inputs = gr.Column(visible=False)
276
  with dataset2_inputs:
277
  gr.Markdown("### Dataset 2")
278
  with gr.Row():
279
+ dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
280
+ dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
281
+ dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
282
+
283
  threshold = gr.Slider(
284
  minimum=0.0,
285
  maximum=1.0,
286
+ value=default_threshold,
287
  label="Similarity Threshold"
288
  )
289
+
290
  compute_button = gr.Button("Compute")
291
+
292
  output = gr.Markdown()
293
+
294
  # Function to update the visibility of dataset2_inputs
295
  def update_visibility(deduplication_type_value):
296
  if deduplication_type_value == "Cross-dataset":
297
  return gr.update(visible=True)
298
  else:
299
  return gr.update(visible=False)
300
+
301
  deduplication_type.change(
302
  update_visibility,
303
  inputs=deduplication_type,
304
  outputs=dataset2_inputs
305
  )
306
+
307
  compute_button.click(
308
  fn=perform_deduplication,
309
  inputs=[
 
360
  # )
361
 
362
  # # Process duplicates
363
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embedding_matrix))):
364
  # if i not in deduplicated_indices:
365
  # continue
366
 
 
389
  # show_progressbar=True # Allow internal progress bar
390
  # )
391
 
392
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=len(embedding_matrix_2))):
 
393
  # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
394
 
395
  # if similar_indices:
 
415
  # ):
416
  # # Monkey-patch tqdm
417
  # original_tqdm = tqdm.tqdm
418
+ # original_reach_tqdm = Reach.__dict__['tqdm'] if 'tqdm' in Reach.__dict__ else None
419
  # tqdm.tqdm = progress.tqdm
420
  # sys.modules['tqdm'].tqdm = progress.tqdm
421
  # sys.modules['tqdm.auto'].tqdm = progress.tqdm
422
+ # Reach.tqdm = progress.tqdm # Monkey-patch reach's tqdm
423
 
424
  # try:
425
  # # Convert threshold to float
 
486
  # embedding_matrix2 = model.encode(texts2, show_progressbar=True) # Enable internal progress bar
487
 
488
  # # Deduplicate across datasets
489
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
490
+ # embedding_matrix1, embedding_matrix2, threshold, progress=progress)
491
 
492
  # num_duplicates = len(duplicate_indices_in_ds2)
493
  # num_total_ds2 = len(texts2)
 
518
  # sys.modules['tqdm'].tqdm = original_tqdm
519
  # sys.modules['tqdm.auto'].tqdm = original_tqdm
520
 
521
+ # # Restore reach's original tqdm
522
+ # if original_reach_tqdm is not None:
523
+ # Reach.tqdm = original_reach_tqdm
524
+ # else:
525
+ # del Reach.tqdm # If it wasn't originally in Reach's __dict__
526
+
527
  # with gr.Blocks() as demo:
528
  # gr.Markdown("# Semantic Deduplication")
529
 
 
586
  # )
587
 
588
  # demo.launch()
589
+
590
+
591
+ # # import gradio as gr
592
+ # # from datasets import load_dataset
593
+ # # import numpy as np
594
+ # # from model2vec import StaticModel
595
+ # # from reach import Reach
596
+ # # from difflib import ndiff
597
+ # # import sys
598
+ # # import tqdm
599
+
600
+ # # # Load the model at startup
601
+ # # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
602
+
603
+ # # # Load the default datasets at startup
604
+ # # default_dataset1_name = "ag_news"
605
+ # # default_dataset1_split = "train"
606
+ # # default_dataset2_name = "ag_news"
607
+ # # default_dataset2_split = "test"
608
+
609
+ # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
610
+ # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
611
+
612
+ # # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
613
+ # # """
614
+ # # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
615
+ # # """
616
+ # # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
617
+
618
+ # # deduplicated_indices = set(range(len(embedding_matrix)))
619
+ # # duplicate_to_original_mapping = {}
620
+
621
+ # # results = reach.nearest_neighbor_threshold(
622
+ # # embedding_matrix,
623
+ # # threshold=threshold,
624
+ # # batch_size=batch_size,
625
+ # # show_progressbar=True # Allow internal progress bar
626
+ # # )
627
+
628
+ # # # Process duplicates
629
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
630
+ # # if i not in deduplicated_indices:
631
+ # # continue
632
+
633
+ # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
634
+
635
+ # # for sim_idx in similar_indices:
636
+ # # if sim_idx in deduplicated_indices:
637
+ # # deduplicated_indices.remove(sim_idx)
638
+ # # duplicate_to_original_mapping[sim_idx] = i
639
+
640
+ # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
641
+
642
+ # # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
643
+ # # """
644
+ # # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
645
+ # # """
646
+ # # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
647
+
648
+ # # duplicate_indices_in_test = []
649
+ # # duplicate_to_original_mapping = {}
650
+
651
+ # # results = reach.nearest_neighbor_threshold(
652
+ # # embedding_matrix_2,
653
+ # # threshold=threshold,
654
+ # # batch_size=batch_size,
655
+ # # show_progressbar=True # Allow internal progress bar
656
+ # # )
657
+
658
+ # # # Process duplicates
659
+ # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
660
+ # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
661
+
662
+ # # if similar_indices:
663
+ # # duplicate_indices_in_test.append(i)
664
+ # # duplicate_to_original_mapping[i] = similar_indices[0]
665
+
666
+ # # return duplicate_indices_in_test, duplicate_to_original_mapping
667
+
668
+ # # def display_word_differences(x: str, y: str) -> str:
669
+ # # diff = ndiff(x.split(), y.split())
670
+ # # return " ".join([word for word in diff if word.startswith(('+', '-'))])
671
+
672
+ # # def perform_deduplication(
673
+ # # deduplication_type,
674
+ # # dataset1_name,
675
+ # # dataset1_split,
676
+ # # dataset1_text_column,
677
+ # # dataset2_name="",
678
+ # # dataset2_split="",
679
+ # # dataset2_text_column="",
680
+ # # threshold=0.8,
681
+ # # progress=gr.Progress(track_tqdm=True)
682
+ # # ):
683
+ # # # Monkey-patch tqdm
684
+ # # original_tqdm = tqdm.tqdm
685
+ # # tqdm.tqdm = progress.tqdm
686
+ # # sys.modules['tqdm'].tqdm = progress.tqdm
687
+ # # sys.modules['tqdm.auto'].tqdm = progress.tqdm
688
+
689
+ # # try:
690
+ # # # Convert threshold to float
691
+ # # threshold = float(threshold)
692
+
693
+ # # if deduplication_type == "Single dataset":
694
+ # # # Check if the dataset is the default one
695
+ # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
696
+ # # ds = ds_default1
697
+ # # else:
698
+ # # ds = load_dataset(dataset1_name, split=dataset1_split)
699
+
700
+ # # # Extract texts
701
+ # # texts = [example[dataset1_text_column] for example in ds]
702
+
703
+ # # # Compute embeddings
704
+ # # embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
705
+
706
+ # # # Deduplicate
707
+ # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
708
+
709
+ # # # Prepare the results
710
+ # # num_duplicates = len(duplicate_to_original_mapping)
711
+ # # num_total = len(texts)
712
+ # # num_deduplicated = len(deduplicated_indices)
713
+
714
+ # # result_text = f"**Total documents:** {num_total}\n"
715
+ # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
716
+ # # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
717
+
718
+ # # # Show deduplicated examples
719
+ # # result_text += "**Examples of duplicates found:**\n\n"
720
+ # # num_examples = min(5, num_duplicates)
721
+ # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
722
+ # # original_text = texts[original_idx]
723
+ # # duplicate_text = texts[duplicate_idx]
724
+ # # differences = display_word_differences(original_text, duplicate_text)
725
+ # # result_text += f"**Original text:**\n{original_text}\n\n"
726
+ # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
727
+ # # result_text += f"**Differences:**\n{differences}\n"
728
+ # # result_text += "-" * 50 + "\n\n"
729
+
730
+ # # return result_text
731
+
732
+ # # elif deduplication_type == "Cross-dataset":
733
+ # # # Dataset 1
734
+ # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
735
+ # # ds1 = ds_default1
736
+ # # else:
737
+ # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
738
+
739
+ # # # Dataset 2
740
+ # # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
741
+ # # ds2 = ds_default2
742
+ # # else:
743
+ # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
744
+
745
+ # # # Extract texts
746
+ # # texts1 = [example[dataset1_text_column] for example in ds1]
747
+ # # texts2 = [example[dataset2_text_column] for example in ds2]
748
+
749
+ # # # Compute embeddings
750
+ # # embedding_matrix1 = model.encode(texts1, show_progressbar=True) # Enable internal progress bar
751
+ # # embedding_matrix2 = model.encode(texts2, show_progressbar=True) # Enable internal progress bar
752
+
753
+ # # # Deduplicate across datasets
754
+ # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
755
+
756
+ # # num_duplicates = len(duplicate_indices_in_ds2)
757
+ # # num_total_ds2 = len(texts2)
758
+ # # num_unique_ds2 = num_total_ds2 - num_duplicates
759
+
760
+ # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
761
+ # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
762
+ # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
763
+
764
+ # # # Show deduplicated examples
765
+ # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
766
+ # # num_examples = min(5, num_duplicates)
767
+ # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
768
+ # # original_idx = duplicate_to_original_mapping[duplicate_idx]
769
+ # # original_text = texts1[original_idx]
770
+ # # duplicate_text = texts2[duplicate_idx]
771
+ # # differences = display_word_differences(original_text, duplicate_text)
772
+ # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
773
+ # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
774
+ # # result_text += f"**Differences:**\n{differences}\n"
775
+ # # result_text += "-" * 50 + "\n\n"
776
+
777
+ # # return result_text
778
+
779
+ # # finally:
780
+ # # # Restore original tqdm
781
+ # # tqdm.tqdm = original_tqdm
782
+ # # sys.modules['tqdm'].tqdm = original_tqdm
783
+ # # sys.modules['tqdm.auto'].tqdm = original_tqdm
784
+
785
+ # # with gr.Blocks() as demo:
786
+ # # gr.Markdown("# Semantic Deduplication")
787
+
788
+ # # deduplication_type = gr.Radio(
789
+ # # choices=["Single dataset", "Cross-dataset"],
790
+ # # label="Deduplication Type",
791
+ # # value="Single dataset"
792
+ # # )
793
+
794
+ # # with gr.Row():
795
+ # # dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
796
+ # # dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
797
+ # # dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
798
+
799
+ # # dataset2_inputs = gr.Column(visible=False)
800
+ # # with dataset2_inputs:
801
+ # # gr.Markdown("### Dataset 2")
802
+ # # with gr.Row():
803
+ # # dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
804
+ # # dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
805
+ # # dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
806
+
807
+ # # threshold = gr.Slider(
808
+ # # minimum=0.0,
809
+ # # maximum=1.0,
810
+ # # value=0.8,
811
+ # # label="Similarity Threshold"
812
+ # # )
813
+
814
+ # # compute_button = gr.Button("Compute")
815
+
816
+ # # output = gr.Markdown()
817
+
818
+ # # # Function to update the visibility of dataset2_inputs
819
+ # # def update_visibility(deduplication_type_value):
820
+ # # if deduplication_type_value == "Cross-dataset":
821
+ # # return gr.update(visible=True)
822
+ # # else:
823
+ # # return gr.update(visible=False)
824
+
825
+ # # deduplication_type.change(
826
+ # # update_visibility,
827
+ # # inputs=deduplication_type,
828
+ # # outputs=dataset2_inputs
829
+ # # )
830
+
831
+ # # compute_button.click(
832
+ # # fn=perform_deduplication,
833
+ # # inputs=[
834
+ # # deduplication_type,
835
+ # # dataset1_name,
836
+ # # dataset1_split,
837
+ # # dataset1_text_column,
838
+ # # dataset2_name,
839
+ # # dataset2_split,
840
+ # # dataset2_text_column,
841
+ # # threshold
842
+ # # ],
843
+ # # outputs=output
844
+ # # )
845
+
846
+ # # demo.launch()