File size: 8,416 Bytes
25d2eb7
2827b8a
 
39a5b1c
2827b8a
7a1cd7a
a81fb12
95530b9
39a5b1c
f5eb405
95530b9
 
 
3b4c438
 
f5eb405
95530b9
 
 
 
c58907b
 
95530b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58d8f1a
7a1cd7a
95530b9
7a1cd7a
95530b9
 
 
 
 
 
7a1cd7a
4f0286f
2827b8a
 
 
7a1cd7a
7ed3881
 
 
3b4c438
c58907b
2827b8a
f5eb405
 
3bd0812
95530b9
 
 
 
f39d105
 
f5eb405
95530b9
 
 
 
c58907b
 
95530b9
 
2a0be82
 
95530b9
c58907b
3bd0812
5422464
95530b9
 
 
 
 
 
 
 
 
 
 
5422464
 
3bd0812
95530b9
c58907b
95530b9
 
 
 
 
f39d105
 
95530b9
 
 
 
c58907b
 
95530b9
 
2a0be82
 
95530b9
c58907b
39a5b1c
 
95530b9
 
 
 
 
 
 
 
 
 
 
39a5b1c
 
 
95530b9
f5eb405
6b0e834
39a5b1c
c58907b
 
 
e49e0e9
4f0286f
 
 
 
c58907b
4f0286f
 
 
95530b9
 
4f0286f
 
 
 
 
 
95530b9
 
4f0286f
 
95530b9
4f0286f
1a5f99b
c58907b
4f0286f
95530b9
 
4f0286f
95530b9
4f0286f
 
 
 
 
 
 
 
 
 
 
c58907b
4f0286f
c58907b
4f0286f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from difflib import ndiff

# Load the model
model = StaticModel.from_pretrained("minishlab/M2V_base_output")

# Default parameters
default_dataset_name = "sst2"
default_dataset_split = "train"
default_text_column = "sentence"
default_threshold = 0.9

def deduplicate_embeddings(
    embeddings_a: np.ndarray,
    embeddings_b: np.ndarray = None,
    threshold: float = 0.9,
    batch_size: int = 1024,
    progress=None
):
    """Deduplicate within one dataset or across two datasets."""
    if embeddings_b is None:
        reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
        duplicate_to_original = {}
        results = reach.nearest_neighbor_threshold(
            embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
        )
        for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
            for sim_idx, _ in similar_items:
                sim_idx = int(sim_idx)
                if sim_idx != i and sim_idx not in duplicate_to_original:
                    duplicate_to_original[sim_idx] = i
        deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
        return deduplicated_indices, duplicate_to_original
    else:
        reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
        duplicate_indices_in_b = []
        duplicate_to_original = {}
        results = reach.nearest_neighbor_threshold(
            embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
        )
        for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
            if similar_items:
                duplicate_indices_in_b.append(i)
                duplicate_to_original[i] = int(similar_items[0][0])
        return duplicate_indices_in_b, duplicate_to_original

def display_word_differences(x: str, y: str) -> str:
    """Display differences between two texts."""
    diff = ndiff(x.split(), y.split())
    return " ".join(word for word in diff if word.startswith(("+", "-")))

def load_dataset_texts(dataset_name, dataset_split, text_column):
    """Load texts from a specified dataset."""
    ds = load_dataset(dataset_name, split=dataset_split)
    return [example[text_column] for example in ds]

def perform_deduplication(
    deduplication_type,
    dataset1_name,
    dataset1_split,
    dataset1_text_column,
    dataset2_name="",
    dataset2_split="",
    dataset2_text_column="",
    threshold=default_threshold,
    progress=gr.Progress(track_tqdm=True),
):
    try:
        threshold = float(threshold)

        # Load and process Dataset 1
        yield "Loading Dataset 1...", ""
        texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
        yield "Computing embeddings for Dataset 1...", ""
        #embeddings1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Dataset 1 embeddings")
        embeddings1 = model.encode(texts1, show_progressbar=True)
        if deduplication_type == "Single dataset":
            # Deduplicate within Dataset 1
            yield "Deduplicating within Dataset 1...", ""
            deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
                embeddings1, threshold=threshold, progress=progress
            )

            num_duplicates = len(duplicate_mapping)
            result_text = (
                f"**Total documents:** {len(texts1)}\n\n"
                f"**Duplicates found:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
            )

            if num_duplicates > 0:
                result_text += "**Sample duplicates:**\n\n"
                for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
                    orig_text = texts1[orig_idx]
                    dup_text = texts1[dup_idx]
                    differences = display_word_differences(orig_text, dup_text)
                    result_text += (
                        f"**Original:**\n{orig_text}\n\n"
                        f"**Duplicate:**\n{dup_text}\n\n"
                        f"**Differences:**\n{differences}\n"
                        + "-" * 50 + "\n\n"
                    )
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

        else:
            # Load and process Dataset 2
            yield "Loading Dataset 2...", ""
            texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
            yield "Computing embeddings for Dataset 2...", ""
            #embeddings2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Dataset 2 embeddings")
            embeddings2 = model.encode(texts2, show_progressbar=True)
            # Deduplicate Dataset 2 against Dataset 1
            yield "Deduplicating Dataset 2 against Dataset 1...", ""
            duplicate_indices, duplicate_mapping = deduplicate_embeddings(
                embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
            )

            num_duplicates = len(duplicate_indices)
            result_text = (
                f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
                f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
            )

            if num_duplicates > 0:
                result_text += "**Sample duplicates from Dataset 2:**\n\n"
                for idx in duplicate_indices[:5]:
                    orig_text = texts1[duplicate_mapping[idx]]
                    dup_text = texts2[idx]
                    differences = display_word_differences(orig_text, dup_text)
                    result_text += (
                        f"**Original (Dataset 1):**\n{orig_text}\n\n"
                        f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
                        f"**Differences:**\n{differences}\n"
                        + "-" * 50 + "\n\n"
                    )
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

    except Exception as e:
        yield f"An error occurred: {e}", ""
        raise e

with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
    gr.Markdown("# Semantic Deduplication")

    deduplication_type = gr.Radio(
        choices=["Single dataset", "Cross-dataset"],
        label="Deduplication Type",
        value="Single dataset",
    )

    with gr.Row():
        dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
        dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
        dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    dataset2_inputs = gr.Column(visible=False)
    with dataset2_inputs:
        gr.Markdown("### Dataset 2")
        with gr.Row():
            dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
            dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
            dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
    compute_button = gr.Button("Compute")
    status_output = gr.Markdown(elem_id="status_output")
    result_output = gr.Markdown()

    def update_visibility(choice):
        return gr.update(visible=choice == "Cross-dataset")

    deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)

    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type,
            dataset1_name,
            dataset1_split,
            dataset1_text_column,
            dataset2_name,
            dataset2_split,
            dataset2_text_column,
            threshold,
        ],
        outputs=[status_output, result_output],
    )

demo.launch()