Pringled commited on
Commit
95530b9
·
1 Parent(s): c58907b
Files changed (1) hide show
  1. app.py +284 -1420
app.py CHANGED
@@ -5,79 +5,72 @@ from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
 
8
- # Load the model at startup
9
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
10
 
11
- # Default dataset parameters
12
- default_dataset1_name = "sst2"
13
- default_dataset1_split = "train"
14
- default_dataset2_name = "sst2"
15
- default_dataset2_split = "validation"
16
  default_text_column = "sentence"
17
  default_threshold = 0.9
18
 
19
- # Load the default datasets at startup
20
- ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
21
- ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
22
-
23
  def batch_iterable(iterable, batch_size):
24
- """Helper function to create batches from an iterable."""
25
  for i in range(0, len(iterable), batch_size):
26
  yield iterable[i:i + batch_size]
27
 
28
- def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
 
29
  embeddings = []
30
  total_batches = (len(texts) + batch_size - 1) // batch_size
31
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
32
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
33
- embeddings.append(batch_embeddings)
34
  progress((i + 1) / total_batches, desc=desc)
35
  return np.concatenate(embeddings, axis=0)
36
 
37
- def deduplicate(
38
- embedding_matrix: np.ndarray,
39
- threshold: float,
 
40
  batch_size: int = 1024,
41
  progress=None
42
- ) -> tuple[np.ndarray, dict[int, int]]:
43
- # Building the index
44
- progress(0, desc="Building search index...")
45
- reach = Reach(
46
- vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
47
- )
48
-
49
- deduplicated_indices = set(range(len(embedding_matrix)))
50
- duplicate_to_original_mapping = {}
51
-
52
- # Finding nearest neighbors
53
- progress(0, desc="Finding nearest neighbors...")
54
- results = reach.nearest_neighbor_threshold(
55
- embedding_matrix,
56
- threshold=threshold,
57
- batch_size=batch_size,
58
- show_progressbar=False, # Disable internal progress bar
59
- )
60
-
61
- # Processing duplicates with a progress bar
62
- total_items = len(embedding_matrix)
63
- for i, similar_items in enumerate(
64
- progress.tqdm(results, desc="Processing duplicates", total=total_items)
65
- ):
66
- if i not in deduplicated_indices:
67
- continue
68
-
69
- similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
70
-
71
- for sim_idx in similar_indices:
72
- if sim_idx in deduplicated_indices:
73
- deduplicated_indices.remove(sim_idx)
74
- duplicate_to_original_mapping[sim_idx] = i
75
-
76
- return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
77
 
78
  def display_word_differences(x: str, y: str) -> str:
 
79
  diff = ndiff(x.split(), y.split())
80
- return " ".join([word for word in diff if word.startswith(("+", "-"))])
 
 
 
 
 
81
 
82
  def perform_deduplication(
83
  deduplication_type,
@@ -91,208 +84,86 @@ def perform_deduplication(
91
  progress=gr.Progress(track_tqdm=True),
92
  ):
93
  try:
94
- # Convert threshold to float
95
  threshold = float(threshold)
96
 
97
- # Initialize status message
98
- status = ""
 
 
 
99
 
100
  if deduplication_type == "Single dataset":
101
- # Load Dataset 1
102
- status = "Loading Dataset 1..."
103
- yield status, ""
104
- if (
105
- dataset1_name == default_dataset1_name
106
- and dataset1_split == default_dataset1_split
107
- ):
108
- ds = ds_default1
109
- else:
110
- ds = load_dataset(dataset1_name, split=dataset1_split)
111
-
112
- # Extract texts
113
- status = "Extracting texts from Dataset 1..."
114
- yield status, ""
115
- texts = [example[dataset1_text_column] for example in ds]
116
-
117
- # Compute embeddings
118
- status = "Computing embeddings for Dataset 1..."
119
- yield status, ""
120
- embedding_matrix = compute_embeddings(
121
- texts,
122
- batch_size=64,
123
- progress=progress,
124
- desc="Computing embeddings for Dataset 1",
125
- )
126
-
127
- # Deduplicate
128
- status = "Deduplicating embeddings..."
129
- yield status, ""
130
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(
131
- embedding_matrix, threshold, progress=progress
132
  )
133
 
134
- # Prepare the results
135
- num_duplicates = len(duplicate_to_original_mapping)
136
- num_total = len(texts)
137
- num_deduplicated = len(deduplicated_indices)
138
-
139
- result_text = f"**Total documents:** {num_total}\n"
140
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
141
- result_text += (
142
- f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
143
  )
144
 
145
- # Show deduplicated examples
146
  if num_duplicates > 0:
147
- result_text += "**Examples of duplicates found:**\n\n"
148
- num_examples = min(5, num_duplicates)
149
- for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
150
- original_text = texts[original_idx]
151
- duplicate_text = texts[duplicate_idx]
152
- differences = display_word_differences(original_text, duplicate_text)
153
- result_text += f"**Original text:**\n{original_text}\n\n"
154
- result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
155
- result_text += f"**Differences:**\n{differences}\n"
156
- result_text += "-" * 50 + "\n\n"
 
157
  else:
158
  result_text += "No duplicates found."
159
 
160
- # Final status
161
- status = "Deduplication completed."
162
- yield status, result_text
163
-
164
- elif deduplication_type == "Cross-dataset":
165
- # Similar code for cross-dataset deduplication
166
- # Load Dataset 1
167
- status = "Loading Dataset 1..."
168
- yield status, ""
169
- if (
170
- dataset1_name == default_dataset1_name
171
- and dataset1_split == default_dataset1_split
172
- ):
173
- ds1 = ds_default1
174
- else:
175
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
176
-
177
- # Load Dataset 2
178
- status = "Loading Dataset 2..."
179
- yield status, ""
180
- if (
181
- dataset2_name == default_dataset2_name
182
- and dataset2_split == default_dataset2_split
183
- ):
184
- ds2 = ds_default2
185
- else:
186
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
187
-
188
- # Extract texts from Dataset 1
189
- status = "Extracting texts from Dataset 1..."
190
- yield status, ""
191
- texts1 = [example[dataset1_text_column] for example in ds1]
192
-
193
- # Extract texts from Dataset 2
194
- status = "Extracting texts from Dataset 2..."
195
- yield status, ""
196
- texts2 = [example[dataset2_text_column] for example in ds2]
197
-
198
- # Compute embeddings for Dataset 1
199
- status = "Computing embeddings for Dataset 1..."
200
- yield status, ""
201
- embedding_matrix1 = compute_embeddings(
202
- texts1,
203
- batch_size=64,
204
- progress=progress,
205
- desc="Computing embeddings for Dataset 1",
206
- )
207
 
208
- # Compute embeddings for Dataset 2
209
- status = "Computing embeddings for Dataset 2..."
210
- yield status, ""
211
- embedding_matrix2 = compute_embeddings(
212
- texts2,
213
- batch_size=64,
214
- progress=progress,
215
- desc="Computing embeddings for Dataset 2",
 
 
 
216
  )
217
 
218
- # Deduplicate across datasets
219
- status = "Deduplicating embeddings across datasets..."
220
- yield status, ""
221
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
222
- embedding_matrix1, embedding_matrix2, threshold, progress=progress
223
  )
224
 
225
- num_duplicates = len(duplicate_indices_in_ds2)
226
- num_total_ds2 = len(texts2)
227
- num_unique_ds2 = num_total_ds2 - num_duplicates
228
-
229
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
230
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
231
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
232
-
233
- # Show deduplicated examples
234
  if num_duplicates > 0:
235
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
236
- num_examples = min(5, num_duplicates)
237
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
238
- original_idx = duplicate_to_original_mapping[duplicate_idx]
239
- original_text = texts1[original_idx]
240
- duplicate_text = texts2[duplicate_idx]
241
- differences = display_word_differences(original_text, duplicate_text)
242
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
243
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
244
- result_text += f"**Differences:**\n{differences}\n"
245
- result_text += "-" * 50 + "\n\n"
246
  else:
247
  result_text += "No duplicates found."
248
 
249
- # Final status
250
- status = "Deduplication completed."
251
- yield status, result_text
252
 
253
  except Exception as e:
254
  yield f"An error occurred: {e}", ""
255
  raise e
256
 
257
- def deduplicate_across_datasets(
258
- embedding_matrix_1: np.ndarray,
259
- embedding_matrix_2: np.ndarray,
260
- threshold: float,
261
- batch_size: int = 1024,
262
- progress=None
263
- ) -> tuple[list[int], dict[int, int]]:
264
- # Building the index from Dataset 1
265
- progress(0, desc="Building search index from Dataset 1...")
266
- reach = Reach(
267
- vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
268
- )
269
-
270
- duplicate_indices_in_test = []
271
- duplicate_to_original_mapping = {}
272
-
273
- # Finding nearest neighbors between datasets
274
- progress(0, desc="Finding nearest neighbors between datasets...")
275
- results = reach.nearest_neighbor_threshold(
276
- embedding_matrix_2,
277
- threshold=threshold,
278
- batch_size=batch_size,
279
- show_progressbar=False, # Disable internal progress bar
280
- )
281
-
282
- total_items = len(embedding_matrix_2)
283
- # Processing duplicates with a progress bar
284
- for i, similar_items in enumerate(
285
- progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
286
- ):
287
- similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
288
-
289
- if similar_indices:
290
- duplicate_indices_in_test.append(i)
291
- duplicate_to_original_mapping[i] = similar_indices[0]
292
-
293
- return duplicate_indices_in_test, duplicate_to_original_mapping
294
-
295
- # Adjust the height of the status_output component using custom CSS
296
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
297
  gr.Markdown("# Semantic Deduplication")
298
 
@@ -303,38 +174,27 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
303
  )
304
 
305
  with gr.Row():
306
- dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
307
- dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
308
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
309
 
310
  dataset2_inputs = gr.Column(visible=False)
311
  with dataset2_inputs:
312
  gr.Markdown("### Dataset 2")
313
  with gr.Row():
314
- dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
315
- dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
316
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
317
 
318
- threshold = gr.Slider(
319
- minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
320
- )
321
-
322
  compute_button = gr.Button("Compute")
323
-
324
- # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
325
  status_output = gr.Markdown(elem_id="status_output")
326
  result_output = gr.Markdown()
327
 
328
- # Function to update the visibility of dataset2_inputs
329
- def update_visibility(deduplication_type_value):
330
- if deduplication_type_value == "Cross-dataset":
331
- return gr.update(visible=True)
332
- else:
333
- return gr.update(visible=False)
334
 
335
- deduplication_type.change(
336
- update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
337
- )
338
 
339
  compute_button.click(
340
  fn=perform_deduplication,
@@ -353,19 +213,17 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
353
 
354
  demo.launch()
355
 
356
-
357
  # import gradio as gr
358
  # from datasets import load_dataset
359
  # import numpy as np
360
  # from model2vec import StaticModel
361
  # from reach import Reach
362
  # from difflib import ndiff
363
- # import tqdm
364
 
365
  # # Load the model at startup
366
  # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
367
 
368
- # # Update default dataset to 'sst2' and set default threshold to 0.9
369
  # default_dataset1_name = "sst2"
370
  # default_dataset1_split = "train"
371
  # default_dataset2_name = "sst2"
@@ -384,29 +242,42 @@ demo.launch()
384
 
385
  # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
386
  # embeddings = []
387
- # for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
388
- # batch_embeddings = model.encode(batch, show_progressbar=False)
 
389
  # embeddings.append(batch_embeddings)
 
390
  # return np.concatenate(embeddings, axis=0)
391
 
392
- # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
393
- # """
394
- # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
395
- # """
396
- # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
 
 
 
 
 
 
397
 
398
  # deduplicated_indices = set(range(len(embedding_matrix)))
399
  # duplicate_to_original_mapping = {}
400
 
 
 
401
  # results = reach.nearest_neighbor_threshold(
402
  # embedding_matrix,
403
  # threshold=threshold,
404
  # batch_size=batch_size,
405
- # show_progressbar=False
406
  # )
407
 
 
408
  # total_items = len(embedding_matrix)
409
- # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
 
 
410
  # if i not in deduplicated_indices:
411
  # continue
412
 
@@ -419,35 +290,9 @@ demo.launch()
419
 
420
  # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
421
 
422
- # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
423
- # """
424
- # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
425
- # """
426
- # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
427
-
428
- # duplicate_indices_in_test = []
429
- # duplicate_to_original_mapping = {}
430
-
431
- # results = reach.nearest_neighbor_threshold(
432
- # embedding_matrix_2,
433
- # threshold=threshold,
434
- # batch_size=batch_size,
435
- # show_progressbar=False
436
- # )
437
-
438
- # total_items = len(embedding_matrix_2)
439
- # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
440
- # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
441
-
442
- # if similar_indices:
443
- # duplicate_indices_in_test.append(i)
444
- # duplicate_to_original_mapping[i] = similar_indices[0]
445
-
446
- # return duplicate_indices_in_test, duplicate_to_original_mapping
447
-
448
  # def display_word_differences(x: str, y: str) -> str:
449
  # diff = ndiff(x.split(), y.split())
450
- # return " ".join([word for word in diff if word.startswith(('+', '-'))])
451
 
452
  # def perform_deduplication(
453
  # deduplication_type,
@@ -458,26 +303,61 @@ demo.launch()
458
  # dataset2_split="",
459
  # dataset2_text_column="",
460
  # threshold=default_threshold,
461
- # progress=gr.Progress(track_tqdm=True)
462
  # ):
463
  # try:
 
464
  # threshold = float(threshold)
465
 
 
 
 
466
  # if deduplication_type == "Single dataset":
467
- # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
468
- # texts = [example[dataset1_text_column] for example in ds]
 
 
 
 
 
 
 
 
469
 
470
- # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
471
- # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
 
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  # num_duplicates = len(duplicate_to_original_mapping)
474
  # num_total = len(texts)
475
  # num_deduplicated = len(deduplicated_indices)
476
 
477
  # result_text = f"**Total documents:** {num_total}\n"
478
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
479
- # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
480
 
 
481
  # if num_duplicates > 0:
482
  # result_text += "**Examples of duplicates found:**\n\n"
483
  # num_examples = min(5, num_duplicates)
@@ -492,19 +372,70 @@ demo.launch()
492
  # else:
493
  # result_text += "No duplicates found."
494
 
495
- # yield result_text
 
 
496
 
497
  # elif deduplication_type == "Cross-dataset":
498
- # ds1 = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
499
- # ds2 = ds_default2 if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split else load_dataset(dataset2_name, split=dataset2_split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
 
 
 
501
  # texts1 = [example[dataset1_text_column] for example in ds1]
502
- # texts2 = [example[dataset2_text_column] for example in ds2]
503
 
504
- # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
505
- # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
 
 
506
 
507
- # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  # num_duplicates = len(duplicate_indices_in_ds2)
510
  # num_total_ds2 = len(texts2)
@@ -514,6 +445,7 @@ demo.launch()
514
  # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
515
  # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
516
 
 
517
  # if num_duplicates > 0:
518
  # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
519
  # num_examples = min(5, num_duplicates)
@@ -529,19 +461,60 @@ demo.launch()
529
  # else:
530
  # result_text += "No duplicates found."
531
 
532
- # yield result_text
 
 
533
 
534
  # except Exception as e:
535
  # yield f"An error occurred: {e}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
- # # Adjust the height of the status_output and result_output components
538
- # with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_output { height: 300px; overflow: auto; }") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  # gr.Markdown("# Semantic Deduplication")
540
 
541
  # deduplication_type = gr.Radio(
542
  # choices=["Single dataset", "Cross-dataset"],
543
  # label="Deduplication Type",
544
- # value="Single dataset"
545
  # )
546
 
547
  # with gr.Row():
@@ -558,17 +531,16 @@ demo.launch()
558
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
559
 
560
  # threshold = gr.Slider(
561
- # minimum=0.0,
562
- # maximum=1.0,
563
- # value=default_threshold,
564
- # label="Similarity Threshold"
565
  # )
566
 
567
  # compute_button = gr.Button("Compute")
568
 
 
569
  # status_output = gr.Markdown(elem_id="status_output")
570
- # result_output = gr.Markdown(elem_id="result_output")
571
 
 
572
  # def update_visibility(deduplication_type_value):
573
  # if deduplication_type_value == "Cross-dataset":
574
  # return gr.update(visible=True)
@@ -576,9 +548,7 @@ demo.launch()
576
  # return gr.update(visible=False)
577
 
578
  # deduplication_type.change(
579
- # update_visibility,
580
- # inputs=deduplication_type,
581
- # outputs=dataset2_inputs
582
  # )
583
 
584
  # compute_button.click(
@@ -591,1115 +561,9 @@ demo.launch()
591
  # dataset2_name,
592
  # dataset2_split,
593
  # dataset2_text_column,
594
- # threshold
595
  # ],
596
- # outputs=[status_output, result_output]
597
  # )
598
 
599
  # demo.launch()
600
-
601
- # # import gradio as gr
602
- # # from datasets import load_dataset
603
- # # import numpy as np
604
- # # from model2vec import StaticModel
605
- # # from reach import Reach
606
- # # from difflib import ndiff
607
- # # import tqdm
608
-
609
- # # # Load the model at startup
610
- # # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
611
-
612
- # # # Update default dataset to 'sst2' and set default threshold to 0.9
613
- # # default_dataset1_name = "sst2"
614
- # # default_dataset1_split = "train"
615
- # # default_dataset2_name = "sst2"
616
- # # default_dataset2_split = "validation"
617
- # # default_text_column = "sentence"
618
- # # default_threshold = 0.9
619
-
620
- # # # Load the default datasets at startup
621
- # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
622
- # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
623
-
624
- # # def batch_iterable(iterable, batch_size):
625
- # # """Helper function to create batches from an iterable."""
626
- # # for i in range(0, len(iterable), batch_size):
627
- # # yield iterable[i:i + batch_size]
628
-
629
- # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
630
- # # embeddings = []
631
- # # for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
632
- # # batch_embeddings = model.encode(batch, show_progressbar=False)
633
- # # embeddings.append(batch_embeddings)
634
- # # return np.concatenate(embeddings, axis=0)
635
-
636
- # # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
637
- # # """
638
- # # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
639
- # # """
640
- # # # Building the index
641
- # # progress(0, desc="Building search index...")
642
- # # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
643
-
644
- # # deduplicated_indices = set(range(len(embedding_matrix)))
645
- # # duplicate_to_original_mapping = {}
646
-
647
- # # # Finding nearest neighbors
648
- # # progress(0, desc="Finding nearest neighbors...")
649
- # # results = reach.nearest_neighbor_threshold(
650
- # # embedding_matrix,
651
- # # threshold=threshold,
652
- # # batch_size=batch_size,
653
- # # show_progressbar=False # Disable internal progress bar
654
- # # )
655
-
656
- # # # Processing duplicates with a progress bar
657
- # # total_items = len(embedding_matrix)
658
- # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
659
- # # if i not in deduplicated_indices:
660
- # # continue
661
-
662
- # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
663
-
664
- # # for sim_idx in similar_indices:
665
- # # if sim_idx in deduplicated_indices:
666
- # # deduplicated_indices.remove(sim_idx)
667
- # # duplicate_to_original_mapping[sim_idx] = i
668
-
669
- # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
670
-
671
- # # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
672
- # # """
673
- # # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
674
- # # """
675
- # # # Building the index from Dataset 1
676
- # # progress(0, desc="Building search index from Dataset 1...")
677
- # # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
678
-
679
- # # duplicate_indices_in_test = []
680
- # # duplicate_to_original_mapping = {}
681
-
682
- # # # Finding nearest neighbors between datasets
683
- # # progress(0, desc="Finding nearest neighbors between datasets...")
684
- # # results = reach.nearest_neighbor_threshold(
685
- # # embedding_matrix_2,
686
- # # threshold=threshold,
687
- # # batch_size=batch_size,
688
- # # show_progressbar=False # Disable internal progress bar
689
- # # )
690
-
691
- # # total_items = len(embedding_matrix_2)
692
- # # # Processing duplicates with a progress bar
693
- # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
694
- # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
695
-
696
- # # if similar_indices:
697
- # # duplicate_indices_in_test.append(i)
698
- # # duplicate_to_original_mapping[i] = similar_indices[0]
699
-
700
- # # return duplicate_indices_in_test, duplicate_to_original_mapping
701
-
702
- # # def display_word_differences(x: str, y: str) -> str:
703
- # # diff = ndiff(x.split(), y.split())
704
- # # return " ".join([word for word in diff if word.startswith(('+', '-'))])
705
-
706
- # # def perform_deduplication(
707
- # # deduplication_type,
708
- # # dataset1_name,
709
- # # dataset1_split,
710
- # # dataset1_text_column,
711
- # # dataset2_name="",
712
- # # dataset2_split="",
713
- # # dataset2_text_column="",
714
- # # threshold=default_threshold,
715
- # # progress=gr.Progress(track_tqdm=True)
716
- # # ):
717
- # # try:
718
- # # # Convert threshold to float
719
- # # threshold = float(threshold)
720
-
721
- # # # Initialize status message
722
- # # status = ""
723
-
724
- # # if deduplication_type == "Single dataset":
725
- # # # Load Dataset 1
726
- # # status = "Loading Dataset 1..."
727
- # # yield status, ""
728
- # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
729
- # # ds = ds_default1
730
- # # else:
731
- # # ds = load_dataset(dataset1_name, split=dataset1_split)
732
-
733
- # # # Extract texts
734
- # # status = "Extracting texts from Dataset 1..."
735
- # # yield status, ""
736
- # # texts = [example[dataset1_text_column] for example in ds]
737
-
738
- # # # Compute embeddings
739
- # # status = "Computing embeddings for Dataset 1..."
740
- # # yield status, ""
741
- # # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
742
-
743
- # # # Deduplicate
744
- # # status = "Deduplicating embeddings..."
745
- # # yield status, ""
746
- # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
747
- # # embedding_matrix, threshold, progress=progress
748
- # # )
749
-
750
- # # # Prepare the results
751
- # # num_duplicates = len(duplicate_to_original_mapping)
752
- # # num_total = len(texts)
753
- # # num_deduplicated = len(deduplicated_indices)
754
-
755
- # # result_text = f"**Total documents:** {num_total}\n"
756
- # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
757
- # # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
758
-
759
- # # # Show deduplicated examples
760
- # # if num_duplicates > 0:
761
- # # result_text += "**Examples of duplicates found:**\n\n"
762
- # # num_examples = min(5, num_duplicates)
763
- # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
764
- # # original_text = texts[original_idx]
765
- # # duplicate_text = texts[duplicate_idx]
766
- # # differences = display_word_differences(original_text, duplicate_text)
767
- # # result_text += f"**Original text:**\n{original_text}\n\n"
768
- # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
769
- # # result_text += f"**Differences:**\n{differences}\n"
770
- # # result_text += "-" * 50 + "\n\n"
771
- # # else:
772
- # # result_text += "No duplicates found."
773
-
774
- # # # Final status
775
- # # status = "Deduplication completed."
776
- # # yield status, result_text
777
-
778
- # # elif deduplication_type == "Cross-dataset":
779
- # # # Load Dataset 1
780
- # # status = "Loading Dataset 1..."
781
- # # yield status, ""
782
- # # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
783
- # # ds1 = ds_default1
784
- # # else:
785
- # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
786
-
787
- # # # Load Dataset 2
788
- # # status = "Loading Dataset 2..."
789
- # # yield status, ""
790
- # # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
791
- # # ds2 = ds_default2
792
- # # else:
793
- # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
794
-
795
- # # # Extract texts from Dataset 1
796
- # # status = "Extracting texts from Dataset 1..."
797
- # # yield status, ""
798
- # # texts1 = [example[dataset1_text_column] for example in ds1]
799
-
800
- # # # Extract texts from Dataset 2
801
- # # status = "Extracting texts from Dataset 2..."
802
- # # yield status, ""
803
- # # texts2 = [example[dataset2_text_column] for example in ds2]
804
-
805
- # # # Compute embeddings for Dataset 1
806
- # # status = "Computing embeddings for Dataset 1..."
807
- # # yield status, ""
808
- # # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
809
-
810
- # # # Compute embeddings for Dataset 2
811
- # # status = "Computing embeddings for Dataset 2..."
812
- # # yield status, ""
813
- # # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
814
-
815
- # # # Deduplicate across datasets
816
- # # status = "Deduplicating embeddings across datasets..."
817
- # # yield status, ""
818
- # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
819
- # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
820
- # # )
821
-
822
- # # num_duplicates = len(duplicate_indices_in_ds2)
823
- # # num_total_ds2 = len(texts2)
824
- # # num_unique_ds2 = num_total_ds2 - num_duplicates
825
-
826
- # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
827
- # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
828
- # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
829
-
830
- # # # Show deduplicated examples
831
- # # if num_duplicates > 0:
832
- # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
833
- # # num_examples = min(5, num_duplicates)
834
- # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
835
- # # original_idx = duplicate_to_original_mapping[duplicate_idx]
836
- # # original_text = texts1[original_idx]
837
- # # duplicate_text = texts2[duplicate_idx]
838
- # # differences = display_word_differences(original_text, duplicate_text)
839
- # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
840
- # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
841
- # # result_text += f"**Differences:**\n{differences}\n"
842
- # # result_text += "-" * 50 + "\n\n"
843
- # # else:
844
- # # result_text += "No duplicates found."
845
-
846
- # # # Final status
847
- # # status = "Deduplication completed."
848
- # # yield status, result_text
849
-
850
- # # except Exception as e:
851
- # # yield f"An error occurred: {e}", ""
852
- # # raise e
853
-
854
- # # with gr.Blocks() as demo:
855
- # # gr.Markdown("# Semantic Deduplication")
856
-
857
- # # deduplication_type = gr.Radio(
858
- # # choices=["Single dataset", "Cross-dataset"],
859
- # # label="Deduplication Type",
860
- # # value="Single dataset"
861
- # # )
862
-
863
- # # with gr.Row():
864
- # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
865
- # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
866
- # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
867
-
868
- # # dataset2_inputs = gr.Column(visible=False)
869
- # # with dataset2_inputs:
870
- # # gr.Markdown("### Dataset 2")
871
- # # with gr.Row():
872
- # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
873
- # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
874
- # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
875
-
876
- # # threshold = gr.Slider(
877
- # # minimum=0.0,
878
- # # maximum=1.0,
879
- # # value=default_threshold,
880
- # # label="Similarity Threshold"
881
- # # )
882
-
883
- # # compute_button = gr.Button("Compute")
884
-
885
- # # status_output = gr.Markdown()
886
- # # result_output = gr.Markdown()
887
-
888
- # # # Function to update the visibility of dataset2_inputs
889
- # # def update_visibility(deduplication_type_value):
890
- # # if deduplication_type_value == "Cross-dataset":
891
- # # return gr.update(visible=True)
892
- # # else:
893
- # # return gr.update(visible=False)
894
-
895
- # # deduplication_type.change(
896
- # # update_visibility,
897
- # # inputs=deduplication_type,
898
- # # outputs=dataset2_inputs
899
- # # )
900
-
901
- # # compute_button.click(
902
- # # fn=perform_deduplication,
903
- # # inputs=[
904
- # # deduplication_type,
905
- # # dataset1_name,
906
- # # dataset1_split,
907
- # # dataset1_text_column,
908
- # # dataset2_name,
909
- # # dataset2_split,
910
- # # dataset2_text_column,
911
- # # threshold
912
- # # ],
913
- # # outputs=[status_output, result_output]
914
- # # )
915
-
916
- # # demo.launch()
917
-
918
-
919
- # # import gradio as gr
920
- # # from datasets import load_dataset
921
- # # import numpy as np
922
- # # import model2vec
923
- # # from reach import Reach
924
- # # from difflib import ndiff
925
-
926
- # # # Load the model at startup
927
- # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
928
-
929
- # # # Default dataset parameters
930
- # # default_dataset1_name = "sst2"
931
- # # default_dataset1_split = "train"
932
- # # default_dataset2_name = "sst2"
933
- # # default_dataset2_split = "validation"
934
- # # default_text_column = "sentence"
935
- # # default_threshold = 0.9
936
-
937
- # # # Load the default datasets at startup
938
- # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
939
- # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
940
-
941
- # # def batch_iterable(iterable, batch_size):
942
- # # """Helper function to create batches from an iterable."""
943
- # # for i in range(0, len(iterable), batch_size):
944
- # # yield iterable[i:i + batch_size]
945
-
946
- # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
947
- # # embeddings = []
948
- # # total_batches = (len(texts) + batch_size - 1) // batch_size
949
- # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
950
- # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
951
- # # embeddings.append(batch_embeddings)
952
- # # progress((i + 1) / total_batches, desc=desc)
953
- # # return np.concatenate(embeddings, axis=0)
954
-
955
- # # def deduplicate(
956
- # # embedding_matrix: np.ndarray,
957
- # # threshold: float,
958
- # # batch_size: int = 1024,
959
- # # progress=None
960
- # # ) -> tuple[np.ndarray, dict[int, int]]:
961
- # # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
962
-
963
- # # deduplicated_indices = set(range(len(embedding_matrix)))
964
- # # duplicate_to_original_mapping = {}
965
-
966
- # # results = reach.nearest_neighbor_threshold(
967
- # # embedding_matrix,
968
- # # threshold=threshold,
969
- # # batch_size=batch_size,
970
- # # show_progressbar=False,
971
- # # )
972
-
973
- # # total_items = len(embedding_matrix)
974
- # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
975
- # # if i not in deduplicated_indices:
976
- # # continue
977
-
978
- # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
979
- # # for sim_idx in similar_indices:
980
- # # if sim_idx in deduplicated_indices:
981
- # # deduplicated_indices.remove(sim_idx)
982
- # # duplicate_to_original_mapping[sim_idx] = i
983
-
984
- # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
985
-
986
- # # def display_word_differences(x: str, y: str) -> str:
987
- # # diff = ndiff(x.split(), y.split())
988
- # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
989
-
990
- # # def perform_deduplication(
991
- # # deduplication_type,
992
- # # dataset1_name,
993
- # # dataset1_split,
994
- # # dataset1_text_column,
995
- # # dataset2_name="",
996
- # # dataset2_split="",
997
- # # dataset2_text_column="",
998
- # # threshold=default_threshold,
999
- # # progress=gr.Progress(track_tqdm=True),
1000
- # # ):
1001
- # # try:
1002
- # # threshold = float(threshold)
1003
-
1004
- # # if deduplication_type == "Single dataset":
1005
- # # ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
1006
- # # texts = [example[dataset1_text_column] for example in ds]
1007
-
1008
- # # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress)
1009
- # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
1010
-
1011
- # # num_duplicates = len(duplicate_to_original_mapping)
1012
- # # num_total = len(texts)
1013
- # # num_deduplicated = len(deduplicated_indices)
1014
-
1015
- # # result_text = f"**Total documents:** {num_total}\n"
1016
- # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
1017
- # # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
1018
-
1019
- # # if num_duplicates > 0:
1020
- # # result_text += "**Examples of duplicates found:**\n\n"
1021
- # # num_examples = min(5, num_duplicates)
1022
- # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
1023
- # # original_text = texts[original_idx]
1024
- # # duplicate_text = texts[duplicate_idx]
1025
- # # differences = display_word_differences(original_text, duplicate_text)
1026
- # # result_text += f"**Original text:**\n{original_text}\n\n"
1027
- # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
1028
- # # result_text += f"**Differences:**\n{differences}\n"
1029
- # # result_text += "-" * 50 + "\n\n"
1030
- # # else:
1031
- # # result_text += "No duplicates found."
1032
-
1033
- # # yield result_text
1034
-
1035
- # # except Exception as e:
1036
- # # yield f"An error occurred: {e}"
1037
-
1038
- # # # Gradio interface setup
1039
- # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1040
- # # gr.Markdown("# Semantic Deduplication")
1041
-
1042
- # # deduplication_type = gr.Radio(
1043
- # # choices=["Single dataset", "Cross-dataset"],
1044
- # # label="Deduplication Type",
1045
- # # value="Single dataset",
1046
- # # )
1047
-
1048
- # # with gr.Row():
1049
- # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
1050
- # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
1051
- # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1052
-
1053
- # # dataset2_inputs = gr.Column(visible=False)
1054
- # # with dataset2_inputs:
1055
- # # gr.Markdown("### Dataset 2")
1056
- # # with gr.Row():
1057
- # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
1058
- # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1059
- # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1060
-
1061
- # # threshold = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold")
1062
-
1063
- # # compute_button = gr.Button("Compute")
1064
-
1065
- # # result_output = gr.Markdown()
1066
-
1067
- # # def update_visibility(deduplication_type_value):
1068
- # # return gr.update(visible=True) if deduplication_type_value == "Cross-dataset" else gr.update(visible=False)
1069
-
1070
- # # deduplication_type.change(
1071
- # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
1072
- # # )
1073
-
1074
- # # compute_button.click(
1075
- # # fn=perform_deduplication,
1076
- # # inputs=[
1077
- # # deduplication_type,
1078
- # # dataset1_name,
1079
- # # dataset1_split,
1080
- # # dataset1_text_column,
1081
- # # dataset2_name,
1082
- # # dataset2_split,
1083
- # # dataset2_text_column,
1084
- # # threshold,
1085
- # # ],
1086
- # # outputs=[result_output],
1087
- # # )
1088
-
1089
- # # demo.launch()
1090
-
1091
-
1092
- # # # import gradio as gr
1093
- # # # from datasets import load_dataset
1094
- # # # import numpy as np
1095
- # # # import model2vec
1096
- # # # from reach import Reach
1097
- # # # from difflib import ndiff
1098
- # # # import time
1099
-
1100
- # # # # Load the model at startup
1101
- # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
1102
-
1103
- # # # # Default dataset parameters
1104
- # # # default_dataset1_name = "sst2"
1105
- # # # default_dataset1_split = "train"
1106
- # # # default_dataset2_name = "sst2"
1107
- # # # default_dataset2_split = "validation"
1108
- # # # default_text_column = "sentence"
1109
- # # # default_threshold = 0.9
1110
-
1111
- # # # # Load the default datasets at startup
1112
- # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
1113
- # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
1114
-
1115
- # # # def batch_iterable(iterable, batch_size):
1116
- # # # """Helper function to create batches from an iterable."""
1117
- # # # for i in range(0, len(iterable), batch_size):
1118
- # # # yield iterable[i:i + batch_size]
1119
-
1120
- # # # def log_time(message, start_time=None, logs=None):
1121
- # # # """Helper function to log the start and end times."""
1122
- # # # current_time = time.time()
1123
- # # # if start_time is not None:
1124
- # # # elapsed = current_time - start_time
1125
- # # # log_message = f"{message} - Took {elapsed:.2f} seconds"
1126
- # # # else:
1127
- # # # log_message = f"{message} - Started"
1128
-
1129
- # # # if logs is not None:
1130
- # # # logs.append(log_message)
1131
-
1132
- # # # def compute_embeddings(texts, batch_size, progress, logs, desc="Computing embeddings"):
1133
- # # # embeddings = []
1134
- # # # total_batches = (len(texts) + batch_size - 1) // batch_size
1135
- # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
1136
- # # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
1137
- # # # embeddings.append(batch_embeddings)
1138
- # # # progress((i + 1) / total_batches, desc=desc)
1139
- # # # return np.concatenate(embeddings, axis=0)
1140
-
1141
- # # # def deduplicate(
1142
- # # # embedding_matrix: np.ndarray,
1143
- # # # threshold: float,
1144
- # # # batch_size: int = 1024,
1145
- # # # progress=None,
1146
- # # # logs=None
1147
- # # # ) -> tuple[np.ndarray, dict[int, int]]:
1148
- # # # # Building the index
1149
- # # # log_time("Building search index", logs=logs)
1150
- # # # reach = Reach(
1151
- # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
1152
- # # # )
1153
-
1154
- # # # deduplicated_indices = set(range(len(embedding_matrix)))
1155
- # # # duplicate_to_original_mapping = {}
1156
-
1157
- # # # # Finding nearest neighbors
1158
- # # # log_time("Finding nearest neighbors", logs=logs)
1159
- # # # results = reach.nearest_neighbor_threshold(
1160
- # # # embedding_matrix,
1161
- # # # threshold=threshold,
1162
- # # # batch_size=batch_size,
1163
- # # # show_progressbar=False, # Disable internal progress bar
1164
- # # # )
1165
-
1166
- # # # # Processing duplicates with a progress bar
1167
- # # # total_items = len(embedding_matrix)
1168
- # # # log_time("Processing duplicates", logs=logs)
1169
- # # # for i, similar_items in enumerate(
1170
- # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
1171
- # # # ):
1172
- # # # if i not in deduplicated_indices:
1173
- # # # continue
1174
-
1175
- # # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
1176
-
1177
- # # # for sim_idx in similar_indices:
1178
- # # # if sim_idx in deduplicated_indices:
1179
- # # # deduplicated_indices.remove(sim_idx)
1180
- # # # duplicate_to_original_mapping[sim_idx] = i
1181
-
1182
- # # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
1183
-
1184
- # # # def display_word_differences(x: str, y: str) -> str:
1185
- # # # diff = ndiff(x.split(), y.split())
1186
- # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
1187
-
1188
- # # # def encode_texts(texts, progress=None, logs=None):
1189
- # # # embedding_matrix = model.encode(texts, show_progressbar=False)
1190
- # # # log_time("Encoding texts completed", logs=logs)
1191
- # # # return embedding_matrix
1192
-
1193
- # # # def perform_deduplication(
1194
- # # # deduplication_type,
1195
- # # # dataset1_name,
1196
- # # # dataset1_split,
1197
- # # # dataset1_text_column,
1198
- # # # dataset2_name="",
1199
- # # # dataset2_split="",
1200
- # # # dataset2_text_column="",
1201
- # # # threshold=default_threshold,
1202
- # # # progress=gr.Progress(track_tqdm=True),
1203
- # # # ):
1204
- # # # logs = [] # To store log messages
1205
- # # # try:
1206
- # # # # Convert threshold to float
1207
- # # # threshold = float(threshold)
1208
-
1209
- # # # # Initialize status message
1210
- # # # log_time("Deduplication started", logs=logs)
1211
-
1212
- # # # if deduplication_type == "Single dataset":
1213
- # # # # Load Dataset 1
1214
- # # # start_time = time.time()
1215
- # # # log_time("Loading Dataset 1", logs=logs)
1216
- # # # if (
1217
- # # # dataset1_name == default_dataset1_name
1218
- # # # and dataset1_split == default_dataset1_split
1219
- # # # ):
1220
- # # # ds = ds_default1
1221
- # # # else:
1222
- # # # ds = load_dataset(dataset1_name, split=dataset1_split)
1223
- # # # log_time("Loading Dataset 1 completed", start_time=start_time, logs=logs)
1224
-
1225
- # # # # Extract texts
1226
- # # # start_time = time.time()
1227
- # # # log_time("Extracting texts from Dataset 1", logs=logs)
1228
- # # # texts = [example[dataset1_text_column] for example in ds]
1229
- # # # log_time("Extracting texts from Dataset 1 completed", start_time=start_time, logs=logs)
1230
-
1231
- # # # # Compute embeddings
1232
- # # # start_time = time.time()
1233
- # # # log_time("Computing embeddings for Dataset 1", logs=logs)
1234
- # # # embedding_matrix = encode_texts(texts, progress=progress, logs=logs)
1235
- # # # log_time("Computing embeddings for Dataset 1 completed", start_time=start_time, logs=logs)
1236
-
1237
- # # # # Deduplicate
1238
- # # # start_time = time.time()
1239
- # # # log_time("Deduplicating embeddings", logs=logs)
1240
- # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
1241
- # # # embedding_matrix, threshold, progress=progress, logs=logs
1242
- # # # )
1243
- # # # log_time("Deduplication completed", start_time=start_time, logs=logs)
1244
-
1245
- # # # # Prepare the results
1246
- # # # num_duplicates = len(duplicate_to_original_mapping)
1247
- # # # num_total = len(texts)
1248
- # # # num_deduplicated = len(deduplicated_indices)
1249
-
1250
- # # # result_text = f"**Total documents:** {num_total}\n"
1251
- # # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
1252
- # # # result_text += (
1253
- # # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
1254
- # # # )
1255
-
1256
- # # # # Show deduplicated examples
1257
- # # # if num_duplicates > 0:
1258
- # # # result_text += "**Examples of duplicates found:**\n\n"
1259
- # # # num_examples = min(5, num_duplicates)
1260
- # # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
1261
- # # # original_text = texts[original_idx]
1262
- # # # duplicate_text = texts[duplicate_idx]
1263
- # # # differences = display_word_differences(original_text, duplicate_text)
1264
- # # # result_text += f"**Original text:**\n{original_text}\n\n"
1265
- # # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
1266
- # # # result_text += f"**Differences:**\n{differences}\n"
1267
- # # # result_text += "-" * 50 + "\n\n"
1268
- # # # else:
1269
- # # # result_text += "No duplicates found."
1270
-
1271
- # # # log_time("Deduplication process finished", logs=logs)
1272
- # # # full_log = "\n".join(logs) # Combine all logs into one output
1273
- # # # yield full_log, result_text
1274
-
1275
- # # # except Exception as e:
1276
- # # # full_log = "\n".join(logs) # Combine all logs into one output in case of an error
1277
- # # # yield f"An error occurred: {e}", ""
1278
- # # # raise e
1279
-
1280
- # # # # Adjust the height of the status_output component using custom CSS
1281
- # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1282
- # # # gr.Markdown("# Semantic Deduplication")
1283
-
1284
- # # # deduplication_type = gr.Radio(
1285
- # # # choices=["Single dataset", "Cross-dataset"],
1286
- # # # label="Deduplication Type",
1287
- # # # value="Single dataset",
1288
- # # # )
1289
-
1290
- # # # with gr.Row():
1291
- # # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
1292
- # # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
1293
- # # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1294
-
1295
- # # # dataset2_inputs = gr.Column(visible=False)
1296
- # # # with dataset2_inputs:
1297
- # # # gr.Markdown("### Dataset 2")
1298
- # # # with gr.Row():
1299
- # # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
1300
- # # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1301
- # # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1302
-
1303
- # # # threshold = gr.Slider(
1304
- # # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
1305
- # # # )
1306
-
1307
- # # # compute_button = gr.Button("Compute")
1308
-
1309
- # # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
1310
- # # # status_output = gr.Markdown(elem_id="status_output")
1311
- # # # result_output = gr.Markdown()
1312
-
1313
- # # # # Function to update the visibility of dataset2_inputs
1314
- # # # def update_visibility(deduplication_type_value):
1315
- # # # if deduplication_type_value == "Cross-dataset":
1316
- # # # return gr.update(visible=True)
1317
- # # # else:
1318
- # # # return gr.update(visible=False)
1319
-
1320
- # # # deduplication_type.change(
1321
- # # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
1322
- # # # )
1323
-
1324
- # # # compute_button.click(
1325
- # # # fn=perform_deduplication,
1326
- # # # inputs=[
1327
- # # # deduplication_type,
1328
- # # # dataset1_name,
1329
- # # # dataset1_split,
1330
- # # # dataset1_text_column,
1331
- # # # dataset2_name,
1332
- # # # dataset2_split,
1333
- # # # dataset2_text_column,
1334
- # # # threshold,
1335
- # # # ],
1336
- # # # outputs=[status_output, result_output],
1337
- # # # )
1338
-
1339
- # # # demo.launch()
1340
-
1341
-
1342
-
1343
- # # # # import gradio as gr
1344
- # # # # from datasets import load_dataset
1345
- # # # # import numpy as np
1346
- # # # # #from model2vec import StaticModel
1347
- # # # # import model2vec
1348
- # # # # from reach import Reach
1349
- # # # # from difflib import ndiff
1350
-
1351
-
1352
- # # # # # Load the model at startup
1353
- # # # # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
1354
-
1355
- # # # # # Default dataset parameters
1356
- # # # # default_dataset1_name = "sst2"
1357
- # # # # default_dataset1_split = "train"
1358
- # # # # default_dataset2_name = "sst2"
1359
- # # # # default_dataset2_split = "validation"
1360
- # # # # default_text_column = "sentence"
1361
- # # # # default_threshold = 0.9
1362
-
1363
- # # # # # Load the default datasets at startup
1364
- # # # # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
1365
- # # # # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
1366
-
1367
-
1368
- # # # # def batch_iterable(iterable, batch_size):
1369
- # # # # """Helper function to create batches from an iterable."""
1370
- # # # # for i in range(0, len(iterable), batch_size):
1371
- # # # # yield iterable[i:i + batch_size]
1372
-
1373
- # # # # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
1374
- # # # # embeddings = []
1375
- # # # # total_batches = (len(texts) + batch_size - 1) // batch_size
1376
- # # # # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
1377
- # # # # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
1378
- # # # # embeddings.append(batch_embeddings)
1379
- # # # # progress((i + 1) / total_batches, desc=desc)
1380
- # # # # return np.concatenate(embeddings, axis=0)
1381
-
1382
- # # # # def deduplicate(
1383
- # # # # embedding_matrix: np.ndarray,
1384
- # # # # threshold: float,
1385
- # # # # batch_size: int = 1024,
1386
- # # # # progress=None
1387
- # # # # ) -> tuple[np.ndarray, dict[int, int]]:
1388
- # # # # # Building the index
1389
- # # # # progress(0, desc="Building search index...")
1390
- # # # # reach = Reach(
1391
- # # # # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
1392
- # # # # )
1393
-
1394
- # # # # deduplicated_indices = set(range(len(embedding_matrix)))
1395
- # # # # duplicate_to_original_mapping = {}
1396
-
1397
- # # # # # Finding nearest neighbors
1398
- # # # # progress(0, desc="Finding nearest neighbors...")
1399
- # # # # results = reach.nearest_neighbor_threshold(
1400
- # # # # embedding_matrix,
1401
- # # # # threshold=threshold,
1402
- # # # # batch_size=batch_size,
1403
- # # # # show_progressbar=False, # Disable internal progress bar
1404
- # # # # )
1405
-
1406
- # # # # # Processing duplicates with a progress bar
1407
- # # # # total_items = len(embedding_matrix)
1408
- # # # # for i, similar_items in enumerate(
1409
- # # # # progress.tqdm(results, desc="Processing duplicates", total=total_items)
1410
- # # # # ):
1411
- # # # # if i not in deduplicated_indices:
1412
- # # # # continue
1413
-
1414
- # # # # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
1415
-
1416
- # # # # for sim_idx in similar_indices:
1417
- # # # # if sim_idx in deduplicated_indices:
1418
- # # # # deduplicated_indices.remove(sim_idx)
1419
- # # # # duplicate_to_original_mapping[sim_idx] = i
1420
-
1421
- # # # # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
1422
-
1423
- # # # # def display_word_differences(x: str, y: str) -> str:
1424
- # # # # diff = ndiff(x.split(), y.split())
1425
- # # # # return " ".join([word for word in diff if word.startswith(("+", "-"))])
1426
-
1427
-
1428
- # # # # def encode_texts(texts, progress=None):
1429
- # # # # embedding_matrix = model.encode(texts, show_progressbar=False)
1430
- # # # # return embedding_matrix
1431
-
1432
- # # # # def perform_deduplication(
1433
- # # # # deduplication_type,
1434
- # # # # dataset1_name,
1435
- # # # # dataset1_split,
1436
- # # # # dataset1_text_column,
1437
- # # # # dataset2_name="",
1438
- # # # # dataset2_split="",
1439
- # # # # dataset2_text_column="",
1440
- # # # # threshold=default_threshold,
1441
- # # # # progress=gr.Progress(track_tqdm=True),
1442
- # # # # ):
1443
- # # # # try:
1444
- # # # # # Convert threshold to float
1445
- # # # # threshold = float(threshold)
1446
-
1447
- # # # # # Initialize status message
1448
- # # # # status = ""
1449
-
1450
- # # # # if deduplication_type == "Single dataset":
1451
- # # # # # Load Dataset 1
1452
- # # # # status = "Loading Dataset 1..."
1453
- # # # # yield status, ""
1454
- # # # # if (
1455
- # # # # dataset1_name == default_dataset1_name
1456
- # # # # and dataset1_split == default_dataset1_split
1457
- # # # # ):
1458
- # # # # ds = ds_default1
1459
- # # # # else:
1460
- # # # # ds = load_dataset(dataset1_name, split=dataset1_split)
1461
-
1462
- # # # # # Extract texts
1463
- # # # # status = "Extracting texts from Dataset 1..."
1464
- # # # # yield status, ""
1465
- # # # # texts = [example[dataset1_text_column] for example in ds]
1466
- # # # # # Compute embeddings
1467
- # # # # status = "Computing embeddings for Dataset 1..."
1468
- # # # # yield status, ""
1469
- # # # # embedding_matrix = encode_texts(texts, progress=progress)
1470
- # # # # #embedding_matrix = model.encode(texts, show_progressbar=True)
1471
- # # # # # embedding_matrix = compute_embeddings(
1472
- # # # # # texts,
1473
- # # # # # batch_size=64,
1474
- # # # # # progress=progress,
1475
- # # # # # desc="Computing embeddings for Dataset 1",
1476
- # # # # # )
1477
-
1478
- # # # # # Deduplicate
1479
- # # # # status = "Deduplicating embeddings..."
1480
- # # # # yield status, ""
1481
- # # # # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
1482
- # # # # embedding_matrix, threshold, progress=progress
1483
- # # # # )
1484
-
1485
- # # # # # Prepare the results
1486
- # # # # num_duplicates = len(duplicate_to_original_mapping)
1487
- # # # # num_total = len(texts)
1488
- # # # # num_deduplicated = len(deduplicated_indices)
1489
-
1490
- # # # # result_text = f"**Total documents:** {num_total}\n"
1491
- # # # # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
1492
- # # # # result_text += (
1493
- # # # # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
1494
- # # # # )
1495
-
1496
- # # # # # Show deduplicated examples
1497
- # # # # if num_duplicates > 0:
1498
- # # # # result_text += "**Examples of duplicates found:**\n\n"
1499
- # # # # num_examples = min(5, num_duplicates)
1500
- # # # # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
1501
- # # # # original_text = texts[original_idx]
1502
- # # # # duplicate_text = texts[duplicate_idx]
1503
- # # # # differences = display_word_differences(original_text, duplicate_text)
1504
- # # # # result_text += f"**Original text:**\n{original_text}\n\n"
1505
- # # # # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
1506
- # # # # result_text += f"**Differences:**\n{differences}\n"
1507
- # # # # result_text += "-" * 50 + "\n\n"
1508
- # # # # else:
1509
- # # # # result_text += "No duplicates found."
1510
-
1511
- # # # # # Final status
1512
- # # # # status = "Deduplication completed."
1513
- # # # # yield status, result_text
1514
-
1515
- # # # # elif deduplication_type == "Cross-dataset":
1516
- # # # # # Similar code for cross-dataset deduplication
1517
- # # # # # Load Dataset 1
1518
- # # # # status = "Loading Dataset 1..."
1519
- # # # # yield status, ""
1520
- # # # # if (
1521
- # # # # dataset1_name == default_dataset1_name
1522
- # # # # and dataset1_split == default_dataset1_split
1523
- # # # # ):
1524
- # # # # ds1 = ds_default1
1525
- # # # # else:
1526
- # # # # ds1 = load_dataset(dataset1_name, split=dataset1_split)
1527
-
1528
- # # # # # Load Dataset 2
1529
- # # # # status = "Loading Dataset 2..."
1530
- # # # # yield status, ""
1531
- # # # # if (
1532
- # # # # dataset2_name == default_dataset2_name
1533
- # # # # and dataset2_split == default_dataset2_split
1534
- # # # # ):
1535
- # # # # ds2 = ds_default2
1536
- # # # # else:
1537
- # # # # ds2 = load_dataset(dataset2_name, split=dataset2_split)
1538
-
1539
- # # # # # Extract texts from Dataset 1
1540
- # # # # status = "Extracting texts from Dataset 1..."
1541
- # # # # yield status, ""
1542
- # # # # texts1 = [example[dataset1_text_column] for example in ds1]
1543
-
1544
- # # # # # Extract texts from Dataset 2
1545
- # # # # status = "Extracting texts from Dataset 2..."
1546
- # # # # yield status, ""
1547
- # # # # texts2 = [example[dataset2_text_column] for example in ds2]
1548
-
1549
- # # # # # Compute embeddings for Dataset 1
1550
- # # # # status = "Computing embeddings for Dataset 1..."
1551
- # # # # yield status, ""
1552
- # # # # embedding_matrix1 = compute_embeddings(
1553
- # # # # texts1,
1554
- # # # # batch_size=64,
1555
- # # # # progress=progress,
1556
- # # # # desc="Computing embeddings for Dataset 1",
1557
- # # # # )
1558
-
1559
- # # # # # Compute embeddings for Dataset 2
1560
- # # # # status = "Computing embeddings for Dataset 2..."
1561
- # # # # yield status, ""
1562
- # # # # embedding_matrix2 = compute_embeddings(
1563
- # # # # texts2,
1564
- # # # # batch_size=64,
1565
- # # # # progress=progress,
1566
- # # # # desc="Computing embeddings for Dataset 2",
1567
- # # # # )
1568
-
1569
- # # # # # Deduplicate across datasets
1570
- # # # # status = "Deduplicating embeddings across datasets..."
1571
- # # # # yield status, ""
1572
- # # # # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
1573
- # # # # embedding_matrix1, embedding_matrix2, threshold, progress=progress
1574
- # # # # )
1575
-
1576
- # # # # num_duplicates = len(duplicate_indices_in_ds2)
1577
- # # # # num_total_ds2 = len(texts2)
1578
- # # # # num_unique_ds2 = num_total_ds2 - num_duplicates
1579
-
1580
- # # # # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
1581
- # # # # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
1582
- # # # # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
1583
-
1584
- # # # # # Show deduplicated examples
1585
- # # # # if num_duplicates > 0:
1586
- # # # # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
1587
- # # # # num_examples = min(5, num_duplicates)
1588
- # # # # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
1589
- # # # # original_idx = duplicate_to_original_mapping[duplicate_idx]
1590
- # # # # original_text = texts1[original_idx]
1591
- # # # # duplicate_text = texts2[duplicate_idx]
1592
- # # # # differences = display_word_differences(original_text, duplicate_text)
1593
- # # # # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
1594
- # # # # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
1595
- # # # # result_text += f"**Differences:**\n{differences}\n"
1596
- # # # # result_text += "-" * 50 + "\n\n"
1597
- # # # # else:
1598
- # # # # result_text += "No duplicates found."
1599
-
1600
- # # # # # Final status
1601
- # # # # status = "Deduplication completed."
1602
- # # # # yield status, result_text
1603
-
1604
- # # # # except Exception as e:
1605
- # # # # yield f"An error occurred: {e}", ""
1606
- # # # # raise e
1607
-
1608
- # # # # def deduplicate_across_datasets(
1609
- # # # # embedding_matrix_1: np.ndarray,
1610
- # # # # embedding_matrix_2: np.ndarray,
1611
- # # # # threshold: float,
1612
- # # # # batch_size: int = 1024,
1613
- # # # # progress=None
1614
- # # # # ) -> tuple[list[int], dict[int, int]]:
1615
- # # # # # Building the index from Dataset 1
1616
- # # # # progress(0, desc="Building search index from Dataset 1...")
1617
- # # # # reach = Reach(
1618
- # # # # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
1619
- # # # # )
1620
-
1621
- # # # # duplicate_indices_in_test = []
1622
- # # # # duplicate_to_original_mapping = {}
1623
-
1624
- # # # # # Finding nearest neighbors between datasets
1625
- # # # # progress(0, desc="Finding nearest neighbors between datasets...")
1626
- # # # # results = reach.nearest_neighbor_threshold(
1627
- # # # # embedding_matrix_2,
1628
- # # # # threshold=threshold,
1629
- # # # # batch_size=batch_size,
1630
- # # # # show_progressbar=False, # Disable internal progress bar
1631
- # # # # )
1632
-
1633
- # # # # total_items = len(embedding_matrix_2)
1634
- # # # # # Processing duplicates with a progress bar
1635
- # # # # for i, similar_items in enumerate(
1636
- # # # # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
1637
- # # # # ):
1638
- # # # # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
1639
-
1640
- # # # # if similar_indices:
1641
- # # # # duplicate_indices_in_test.append(i)
1642
- # # # # duplicate_to_original_mapping[i] = similar_indices[0]
1643
-
1644
- # # # # return duplicate_indices_in_test, duplicate_to_original_mapping
1645
-
1646
- # # # # # Adjust the height of the status_output component using custom CSS
1647
- # # # # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
1648
- # # # # gr.Markdown("# Semantic Deduplication")
1649
-
1650
- # # # # deduplication_type = gr.Radio(
1651
- # # # # choices=["Single dataset", "Cross-dataset"],
1652
- # # # # label="Deduplication Type",
1653
- # # # # value="Single dataset",
1654
- # # # # )
1655
-
1656
- # # # # with gr.Row():
1657
- # # # # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
1658
- # # # # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
1659
- # # # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1660
-
1661
- # # # # dataset2_inputs = gr.Column(visible=False)
1662
- # # # # with dataset2_inputs:
1663
- # # # # gr.Markdown("### Dataset 2")
1664
- # # # # with gr.Row():
1665
- # # # # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
1666
- # # # # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
1667
- # # # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
1668
-
1669
- # # # # threshold = gr.Slider(
1670
- # # # # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
1671
- # # # # )
1672
-
1673
- # # # # compute_button = gr.Button("Compute")
1674
-
1675
- # # # # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
1676
- # # # # status_output = gr.Markdown(elem_id="status_output")
1677
- # # # # result_output = gr.Markdown()
1678
-
1679
- # # # # # Function to update the visibility of dataset2_inputs
1680
- # # # # def update_visibility(deduplication_type_value):
1681
- # # # # if deduplication_type_value == "Cross-dataset":
1682
- # # # # return gr.update(visible=True)
1683
- # # # # else:
1684
- # # # # return gr.update(visible=False)
1685
-
1686
- # # # # deduplication_type.change(
1687
- # # # # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
1688
- # # # # )
1689
-
1690
- # # # # compute_button.click(
1691
- # # # # fn=perform_deduplication,
1692
- # # # # inputs=[
1693
- # # # # deduplication_type,
1694
- # # # # dataset1_name,
1695
- # # # # dataset1_split,
1696
- # # # # dataset1_text_column,
1697
- # # # # dataset2_name,
1698
- # # # # dataset2_split,
1699
- # # # # dataset2_text_column,
1700
- # # # # threshold,
1701
- # # # # ],
1702
- # # # # outputs=[status_output, result_output],
1703
- # # # # )
1704
-
1705
- # # # # demo.launch()
 
5
  from reach import Reach
6
  from difflib import ndiff
7
 
8
+ # Load the model
9
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
10
 
11
+ # Default parameters
12
+ default_dataset_name = "sst2"
13
+ default_dataset_split = "train"
 
 
14
  default_text_column = "sentence"
15
  default_threshold = 0.9
16
 
 
 
 
 
17
  def batch_iterable(iterable, batch_size):
18
+ """Yield successive batches from an iterable."""
19
  for i in range(0, len(iterable), batch_size):
20
  yield iterable[i:i + batch_size]
21
 
22
+ def compute_embeddings(texts, batch_size, progress, desc):
23
+ """Compute embeddings for a list of texts with progress tracking."""
24
  embeddings = []
25
  total_batches = (len(texts) + batch_size - 1) // batch_size
26
  for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
27
+ embeddings.append(model.encode(batch_texts, show_progressbar=False))
 
28
  progress((i + 1) / total_batches, desc=desc)
29
  return np.concatenate(embeddings, axis=0)
30
 
31
+ def deduplicate_embeddings(
32
+ embeddings_a: np.ndarray,
33
+ embeddings_b: np.ndarray = None,
34
+ threshold: float = 0.9,
35
  batch_size: int = 1024,
36
  progress=None
37
+ ):
38
+ """Deduplicate within one dataset or across two datasets."""
39
+ if embeddings_b is None:
40
+ reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
41
+ duplicate_to_original = {}
42
+ results = reach.nearest_neighbor_threshold(
43
+ embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
44
+ )
45
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
46
+ for sim_idx, _ in similar_items:
47
+ sim_idx = int(sim_idx)
48
+ if sim_idx != i and sim_idx not in duplicate_to_original:
49
+ duplicate_to_original[sim_idx] = i
50
+ deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
51
+ return deduplicated_indices, duplicate_to_original
52
+ else:
53
+ reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
54
+ duplicate_indices_in_b = []
55
+ duplicate_to_original = {}
56
+ results = reach.nearest_neighbor_threshold(
57
+ embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
58
+ )
59
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
60
+ if similar_items:
61
+ duplicate_indices_in_b.append(i)
62
+ duplicate_to_original[i] = int(similar_items[0][0])
63
+ return duplicate_indices_in_b, duplicate_to_original
 
 
 
 
 
 
 
 
64
 
65
  def display_word_differences(x: str, y: str) -> str:
66
+ """Display differences between two texts."""
67
  diff = ndiff(x.split(), y.split())
68
+ return " ".join(word for word in diff if word.startswith(("+", "-")))
69
+
70
+ def load_dataset_texts(dataset_name, dataset_split, text_column):
71
+ """Load texts from a specified dataset."""
72
+ ds = load_dataset(dataset_name, split=dataset_split)
73
+ return [example[text_column] for example in ds]
74
 
75
  def perform_deduplication(
76
  deduplication_type,
 
84
  progress=gr.Progress(track_tqdm=True),
85
  ):
86
  try:
 
87
  threshold = float(threshold)
88
 
89
+ # Load and process Dataset 1
90
+ yield "Loading Dataset 1...", ""
91
+ texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
92
+ yield "Computing embeddings for Dataset 1...", ""
93
+ embeddings1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Dataset 1 embeddings")
94
 
95
  if deduplication_type == "Single dataset":
96
+ # Deduplicate within Dataset 1
97
+ yield "Deduplicating within Dataset 1...", ""
98
+ deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
99
+ embeddings1, threshold=threshold, progress=progress
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
 
102
+ num_duplicates = len(duplicate_mapping)
103
+ result_text = (
104
+ f"**Total documents:** {len(texts1)}\n"
105
+ f"**Duplicates found:** {num_duplicates}\n"
106
+ f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
 
 
 
 
107
  )
108
 
 
109
  if num_duplicates > 0:
110
+ result_text += "**Sample duplicates:**\n\n"
111
+ for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
112
+ orig_text = texts1[orig_idx]
113
+ dup_text = texts1[dup_idx]
114
+ differences = display_word_differences(orig_text, dup_text)
115
+ result_text += (
116
+ f"**Original:**\n{orig_text}\n\n"
117
+ f"**Duplicate:**\n{dup_text}\n\n"
118
+ f"**Differences:**\n{differences}\n"
119
+ + "-" * 50 + "\n\n"
120
+ )
121
  else:
122
  result_text += "No duplicates found."
123
 
124
+ yield "Deduplication completed.", result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ else:
127
+ # Load and process Dataset 2
128
+ yield "Loading Dataset 2...", ""
129
+ texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
130
+ yield "Computing embeddings for Dataset 2...", ""
131
+ embeddings2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Dataset 2 embeddings")
132
+
133
+ # Deduplicate Dataset 2 against Dataset 1
134
+ yield "Deduplicating Dataset 2 against Dataset 1...", ""
135
+ duplicate_indices, duplicate_mapping = deduplicate_embeddings(
136
+ embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
137
  )
138
 
139
+ num_duplicates = len(duplicate_indices)
140
+ result_text = (
141
+ f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n"
142
+ f"**Duplicates found in Dataset 2:** {num_duplicates}\n"
143
+ f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
144
  )
145
 
 
 
 
 
 
 
 
 
 
146
  if num_duplicates > 0:
147
+ result_text += "**Sample duplicates from Dataset 2:**\n\n"
148
+ for idx in duplicate_indices[:5]:
149
+ orig_text = texts1[duplicate_mapping[idx]]
150
+ dup_text = texts2[idx]
151
+ differences = display_word_differences(orig_text, dup_text)
152
+ result_text += (
153
+ f"**Original (Dataset 1):**\n{orig_text}\n\n"
154
+ f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
155
+ f"**Differences:**\n{differences}\n"
156
+ + "-" * 50 + "\n\n"
157
+ )
158
  else:
159
  result_text += "No duplicates found."
160
 
161
+ yield "Deduplication completed.", result_text
 
 
162
 
163
  except Exception as e:
164
  yield f"An error occurred: {e}", ""
165
  raise e
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
168
  gr.Markdown("# Semantic Deduplication")
169
 
 
174
  )
175
 
176
  with gr.Row():
177
+ dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
178
+ dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
179
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
180
 
181
  dataset2_inputs = gr.Column(visible=False)
182
  with dataset2_inputs:
183
  gr.Markdown("### Dataset 2")
184
  with gr.Row():
185
+ dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
186
+ dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
187
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
188
 
189
+ threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
 
 
 
190
  compute_button = gr.Button("Compute")
 
 
191
  status_output = gr.Markdown(elem_id="status_output")
192
  result_output = gr.Markdown()
193
 
194
+ def update_visibility(choice):
195
+ return gr.update(visible=choice == "Cross-dataset")
 
 
 
 
196
 
197
+ deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
 
 
198
 
199
  compute_button.click(
200
  fn=perform_deduplication,
 
213
 
214
  demo.launch()
215
 
 
216
  # import gradio as gr
217
  # from datasets import load_dataset
218
  # import numpy as np
219
  # from model2vec import StaticModel
220
  # from reach import Reach
221
  # from difflib import ndiff
 
222
 
223
  # # Load the model at startup
224
  # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
225
 
226
+ # # Default dataset parameters
227
  # default_dataset1_name = "sst2"
228
  # default_dataset1_split = "train"
229
  # default_dataset2_name = "sst2"
 
242
 
243
  # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
244
  # embeddings = []
245
+ # total_batches = (len(texts) + batch_size - 1) // batch_size
246
+ # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
247
+ # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
248
  # embeddings.append(batch_embeddings)
249
+ # progress((i + 1) / total_batches, desc=desc)
250
  # return np.concatenate(embeddings, axis=0)
251
 
252
+ # def deduplicate(
253
+ # embedding_matrix: np.ndarray,
254
+ # threshold: float,
255
+ # batch_size: int = 1024,
256
+ # progress=None
257
+ # ) -> tuple[np.ndarray, dict[int, int]]:
258
+ # # Building the index
259
+ # progress(0, desc="Building search index...")
260
+ # reach = Reach(
261
+ # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
262
+ # )
263
 
264
  # deduplicated_indices = set(range(len(embedding_matrix)))
265
  # duplicate_to_original_mapping = {}
266
 
267
+ # # Finding nearest neighbors
268
+ # progress(0, desc="Finding nearest neighbors...")
269
  # results = reach.nearest_neighbor_threshold(
270
  # embedding_matrix,
271
  # threshold=threshold,
272
  # batch_size=batch_size,
273
+ # show_progressbar=False, # Disable internal progress bar
274
  # )
275
 
276
+ # # Processing duplicates with a progress bar
277
  # total_items = len(embedding_matrix)
278
+ # for i, similar_items in enumerate(
279
+ # progress.tqdm(results, desc="Processing duplicates", total=total_items)
280
+ # ):
281
  # if i not in deduplicated_indices:
282
  # continue
283
 
 
290
 
291
  # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  # def display_word_differences(x: str, y: str) -> str:
294
  # diff = ndiff(x.split(), y.split())
295
+ # return " ".join([word for word in diff if word.startswith(("+", "-"))])
296
 
297
  # def perform_deduplication(
298
  # deduplication_type,
 
303
  # dataset2_split="",
304
  # dataset2_text_column="",
305
  # threshold=default_threshold,
306
+ # progress=gr.Progress(track_tqdm=True),
307
  # ):
308
  # try:
309
+ # # Convert threshold to float
310
  # threshold = float(threshold)
311
 
312
+ # # Initialize status message
313
+ # status = ""
314
+
315
  # if deduplication_type == "Single dataset":
316
+ # # Load Dataset 1
317
+ # status = "Loading Dataset 1..."
318
+ # yield status, ""
319
+ # if (
320
+ # dataset1_name == default_dataset1_name
321
+ # and dataset1_split == default_dataset1_split
322
+ # ):
323
+ # ds = ds_default1
324
+ # else:
325
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
326
 
327
+ # # Extract texts
328
+ # status = "Extracting texts from Dataset 1..."
329
+ # yield status, ""
330
+ # texts = [example[dataset1_text_column] for example in ds]
331
 
332
+ # # Compute embeddings
333
+ # status = "Computing embeddings for Dataset 1..."
334
+ # yield status, ""
335
+ # embedding_matrix = compute_embeddings(
336
+ # texts,
337
+ # batch_size=64,
338
+ # progress=progress,
339
+ # desc="Computing embeddings for Dataset 1",
340
+ # )
341
+
342
+ # # Deduplicate
343
+ # status = "Deduplicating embeddings..."
344
+ # yield status, ""
345
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
346
+ # embedding_matrix, threshold, progress=progress
347
+ # )
348
+
349
+ # # Prepare the results
350
  # num_duplicates = len(duplicate_to_original_mapping)
351
  # num_total = len(texts)
352
  # num_deduplicated = len(deduplicated_indices)
353
 
354
  # result_text = f"**Total documents:** {num_total}\n"
355
  # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
356
+ # result_text += (
357
+ # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
358
+ # )
359
 
360
+ # # Show deduplicated examples
361
  # if num_duplicates > 0:
362
  # result_text += "**Examples of duplicates found:**\n\n"
363
  # num_examples = min(5, num_duplicates)
 
372
  # else:
373
  # result_text += "No duplicates found."
374
 
375
+ # # Final status
376
+ # status = "Deduplication completed."
377
+ # yield status, result_text
378
 
379
  # elif deduplication_type == "Cross-dataset":
380
+ # # Similar code for cross-dataset deduplication
381
+ # # Load Dataset 1
382
+ # status = "Loading Dataset 1..."
383
+ # yield status, ""
384
+ # if (
385
+ # dataset1_name == default_dataset1_name
386
+ # and dataset1_split == default_dataset1_split
387
+ # ):
388
+ # ds1 = ds_default1
389
+ # else:
390
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
391
+
392
+ # # Load Dataset 2
393
+ # status = "Loading Dataset 2..."
394
+ # yield status, ""
395
+ # if (
396
+ # dataset2_name == default_dataset2_name
397
+ # and dataset2_split == default_dataset2_split
398
+ # ):
399
+ # ds2 = ds_default2
400
+ # else:
401
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
402
 
403
+ # # Extract texts from Dataset 1
404
+ # status = "Extracting texts from Dataset 1..."
405
+ # yield status, ""
406
  # texts1 = [example[dataset1_text_column] for example in ds1]
 
407
 
408
+ # # Extract texts from Dataset 2
409
+ # status = "Extracting texts from Dataset 2..."
410
+ # yield status, ""
411
+ # texts2 = [example[dataset2_text_column] for example in ds2]
412
 
413
+ # # Compute embeddings for Dataset 1
414
+ # status = "Computing embeddings for Dataset 1..."
415
+ # yield status, ""
416
+ # embedding_matrix1 = compute_embeddings(
417
+ # texts1,
418
+ # batch_size=64,
419
+ # progress=progress,
420
+ # desc="Computing embeddings for Dataset 1",
421
+ # )
422
+
423
+ # # Compute embeddings for Dataset 2
424
+ # status = "Computing embeddings for Dataset 2..."
425
+ # yield status, ""
426
+ # embedding_matrix2 = compute_embeddings(
427
+ # texts2,
428
+ # batch_size=64,
429
+ # progress=progress,
430
+ # desc="Computing embeddings for Dataset 2",
431
+ # )
432
+
433
+ # # Deduplicate across datasets
434
+ # status = "Deduplicating embeddings across datasets..."
435
+ # yield status, ""
436
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
437
+ # embedding_matrix1, embedding_matrix2, threshold, progress=progress
438
+ # )
439
 
440
  # num_duplicates = len(duplicate_indices_in_ds2)
441
  # num_total_ds2 = len(texts2)
 
445
  # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
446
  # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
447
 
448
+ # # Show deduplicated examples
449
  # if num_duplicates > 0:
450
  # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
451
  # num_examples = min(5, num_duplicates)
 
461
  # else:
462
  # result_text += "No duplicates found."
463
 
464
+ # # Final status
465
+ # status = "Deduplication completed."
466
+ # yield status, result_text
467
 
468
  # except Exception as e:
469
  # yield f"An error occurred: {e}", ""
470
+ # raise e
471
+
472
+ # def deduplicate_across_datasets(
473
+ # embedding_matrix_1: np.ndarray,
474
+ # embedding_matrix_2: np.ndarray,
475
+ # threshold: float,
476
+ # batch_size: int = 1024,
477
+ # progress=None
478
+ # ) -> tuple[list[int], dict[int, int]]:
479
+ # # Building the index from Dataset 1
480
+ # progress(0, desc="Building search index from Dataset 1...")
481
+ # reach = Reach(
482
+ # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
483
+ # )
484
+
485
+ # duplicate_indices_in_test = []
486
+ # duplicate_to_original_mapping = {}
487
+
488
+ # # Finding nearest neighbors between datasets
489
+ # progress(0, desc="Finding nearest neighbors between datasets...")
490
+ # results = reach.nearest_neighbor_threshold(
491
+ # embedding_matrix_2,
492
+ # threshold=threshold,
493
+ # batch_size=batch_size,
494
+ # show_progressbar=False, # Disable internal progress bar
495
+ # )
496
 
497
+ # total_items = len(embedding_matrix_2)
498
+ # # Processing duplicates with a progress bar
499
+ # for i, similar_items in enumerate(
500
+ # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
501
+ # ):
502
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
503
+
504
+ # if similar_indices:
505
+ # duplicate_indices_in_test.append(i)
506
+ # duplicate_to_original_mapping[i] = similar_indices[0]
507
+
508
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
509
+
510
+ # # Adjust the height of the status_output component using custom CSS
511
+ # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
512
  # gr.Markdown("# Semantic Deduplication")
513
 
514
  # deduplication_type = gr.Radio(
515
  # choices=["Single dataset", "Cross-dataset"],
516
  # label="Deduplication Type",
517
+ # value="Single dataset",
518
  # )
519
 
520
  # with gr.Row():
 
531
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
532
 
533
  # threshold = gr.Slider(
534
+ # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
 
 
 
535
  # )
536
 
537
  # compute_button = gr.Button("Compute")
538
 
539
+ # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
540
  # status_output = gr.Markdown(elem_id="status_output")
541
+ # result_output = gr.Markdown()
542
 
543
+ # # Function to update the visibility of dataset2_inputs
544
  # def update_visibility(deduplication_type_value):
545
  # if deduplication_type_value == "Cross-dataset":
546
  # return gr.update(visible=True)
 
548
  # return gr.update(visible=False)
549
 
550
  # deduplication_type.change(
551
+ # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
 
552
  # )
553
 
554
  # compute_button.click(
 
561
  # dataset2_name,
562
  # dataset2_split,
563
  # dataset2_text_column,
564
+ # threshold,
565
  # ],
566
+ # outputs=[status_output, result_output],
567
  # )
568
 
569
  # demo.launch()