Add entire pipeline
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- UI/__pycache__/gradio.cpython-310.pyc +0 -0
- UI/__pycache__/gradio.cpython-311.pyc +0 -0
- UI/gradio.py +516 -0
- __pycache__/app.cpython-310.pyc +0 -0
- app.py +21 -0
- environment.yml +245 -0
- metrics/distortion.py +370 -0
- renderers/__pycache__/highlighter.cpython-310.pyc +0 -0
- renderers/__pycache__/highlighter.cpython-311.pyc +0 -0
- renderers/__pycache__/plot_3d.cpython-310.pyc +0 -0
- renderers/__pycache__/plot_3d.cpython-311.pyc +0 -0
- renderers/__pycache__/tree.cpython-310.pyc +0 -0
- renderers/__pycache__/tree.cpython-311.pyc +0 -0
- renderers/highlighter.py +162 -0
- renderers/plot_3d.py +126 -0
- renderers/tree.py +490 -0
- utils/__init__.py +5 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/config.cpython-310.pyc +0 -0
- utils/__pycache__/config.cpython-311.pyc +0 -0
- utils/__pycache__/entailment.cpython-310.pyc +0 -0
- utils/__pycache__/entailment.cpython-311.pyc +0 -0
- utils/__pycache__/masking_methods.cpython-310.pyc +0 -0
- utils/__pycache__/masking_methods.cpython-311.pyc +0 -0
- utils/__pycache__/non_melting_point.cpython-310.pyc +0 -0
- utils/__pycache__/non_melting_point.cpython-311.pyc +0 -0
- utils/__pycache__/paraphraser.cpython-310.pyc +0 -0
- utils/__pycache__/paraphraser.cpython-311.pyc +0 -0
- utils/__pycache__/sampling.cpython-310.pyc +0 -0
- utils/__pycache__/sampling.cpython-311.pyc +0 -0
- utils/__pycache__/watermark.cpython-310.pyc +0 -0
- utils/__pycache__/watermark.cpython-311.pyc +0 -0
- utils/config.py +18 -0
- utils/config.yaml +48 -0
- utils/entailment.py +107 -0
- utils/masking_methods.py +304 -0
- utils/non_melting_point.py +137 -0
- utils/old/masking/masking_methods.py +355 -0
- utils/old/masking/masking_methods_new_work.py +447 -0
- utils/old/masking/masking_methods_ok_working.py +257 -0
- utils/old/masking/masking_methods_v1_working.py +233 -0
- utils/old/masking_methods_final_copy.py +619 -0
- utils/old/non_melting_points_v1.py +244 -0
- utils/old/sampling/sampling.py +330 -0
- utils/old/sampling/sampling_methods.py +291 -0
- utils/old/sampling/sampling_methods_v1.py +146 -0
- utils/old/sampling/sampling_methods_v2.py +112 -0
- utils/old/sampling_final_copy.py +168 -0
- utils/paraphraser.py +75 -0
UI/__pycache__/gradio.cpython-310.pyc
ADDED
Binary file (6.61 kB). View file
|
|
UI/__pycache__/gradio.cpython-311.pyc
ADDED
Binary file (27.3 kB). View file
|
|
UI/gradio.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.watermark import Watermarker
|
3 |
+
from utils.config import load_config
|
4 |
+
from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
|
5 |
+
from renderers.tree import generate_subplot1, generate_subplot2
|
6 |
+
from pathlib import Path
|
7 |
+
import time
|
8 |
+
from typing import Dict, List, Tuple, Any
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
|
11 |
+
class WatermarkerInterface:
|
12 |
+
def __init__(self, config):
|
13 |
+
|
14 |
+
self.pipeline = Watermarker(config)
|
15 |
+
self.common_grams = {}
|
16 |
+
self.highlight_info = []
|
17 |
+
self.masked_sentences = []
|
18 |
+
|
19 |
+
def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]:
|
20 |
+
"""Wrapper for paraphrasing that includes highlighting"""
|
21 |
+
start_time = time.time()
|
22 |
+
|
23 |
+
# Run paraphrasing
|
24 |
+
self.pipeline.Paraphrase(prompt)
|
25 |
+
|
26 |
+
# Step 1: Process the original sentence first
|
27 |
+
seen_ngrams = {} # Stores first occurrence index of each n-gram
|
28 |
+
original_indexed_ngrams = [] # Final indexed list for original
|
29 |
+
|
30 |
+
original_sentence = self.pipeline.user_prompt
|
31 |
+
original_ngrams = self.pipeline.common_grams.get(original_sentence, {})
|
32 |
+
|
33 |
+
# Step 1.1: Extract n-grams and their first occurrence index
|
34 |
+
ngram_occurrences = [
|
35 |
+
(min(indices, key=lambda x: x[0])[0], gram) # Get first index
|
36 |
+
for gram, indices in original_ngrams.items()
|
37 |
+
]
|
38 |
+
|
39 |
+
# Step 1.2: Sort n-grams based on their first occurrence
|
40 |
+
ngram_occurrences.sort()
|
41 |
+
|
42 |
+
# Step 1.3: Assign sequential indices
|
43 |
+
for idx, (position, gram) in enumerate(ngram_occurrences, start=1):
|
44 |
+
seen_ngrams[gram] = idx # Assign sequential index
|
45 |
+
original_indexed_ngrams.append((idx, gram))
|
46 |
+
|
47 |
+
print("Original Indexed N-grams:", original_indexed_ngrams)
|
48 |
+
|
49 |
+
#generate highlight_info
|
50 |
+
colors = ["red", "blue", "green", "purple", "orange"]
|
51 |
+
highlight_info = [
|
52 |
+
(ngram, colors[i % len(colors)])
|
53 |
+
for i, (index, ngram) in enumerate(original_indexed_ngrams)
|
54 |
+
]
|
55 |
+
common_grams = original_indexed_ngrams
|
56 |
+
self.highlight_info = highlight_info
|
57 |
+
self.common_grams = common_grams
|
58 |
+
|
59 |
+
# Step 2: Process paraphrased sentences and match indices
|
60 |
+
paraphrase_indexed_ngrams = {}
|
61 |
+
|
62 |
+
for sentence in self.pipeline.paraphrased_sentences:
|
63 |
+
sentence_ngrams = [] # Stores n-grams for this sentence
|
64 |
+
sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {})
|
65 |
+
|
66 |
+
for gram, indices in sentence_ngrams_dict.items():
|
67 |
+
first_occurrence = min(indices, key=lambda x: x[0])[0]
|
68 |
+
|
69 |
+
# Use the original's index if exists, otherwise assign a new one
|
70 |
+
if gram in seen_ngrams:
|
71 |
+
index = seen_ngrams[gram] # Use the same index as original
|
72 |
+
else:
|
73 |
+
index = len(seen_ngrams) + 1 # Assign new index
|
74 |
+
seen_ngrams[gram] = index # Store it
|
75 |
+
|
76 |
+
sentence_ngrams.append((index, gram))
|
77 |
+
|
78 |
+
sentence_ngrams.sort()
|
79 |
+
paraphrase_indexed_ngrams[sentence] = sentence_ngrams
|
80 |
+
|
81 |
+
print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams)
|
82 |
+
|
83 |
+
# Step 3: Generate highlighted versions using the renderer
|
84 |
+
highlighted_prompt = highlight_common_words(
|
85 |
+
common_grams,
|
86 |
+
[self.pipeline.user_prompt],
|
87 |
+
"Original Prompt with Highlighted Common Sequences"
|
88 |
+
)
|
89 |
+
|
90 |
+
highlighted_accepted = highlight_common_words_dict(
|
91 |
+
common_grams,
|
92 |
+
self.pipeline.selected_sentences,
|
93 |
+
"Accepted Paraphrased Sentences with Entailment Scores"
|
94 |
+
)
|
95 |
+
|
96 |
+
highlighted_discarded = highlight_common_words_dict(
|
97 |
+
common_grams,
|
98 |
+
self.pipeline.discarded_sentences,
|
99 |
+
"Discarded Paraphrased Sentences with Entailment Scores"
|
100 |
+
)
|
101 |
+
|
102 |
+
execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>"
|
103 |
+
self.highlight_info = highlight_info
|
104 |
+
self.common_grams = common_grams
|
105 |
+
|
106 |
+
return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time
|
107 |
+
|
108 |
+
def handle_masking(self) -> Tuple[List[go.Figure], str]:
|
109 |
+
"""Wrapper for masking that generates visualization trees"""
|
110 |
+
start_time = time.time()
|
111 |
+
|
112 |
+
masking_results = self.pipeline.Masking()
|
113 |
+
trees = []
|
114 |
+
highlight_info = self.highlight_info
|
115 |
+
common_grams = self.common_grams
|
116 |
+
sentence_to_masked = {}
|
117 |
+
|
118 |
+
# Create a consolidated figure with all strategies
|
119 |
+
original_sentence = None
|
120 |
+
|
121 |
+
# First pass - gather all sentences and strategies
|
122 |
+
for strategy, sentence_dict in masking_results.items():
|
123 |
+
for sent, data in sentence_dict.items():
|
124 |
+
if sent not in sentence_to_masked:
|
125 |
+
sentence_to_masked[sent] = []
|
126 |
+
try:
|
127 |
+
if not isinstance(data, dict):
|
128 |
+
print(f"[ERROR] Data is not a dictionary for {sent} with strategy {strategy}")
|
129 |
+
continue
|
130 |
+
|
131 |
+
masked_sentence = data.get("masked_sentence", "")
|
132 |
+
if masked_sentence:
|
133 |
+
sentence_to_masked[sent].append((masked_sentence, strategy))
|
134 |
+
except Exception as e:
|
135 |
+
print(f"Error processing {strategy} for sentence {sent}: {e}")
|
136 |
+
|
137 |
+
for original_sentence, masked_sentences_data in sentence_to_masked.items():
|
138 |
+
if not masked_sentences_data:
|
139 |
+
continue
|
140 |
+
masked_sentences = [ms[0] for ms in masked_sentences_data]
|
141 |
+
strategies = [ms[1] for ms in masked_sentences_data]
|
142 |
+
try:
|
143 |
+
|
144 |
+
fig = generate_subplot1(
|
145 |
+
original_sentence,
|
146 |
+
masked_sentences,
|
147 |
+
strategies,
|
148 |
+
highlight_info,
|
149 |
+
common_grams
|
150 |
+
)
|
151 |
+
trees.append(fig)
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Error generating multi-strategy tree: {e}")
|
154 |
+
trees.append(go.Figure())
|
155 |
+
|
156 |
+
# Pad with empty plots if needed
|
157 |
+
while len(trees) < 10:
|
158 |
+
trees.append(go.Figure())
|
159 |
+
|
160 |
+
execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>"
|
161 |
+
|
162 |
+
return trees[:10] + [execution_time]
|
163 |
+
|
164 |
+
def handle_sampling(self) -> Tuple[List[go.Figure], str]:
|
165 |
+
"""Wrapper for sampling that generates visualization trees"""
|
166 |
+
start_time = time.time()
|
167 |
+
sampling_results = self.pipeline.Sampling()
|
168 |
+
trees = []
|
169 |
+
|
170 |
+
# Group sentences by original sentence
|
171 |
+
organized_results = {}
|
172 |
+
|
173 |
+
# Generate trees for each sampled sentence
|
174 |
+
for sampling_strategy, masking_dict in sampling_results.items():
|
175 |
+
for masking_strategy, sentences in masking_dict.items():
|
176 |
+
for original_sentence, data in sentences.items():
|
177 |
+
if original_sentence not in organized_results:
|
178 |
+
organized_results[original_sentence] = {}
|
179 |
+
|
180 |
+
if masking_strategy not in organized_results[original_sentence]:
|
181 |
+
organized_results[original_sentence][masking_strategy] = {
|
182 |
+
"masked_sentence": data.get("masked_sentence", ""), # Corrected reference
|
183 |
+
"sampled_sentences": {}
|
184 |
+
}
|
185 |
+
|
186 |
+
# Add this sampling result
|
187 |
+
organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "")
|
188 |
+
|
189 |
+
for original_sentence, data in organized_results.items():
|
190 |
+
masked_sentences = []
|
191 |
+
all_sampled_sentences = []
|
192 |
+
|
193 |
+
for masking_strategy, masking_data in list(data.items())[:3]: # Ensure this iteration is safe
|
194 |
+
masked_sentence = masking_data.get("masked_sentence", "")
|
195 |
+
if masked_sentence:
|
196 |
+
masked_sentences.append(masked_sentence)
|
197 |
+
|
198 |
+
for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items():
|
199 |
+
if sampled_sentence:
|
200 |
+
all_sampled_sentences.append(sampled_sentence)
|
201 |
+
|
202 |
+
if masked_sentences:
|
203 |
+
try:
|
204 |
+
fig = generate_subplot2(
|
205 |
+
masked_sentences,
|
206 |
+
all_sampled_sentences,
|
207 |
+
self.highlight_info,
|
208 |
+
self.common_grams
|
209 |
+
)
|
210 |
+
trees.append(fig)
|
211 |
+
except Exception as e:
|
212 |
+
print(f"Error generating subplot for {original_sentence}: {e}")
|
213 |
+
trees.append(go.Figure())
|
214 |
+
|
215 |
+
while len(trees) < 10:
|
216 |
+
trees.append(go.Figure())
|
217 |
+
|
218 |
+
execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>"
|
219 |
+
|
220 |
+
return trees[:10] + [execution_time]
|
221 |
+
|
222 |
+
def handle_reparaphrasing(self) -> Tuple[List[str], str]:
|
223 |
+
"""Wrapper for re-paraphrasing that formats results as HTML"""
|
224 |
+
start_time = time.time()
|
225 |
+
|
226 |
+
results = self.pipeline.re_paraphrasing()
|
227 |
+
html_outputs = []
|
228 |
+
|
229 |
+
# Generate HTML for each batch of re-paraphrased sentences
|
230 |
+
for sampling_strategy, masking_dict in results.items():
|
231 |
+
for masking_strategy, sentences in masking_dict.items():
|
232 |
+
for original_sent, data in sentences.items():
|
233 |
+
if data["re_paraphrased_sentences"]:
|
234 |
+
html = reparaphrased_sentences_html(data["re_paraphrased_sentences"])
|
235 |
+
html_outputs.append(html)
|
236 |
+
|
237 |
+
# Pad with empty HTML if needed
|
238 |
+
while len(html_outputs) < 120:
|
239 |
+
html_outputs.append("")
|
240 |
+
|
241 |
+
execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>"
|
242 |
+
|
243 |
+
return html_outputs[:120] + [execution_time]
|
244 |
+
|
245 |
+
|
246 |
+
def create_gradio_interface(config):
|
247 |
+
"""Creates the Gradio interface with the updated pipeline"""
|
248 |
+
interface = WatermarkerInterface(config)
|
249 |
+
|
250 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
251 |
+
#CSS to enable scrolling for reparaphrased sentences and sampling plots
|
252 |
+
demo.css = """
|
253 |
+
/* Set fixed height for the reparaphrased tabs container only */
|
254 |
+
.gradio-container .tabs[id="reparaphrased-tabs"],
|
255 |
+
.gradio-container .tabs[id="sampling-tabs"] {
|
256 |
+
overflow-x: hidden;
|
257 |
+
white-space: normal;
|
258 |
+
border-radius: 8px;
|
259 |
+
max-height: 600px; /* Set fixed height for the entire tabs component */
|
260 |
+
overflow-y: auto; /* Enable vertical scrolling inside the container */
|
261 |
+
}
|
262 |
+
|
263 |
+
/* Tab content styling for reparaphrased and sampling tabs */
|
264 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tabitem,
|
265 |
+
.gradio-container .tabs[id="sampling-tabs"] .tabitem {
|
266 |
+
overflow-x: hidden;
|
267 |
+
white-space: normal;
|
268 |
+
display: block;
|
269 |
+
border-radius: 8px;
|
270 |
+
}
|
271 |
+
|
272 |
+
/* Make the tab navigation fixed at the top for scrollable tabs */
|
273 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav,
|
274 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav {
|
275 |
+
display: flex;
|
276 |
+
overflow-x: auto;
|
277 |
+
white-space: nowrap;
|
278 |
+
scrollbar-width: thin;
|
279 |
+
border-radius: 8px;
|
280 |
+
scrollbar-color: #888 #f1f1f1;
|
281 |
+
position: sticky;
|
282 |
+
top: 0;
|
283 |
+
background: white;
|
284 |
+
z-index: 100;
|
285 |
+
}
|
286 |
+
|
287 |
+
/* Dropdown menu for scrollable tabs styling */
|
288 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown,
|
289 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown {
|
290 |
+
position: relative;
|
291 |
+
display: inline-block;
|
292 |
+
}
|
293 |
+
|
294 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content,
|
295 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content {
|
296 |
+
display: none;
|
297 |
+
position: absolute;
|
298 |
+
background-color: #f9f9f9;
|
299 |
+
min-width: 160px;
|
300 |
+
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
|
301 |
+
z-index: 1;
|
302 |
+
max-height: 300px;
|
303 |
+
overflow-y: auto;
|
304 |
+
}
|
305 |
+
|
306 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content,
|
307 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content {
|
308 |
+
display: block;
|
309 |
+
}
|
310 |
+
|
311 |
+
/* Scrollbar styling for scrollable tabs */
|
312 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar,
|
313 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar {
|
314 |
+
height: 8px;
|
315 |
+
border-radius: 8px;
|
316 |
+
}
|
317 |
+
|
318 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track,
|
319 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track {
|
320 |
+
background: #f1f1f1;
|
321 |
+
border-radius: 8px;
|
322 |
+
}
|
323 |
+
|
324 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb,
|
325 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb {
|
326 |
+
background: #888;
|
327 |
+
border-radius: 8px;
|
328 |
+
}
|
329 |
+
|
330 |
+
/* Tab button styling for scrollable tabs */
|
331 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item,
|
332 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item {
|
333 |
+
flex: 0 0 auto;
|
334 |
+
border-radius: 8px;
|
335 |
+
}
|
336 |
+
|
337 |
+
/* Plot container styling specifically for sampling tabs */
|
338 |
+
.gradio-container .tabs[id="sampling-tabs"] .plot-container {
|
339 |
+
min-height: 600px;
|
340 |
+
max-height: 1800px;
|
341 |
+
overflow-y: auto;
|
342 |
+
}
|
343 |
+
|
344 |
+
/* Ensure text wraps in HTML components */
|
345 |
+
.gradio-container .prose {
|
346 |
+
white-space: normal;
|
347 |
+
word-wrap: break-word;
|
348 |
+
overflow-wrap: break-word;
|
349 |
+
}
|
350 |
+
|
351 |
+
/* Dropdown button styling for scrollable tabs */
|
352 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button,
|
353 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button {
|
354 |
+
background-color: #f0f0f0;
|
355 |
+
border: 1px solid #ddd;
|
356 |
+
border-radius: 4px;
|
357 |
+
padding: 5px 10px;
|
358 |
+
cursor: pointer;
|
359 |
+
margin: 2px;
|
360 |
+
}
|
361 |
+
|
362 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover,
|
363 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover {
|
364 |
+
background-color: #e0e0e0;
|
365 |
+
}
|
366 |
+
|
367 |
+
/* Style dropdown content items for scrollable tabs */
|
368 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div,
|
369 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div {
|
370 |
+
padding: 8px 12px;
|
371 |
+
cursor: pointer;
|
372 |
+
}
|
373 |
+
|
374 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover,
|
375 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover {
|
376 |
+
background-color: #e0e0e0;
|
377 |
+
}
|
378 |
+
|
379 |
+
/* Custom styling for execution time display */
|
380 |
+
.execution-time {
|
381 |
+
text-align: right;
|
382 |
+
padding: 8px 16px;
|
383 |
+
font-family: inherit;
|
384 |
+
color: #555;
|
385 |
+
font-size: 0.9rem;
|
386 |
+
font-style: italic;
|
387 |
+
margin-left: auto;
|
388 |
+
width: 100%;
|
389 |
+
border-top: 1px solid #eee;
|
390 |
+
margin-top: 8px;
|
391 |
+
}
|
392 |
+
|
393 |
+
/* Layout for section headers with execution time */
|
394 |
+
.section-header {
|
395 |
+
display: flex;
|
396 |
+
justify-content: space-between;
|
397 |
+
align-items: center;
|
398 |
+
width: 100%;
|
399 |
+
margin-bottom: 12px;
|
400 |
+
}
|
401 |
+
|
402 |
+
.section-header h3 {
|
403 |
+
margin: 0;
|
404 |
+
}
|
405 |
+
"""
|
406 |
+
gr.Markdown("# **AIISC Watermarking Model**")
|
407 |
+
|
408 |
+
with gr.Column():
|
409 |
+
gr.Markdown("## Input Prompt")
|
410 |
+
user_input = gr.Textbox(
|
411 |
+
label="Enter Your Prompt",
|
412 |
+
placeholder="Type your text here..."
|
413 |
+
)
|
414 |
+
|
415 |
+
with gr.Row():
|
416 |
+
with gr.Column(scale=3):
|
417 |
+
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
|
418 |
+
with gr.Column(scale=1):
|
419 |
+
step1_time = gr.HTML()
|
420 |
+
|
421 |
+
paraphrase_button = gr.Button("Generate Paraphrases")
|
422 |
+
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
|
423 |
+
|
424 |
+
with gr.Tabs():
|
425 |
+
with gr.TabItem("Accepted Paraphrased Sentences"):
|
426 |
+
highlighted_accepted_sentences = gr.HTML()
|
427 |
+
with gr.TabItem("Discarded Paraphrased Sentences"):
|
428 |
+
highlighted_discarded_sentences = gr.HTML()
|
429 |
+
|
430 |
+
with gr.Row():
|
431 |
+
with gr.Column(scale=3):
|
432 |
+
gr.Markdown("## Step 2: Where to Mask?")
|
433 |
+
with gr.Column(scale=1):
|
434 |
+
step2_time = gr.HTML()
|
435 |
+
|
436 |
+
masking_button = gr.Button("Apply Masking")
|
437 |
+
gr.Markdown("### Masked Sentence Trees")
|
438 |
+
tree1_plots = []
|
439 |
+
with gr.Tabs() as tree1_tabs:
|
440 |
+
for i in range(10):
|
441 |
+
with gr.TabItem(f"Masked Sentence {i+1}"):
|
442 |
+
tree1 = gr.Plot()
|
443 |
+
tree1_plots.append(tree1)
|
444 |
+
|
445 |
+
with gr.Row():
|
446 |
+
with gr.Column(scale=3):
|
447 |
+
gr.Markdown("## Step 3: How to Mask?")
|
448 |
+
with gr.Column(scale=1):
|
449 |
+
step3_time = gr.HTML()
|
450 |
+
|
451 |
+
sampling_button = gr.Button("Sample Words")
|
452 |
+
gr.Markdown("### Sampled Sentence Trees")
|
453 |
+
|
454 |
+
tree2_plots = []
|
455 |
+
# Add elem_id to make this tab container scrollable
|
456 |
+
with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs:
|
457 |
+
for i in range(10):
|
458 |
+
with gr.TabItem(f"Sampled Sentence {i+1}"):
|
459 |
+
# Add a custom class to the container to enable proper styling
|
460 |
+
with gr.Column(elem_classes=["plot-container"]):
|
461 |
+
tree2 = gr.Plot()
|
462 |
+
tree2_plots.append(tree2)
|
463 |
+
|
464 |
+
with gr.Row():
|
465 |
+
with gr.Column(scale=3):
|
466 |
+
gr.Markdown("## Step 4: Re-paraphrasing")
|
467 |
+
with gr.Column(scale=1):
|
468 |
+
step4_time = gr.HTML()
|
469 |
+
|
470 |
+
reparaphrase_button = gr.Button("Re-paraphrase")
|
471 |
+
gr.Markdown("### Reparaphrased Sentences")
|
472 |
+
reparaphrased_sentences_tabs = []
|
473 |
+
with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs:
|
474 |
+
for i in range(120):
|
475 |
+
with gr.TabItem(f"Reparaphrased Batch {i+1}"):
|
476 |
+
reparaphrased_sent_html = gr.HTML()
|
477 |
+
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
|
478 |
+
|
479 |
+
# Connect the interface functions to the buttons
|
480 |
+
paraphrase_button.click(
|
481 |
+
interface.handle_paraphrase,
|
482 |
+
inputs=user_input,
|
483 |
+
outputs=[
|
484 |
+
highlighted_user_prompt,
|
485 |
+
highlighted_accepted_sentences,
|
486 |
+
highlighted_discarded_sentences,
|
487 |
+
step1_time
|
488 |
+
]
|
489 |
+
)
|
490 |
+
|
491 |
+
masking_button.click(
|
492 |
+
interface.handle_masking,
|
493 |
+
inputs=None,
|
494 |
+
outputs=tree1_plots + [step2_time]
|
495 |
+
)
|
496 |
+
|
497 |
+
sampling_button.click(
|
498 |
+
interface.handle_sampling,
|
499 |
+
inputs=None,
|
500 |
+
outputs=tree2_plots + [step3_time]
|
501 |
+
)
|
502 |
+
|
503 |
+
reparaphrase_button.click(
|
504 |
+
interface.handle_reparaphrasing,
|
505 |
+
inputs=None,
|
506 |
+
outputs=reparaphrased_sentences_tabs + [step4_time]
|
507 |
+
)
|
508 |
+
|
509 |
+
return demo
|
510 |
+
|
511 |
+
if __name__ == "__main__":
|
512 |
+
project_root = Path(__file__).parent.parent
|
513 |
+
config_path = project_root / "utils" / "config.yaml"
|
514 |
+
config = load_config(config_path)['PECCAVI_TEXT']
|
515 |
+
|
516 |
+
create_gradio_interface(config).launch()
|
__pycache__/app.cpython-310.pyc
ADDED
Binary file (747 Bytes). View file
|
|
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
from UI.gradio import create_gradio_interface
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from utils.config import load_config
|
7 |
+
|
8 |
+
project_root = Path(__file__).resolve().parent
|
9 |
+
config_path = project_root / "utils" / "config.yaml"
|
10 |
+
config = load_config(config_path)['PECCAVI_TEXT']
|
11 |
+
|
12 |
+
def main():
|
13 |
+
"""
|
14 |
+
This function is the entry point for the PECCAVI Watermarking Model.
|
15 |
+
|
16 |
+
It creates the Gradio interface for the model and runs it.
|
17 |
+
"""
|
18 |
+
create_gradio_interface(config).launch()
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
environment.yml
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: panda
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=conda_forge
|
7 |
+
- _openmp_mutex=4.5=2_gnu
|
8 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
9 |
+
- bzip2=1.0.8=h5eee18b_6
|
10 |
+
- ca-certificates=2024.8.30=hbcca054_0
|
11 |
+
- comm=0.2.2=pyhd8ed1ab_0
|
12 |
+
- debugpy=1.8.6=py310hf71b8c6_0
|
13 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
14 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_0
|
15 |
+
- executing=2.1.0=pyhd8ed1ab_0
|
16 |
+
- ipykernel=6.29.5=pyh3099207_0
|
17 |
+
- ipython=8.27.0=pyh707e725_0
|
18 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
19 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_0
|
20 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
21 |
+
- krb5=1.21.3=h143b758_0
|
22 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
23 |
+
- libedit=3.1.20230828=h5eee18b_0
|
24 |
+
- libffi=3.4.4=h6a678d5_1
|
25 |
+
- libgcc=14.1.0=h77fa898_1
|
26 |
+
- libgcc-ng=14.1.0=h69a702a_1
|
27 |
+
- libgomp=14.1.0=h77fa898_1
|
28 |
+
- libsodium=1.0.20=h4ab18f5_0
|
29 |
+
- libstdcxx=14.1.0=hc0a3c3a_1
|
30 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
31 |
+
- libuuid=1.41.5=h5eee18b_0
|
32 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
33 |
+
- ncurses=6.4=h6a678d5_0
|
34 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
35 |
+
- openssl=3.3.2=hb9d3cd8_0
|
36 |
+
- packaging=24.1=pyhd8ed1ab_0
|
37 |
+
- parso=0.8.4=pyhd8ed1ab_0
|
38 |
+
- pexpect=4.9.0=pyhd8ed1ab_0
|
39 |
+
- pickleshare=0.7.5=py_1003
|
40 |
+
- pip=24.2=py310h06a4308_0
|
41 |
+
- platformdirs=4.3.6=pyhd8ed1ab_0
|
42 |
+
- prompt-toolkit=3.0.48=pyha770c72_0
|
43 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
44 |
+
- pure_eval=0.2.3=pyhd8ed1ab_0
|
45 |
+
- pygments=2.18.0=pyhd8ed1ab_0
|
46 |
+
- python=3.10.14=h955ad1f_1
|
47 |
+
- python_abi=3.10=2_cp310
|
48 |
+
- pyzmq=26.2.0=py310h71f11fc_2
|
49 |
+
- readline=8.2=h5eee18b_0
|
50 |
+
- setuptools=75.1.0=py310h06a4308_0
|
51 |
+
- sqlite=3.45.3=h5eee18b_0
|
52 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
53 |
+
- tk=8.6.14=h39e8969_0
|
54 |
+
- tornado=6.4.1=py310ha75aee5_1
|
55 |
+
- traitlets=5.14.3=pyhd8ed1ab_0
|
56 |
+
- typing_extensions=4.12.2=pyha770c72_0
|
57 |
+
- wcwidth=0.2.13=pyhd8ed1ab_0
|
58 |
+
- wheel=0.44.0=py310h06a4308_0
|
59 |
+
- xz=5.4.6=h5eee18b_1
|
60 |
+
- zeromq=4.3.5=ha4adb4c_5
|
61 |
+
- zlib=1.2.13=h5eee18b_1
|
62 |
+
- pip:
|
63 |
+
- absl-py==2.1.0
|
64 |
+
- accelerate==0.33.0
|
65 |
+
- aiofiles==23.2.1
|
66 |
+
- aiohappyeyeballs==2.3.5
|
67 |
+
- aiohttp==3.10.3
|
68 |
+
- aiosignal==1.3.1
|
69 |
+
- altgraph==0.17.4
|
70 |
+
- annotated-types==0.7.0
|
71 |
+
- anyio==4.6.0
|
72 |
+
- astunparse==1.6.3
|
73 |
+
- async-timeout==4.0.3
|
74 |
+
- attrs==24.2.0
|
75 |
+
- av==12.0.0
|
76 |
+
- backports-tarfile==1.2.0
|
77 |
+
- beautifulsoup4==4.12.3
|
78 |
+
- build==1.2.2
|
79 |
+
- cachetools==5.5.0
|
80 |
+
- certifi==2024.7.4
|
81 |
+
- cffi==1.17.1
|
82 |
+
- charset-normalizer==3.3.2
|
83 |
+
- clean-fid==0.1.35
|
84 |
+
- click==8.1.7
|
85 |
+
- colorama==0.4.6
|
86 |
+
- contextlib2==21.6.0
|
87 |
+
- contourpy==1.2.1
|
88 |
+
- cryptography==43.0.1
|
89 |
+
- cycler==0.12.1
|
90 |
+
- datasets==2.21.0
|
91 |
+
- diffusers==0.27.2
|
92 |
+
- dill==0.3.8
|
93 |
+
- docker-pycreds==0.4.0
|
94 |
+
- docutils==0.21.2
|
95 |
+
- fastapi==0.115.0
|
96 |
+
- ffmpy==0.4.0
|
97 |
+
- filelock==3.15.4
|
98 |
+
- flatbuffers==24.3.25
|
99 |
+
- fonttools==4.53.1
|
100 |
+
- frozenlist==1.4.1
|
101 |
+
- fsspec==2024.6.1
|
102 |
+
- gast==0.4.0
|
103 |
+
- gdown==5.2.0
|
104 |
+
- gitdb==4.0.11
|
105 |
+
- gitpython==3.1.43
|
106 |
+
- google-auth==2.35.0
|
107 |
+
- google-auth-oauthlib==0.4.6
|
108 |
+
- google-pasta==0.2.0
|
109 |
+
- gradio==4.44.0
|
110 |
+
- gradio-client==1.3.0
|
111 |
+
- grpcio==1.65.4
|
112 |
+
- h11==0.14.0
|
113 |
+
- h5py==3.11.0
|
114 |
+
- httpcore==1.0.6
|
115 |
+
- httpx==0.27.2
|
116 |
+
- huggingface-hub==0.25.2
|
117 |
+
- idna==3.7
|
118 |
+
- imageio==2.35.0
|
119 |
+
- importlib-metadata==8.2.0
|
120 |
+
- importlib-resources==6.4.5
|
121 |
+
- jaraco-classes==3.4.0
|
122 |
+
- jaraco-context==6.0.1
|
123 |
+
- jaraco-functools==4.1.0
|
124 |
+
- jeepney==0.8.0
|
125 |
+
- jinja2==3.1.4
|
126 |
+
- joblib==1.4.2
|
127 |
+
- json-with-comments==1.2.7
|
128 |
+
- keras==3.5.0
|
129 |
+
- keras-preprocessing==1.1.2
|
130 |
+
- keyring==25.4.1
|
131 |
+
- kiwisolver==1.4.5
|
132 |
+
- kornia==0.7.4
|
133 |
+
- kornia-rs==0.1.7
|
134 |
+
- lazy-loader==0.4
|
135 |
+
- libclang==18.1.1
|
136 |
+
- markdown==3.6
|
137 |
+
- markdown-it-py==3.0.0
|
138 |
+
- markupsafe==2.1.5
|
139 |
+
- matplotlib==3.9.2
|
140 |
+
- mdurl==0.1.2
|
141 |
+
- ml-collections==0.1.1
|
142 |
+
- ml-dtypes==0.4.0
|
143 |
+
- more-itertools==10.5.0
|
144 |
+
- multidict==6.0.5
|
145 |
+
- multiprocess==0.70.16
|
146 |
+
- namex==0.0.8
|
147 |
+
- networkx==3.3
|
148 |
+
- nh3==0.2.18
|
149 |
+
- nltk==3.9.1
|
150 |
+
- numpy==1.26.4
|
151 |
+
- nvidia-cublas-cu11==11.10.3.66
|
152 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
153 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
154 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
155 |
+
- oauthlib==3.2.2
|
156 |
+
- opencv-python==4.10.0.84
|
157 |
+
- opencv-python-headless==4.10.0.84
|
158 |
+
- opt-einsum==3.3.0
|
159 |
+
- optree==0.12.1
|
160 |
+
- orjson==3.10.7
|
161 |
+
- pandas==2.2.2
|
162 |
+
- pillow==10.4.0
|
163 |
+
- pkginfo==1.10.0
|
164 |
+
- plotly==5.24.1
|
165 |
+
- protobuf==4.25.5
|
166 |
+
- psutil==5.9.8
|
167 |
+
- pyarrow==17.0.0
|
168 |
+
- pyasn1==0.6.1
|
169 |
+
- pyasn1-modules==0.4.1
|
170 |
+
- pycparser==2.22
|
171 |
+
- pydantic==2.9.2
|
172 |
+
- pydantic-core==2.23.4
|
173 |
+
- pydub==0.25.1
|
174 |
+
- pyinstaller==6.10.0
|
175 |
+
- pyinstaller-hooks-contrib==2024.8
|
176 |
+
- pyparsing==3.1.2
|
177 |
+
- pyproject-hooks==1.1.0
|
178 |
+
- pysocks==1.7.1
|
179 |
+
- python-dateutil==2.9.0.post0
|
180 |
+
- python-multipart==0.0.12
|
181 |
+
- pytorch-msssim==1.0.0
|
182 |
+
- pytorchcv==0.0.73
|
183 |
+
- pytz==2023.3.post1
|
184 |
+
- pyyaml==6.0.2
|
185 |
+
- readme-renderer==44.0
|
186 |
+
- regex==2024.7.24
|
187 |
+
- requests==2.32.3
|
188 |
+
- requests-oauthlib==2.0.0
|
189 |
+
- requests-toolbelt==1.0.0
|
190 |
+
- rfc3986==2.0.0
|
191 |
+
- rich==13.7.1
|
192 |
+
- rsa==4.9
|
193 |
+
- ruff==0.6.9
|
194 |
+
- safetensors==0.4.4
|
195 |
+
- saliency==0.2.1
|
196 |
+
- scikit-image==0.24.0
|
197 |
+
- scikit-learn==1.6.0
|
198 |
+
- scipy==1.14.0
|
199 |
+
- secretstorage==3.3.3
|
200 |
+
- semantic-version==2.10.0
|
201 |
+
- sentence-transformers==3.3.1
|
202 |
+
- sentry-sdk==2.15.0
|
203 |
+
- setproctitle==1.3.3
|
204 |
+
- shapely==2.0.5
|
205 |
+
- shellingham==1.5.4
|
206 |
+
- six==1.12.0
|
207 |
+
- smmap==5.0.1
|
208 |
+
- sniffio==1.3.1
|
209 |
+
- soupsieve==2.6
|
210 |
+
- spaces==0.30.2
|
211 |
+
- starlette==0.38.6
|
212 |
+
- tenacity==9.0.0
|
213 |
+
- tensorboard==2.17.1
|
214 |
+
- tensorboard-data-server==0.7.2
|
215 |
+
- tensorboard-plugin-wit==1.8.1
|
216 |
+
- tensorflow==2.17.0
|
217 |
+
- tensorflow-estimator==2.10.0
|
218 |
+
- tensorflow-hub==0.16.1
|
219 |
+
- tensorflow-intel==0.0.1
|
220 |
+
- tensorflow-io-gcs-filesystem==0.31.0
|
221 |
+
- termcolor==1.1.0
|
222 |
+
- tf-keras==2.17.0
|
223 |
+
- threadpoolctl==3.5.0
|
224 |
+
- tifffile==2024.8.10
|
225 |
+
- timm==1.0.10
|
226 |
+
- tokenizers==0.19.1
|
227 |
+
- tomli==2.0.1
|
228 |
+
- tomlkit==0.12.0
|
229 |
+
- torch==1.13.1
|
230 |
+
- torchvision==0.14.1
|
231 |
+
- tqdm==4.66.5
|
232 |
+
- transformers==4.43.3
|
233 |
+
- twine==5.1.1
|
234 |
+
- typer==0.12.5
|
235 |
+
- tzdata==2024.1
|
236 |
+
- urllib3==2.2.2
|
237 |
+
- uvicorn==0.31.0
|
238 |
+
- wandb==0.18.3
|
239 |
+
- websockets==12.0
|
240 |
+
- werkzeug==3.0.4
|
241 |
+
- wrapt==1.11.2
|
242 |
+
- xxhash==3.4.1
|
243 |
+
- yarl==1.9.4
|
244 |
+
- zipp==3.20.0
|
245 |
+
prefix: /home/ashhar21137/miniconda3/envs/panda
|
metrics/distortion.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
8 |
+
from bert_score import BERTScorer
|
9 |
+
from bert_score.utils import model2layers
|
10 |
+
from nltk.tokenize import word_tokenize
|
11 |
+
from Levenshtein import distance as levenshtein_distance
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
14 |
+
from scipy.spatial.distance import cdist
|
15 |
+
from scipy.optimize import linear_sum_assignment
|
16 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
17 |
+
|
18 |
+
from config.config import load_config
|
19 |
+
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
|
20 |
+
config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
|
21 |
+
|
22 |
+
class SentenceDistortionCalculator:
|
23 |
+
"""
|
24 |
+
A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
|
25 |
+
"""
|
26 |
+
def __init__(self, config, original_sentence, paraphrased_sentences):
|
27 |
+
"""
|
28 |
+
Initialize the calculator with the original sentence and a list of modified sentences.
|
29 |
+
"""
|
30 |
+
self.original_sentence = original_sentence
|
31 |
+
self.paraphrased_sentences = paraphrased_sentences
|
32 |
+
|
33 |
+
self.levenshtein_distances = {}
|
34 |
+
self.bert_scores = {}
|
35 |
+
self.mover_scores = {}
|
36 |
+
|
37 |
+
self.normalized_levenshtein = {}
|
38 |
+
self.normalized_bert_scores = {}
|
39 |
+
self.normalized_mover_scores = {}
|
40 |
+
self.combined_distortions = {}
|
41 |
+
|
42 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained(config['Distortion'])
|
43 |
+
self.model = GPT2LMHeadModel.from_pretrained(config['Distortion'])
|
44 |
+
self.model.eval()
|
45 |
+
|
46 |
+
def calculate_all_metrics(self):
|
47 |
+
"""
|
48 |
+
Calculate all distortion metrics for each modified sentence.
|
49 |
+
"""
|
50 |
+
for idx, modified_sentence in tqdm(enumerate(self.paraphrased_sentences), total=len(self.paraphrased_sentences), desc="Calculating Metrics"):
|
51 |
+
key = f"Sentence_{idx+1}"
|
52 |
+
self.levenshtein_distances[key] = self._calculate_levenshtein_distance(modified_sentence)
|
53 |
+
self.bert_scores[key] = self._calculate_bert_score(modified_sentence)
|
54 |
+
self.mover_scores[key] = self._calculate_mover_score(modified_sentence)
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_metrics(self):
|
58 |
+
"""
|
59 |
+
Normalize all metrics to be between 0 and 1.
|
60 |
+
"""
|
61 |
+
for _ in tqdm(range(1), desc="Normalizing Metrics"): # Add tqdm here (wrap the normalization process)
|
62 |
+
self.normalized_levenshtein = self._normalize_dict(self.levenshtein_distances)
|
63 |
+
self.normalized_bert_scores = self._normalize_dict(self.bert_scores)
|
64 |
+
self.normalized_mover_scores = self._normalize_dict(self.mover_scores)
|
65 |
+
|
66 |
+
def calculate_combined_distortion(self):
|
67 |
+
"""
|
68 |
+
Calculate the combined distortion using the root mean square of the normalized metrics.
|
69 |
+
"""
|
70 |
+
for _ in tqdm(range(1), desc="Calculating Combined Distortion"): # Add tqdm here
|
71 |
+
for key in self.normalized_levenshtein.keys():
|
72 |
+
rms = np.sqrt(
|
73 |
+
(
|
74 |
+
self.normalized_levenshtein[key] ** 2 +
|
75 |
+
self.normalized_bert_scores[key] ** 2+
|
76 |
+
self.normalized_mover_scores[key] **2
|
77 |
+
) / 3
|
78 |
+
)
|
79 |
+
self.combined_distortions[key] = rms
|
80 |
+
|
81 |
+
def plot_metrics(self):
|
82 |
+
"""
|
83 |
+
Plot each normalized metric and the combined distortion in separate graphs.
|
84 |
+
"""
|
85 |
+
keys = list(self.normalized_levenshtein.keys())
|
86 |
+
indices = np.arange(len(keys))
|
87 |
+
|
88 |
+
# Prepare data for plotting
|
89 |
+
metrics = {
|
90 |
+
'Levenshtein Distance': [self.normalized_levenshtein[key] for key in keys],
|
91 |
+
'BERTScore': [self.normalized_bert_scores[key] for key in keys],
|
92 |
+
'MOVERscore':[self.normalized_mover_scores[key] for key in keys],
|
93 |
+
'Combined Distortion': [self.combined_distortions[key] for key in keys]
|
94 |
+
}
|
95 |
+
|
96 |
+
# Plot each metric separately
|
97 |
+
for metric_name, values in tqdm(metrics.items(), desc="Plotting Metrics"): # Add tqdm here
|
98 |
+
plt.figure(figsize=(12, 6))
|
99 |
+
plt.plot(indices, values, marker='o', color=np.random.rand(3,))
|
100 |
+
plt.xlabel('Sentence Index')
|
101 |
+
plt.ylabel('Normalized Value (0-1)')
|
102 |
+
plt.title(f'Normalized {metric_name}')
|
103 |
+
plt.grid(True)
|
104 |
+
plt.tight_layout()
|
105 |
+
plt.show()
|
106 |
+
|
107 |
+
def _calculate_levenshtein_distance(self, modified_sentence):
|
108 |
+
"""
|
109 |
+
Calculate the word-level Levenshtein distance between the original and modified sentence.
|
110 |
+
"""
|
111 |
+
words1 = word_tokenize(self.original_sentence)
|
112 |
+
words2 = word_tokenize(modified_sentence)
|
113 |
+
lev_distance = levenshtein_distance(words1, words2)
|
114 |
+
return (lev_distance / max(len(words1), len(words2)))
|
115 |
+
|
116 |
+
def _calculate_bert_score(self, modified_sentence):
|
117 |
+
"""
|
118 |
+
Compute the BERTScore similarity between the original and modified sentence.
|
119 |
+
Returns 1 - F1 score to represent dissimilarity.
|
120 |
+
"""
|
121 |
+
if not hasattr(self, 'original_sentence'):
|
122 |
+
raise ValueError("original_sentence is not set. Please set self.original_sentence before calling this function.")
|
123 |
+
if not isinstance(modified_sentence, str):
|
124 |
+
raise ValueError("modified_sentence must be a string.")
|
125 |
+
|
126 |
+
model_type = "microsoft/deberta-xlarge-mnli"
|
127 |
+
num_layers = model2layers[model_type]
|
128 |
+
|
129 |
+
if not hasattr(self, "cached_bertscorer"):
|
130 |
+
self.cached_bertscorer = BERTScorer(
|
131 |
+
model_type=model_type,
|
132 |
+
num_layers=num_layers,
|
133 |
+
batch_size=1, # Single sentence comparison
|
134 |
+
nthreads=4,
|
135 |
+
all_layers=False,
|
136 |
+
idf=False,
|
137 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
138 |
+
lang="en"
|
139 |
+
)
|
140 |
+
|
141 |
+
# Compute BERTScore
|
142 |
+
_, _, F1 = self.cached_bertscorer.score(
|
143 |
+
cands=[modified_sentence],
|
144 |
+
refs=[self.original_sentence],
|
145 |
+
verbose=False,
|
146 |
+
batch_size=1
|
147 |
+
)
|
148 |
+
|
149 |
+
return 1 - F1.item() # Return dissimilarity score
|
150 |
+
def _calculate_mover_score(self,modified_sentence,model_name='all-MiniLM-L6-v2'):
|
151 |
+
"""Compute MoverScore correctly using word-level embeddings."""
|
152 |
+
if not self.original_sentence:
|
153 |
+
raise ValueError("Original sentence not provided.")
|
154 |
+
|
155 |
+
# Tokenize sentences
|
156 |
+
original_tokens = self.original_sentence.split()
|
157 |
+
modified_tokens = modified_sentence.split()
|
158 |
+
model = SentenceTransformer(model_name)
|
159 |
+
|
160 |
+
# Compute word embeddings
|
161 |
+
original_embeddings = model.encode(original_tokens, convert_to_numpy=True)
|
162 |
+
modified_embeddings = model.encode(modified_tokens, convert_to_numpy=True)
|
163 |
+
|
164 |
+
# Compute cost matrix (cosine distance)
|
165 |
+
cost_matrix = cdist(original_embeddings, modified_embeddings, metric='cosine')
|
166 |
+
|
167 |
+
# Solve optimal transport problem (Hungarian Algorithm)
|
168 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
169 |
+
|
170 |
+
# Compute IDF weights
|
171 |
+
vectorizer = TfidfVectorizer()
|
172 |
+
vectorizer.fit([self.original_sentence, modified_sentence])
|
173 |
+
idf_values = dict(zip(vectorizer.get_feature_names_out(), vectorizer.idf_))
|
174 |
+
|
175 |
+
# Apply IDF weighting to aligned word pairs
|
176 |
+
idf_weights_original = np.array([idf_values.get(word.lower(), 1.0) for word in original_tokens])
|
177 |
+
idf_weights_modified = np.array([idf_values.get(word.lower(), 1.0) for word in modified_tokens])
|
178 |
+
combined_idf_weights = (idf_weights_original[row_ind] + idf_weights_modified[col_ind]) / 2
|
179 |
+
weighted_score = np.sum((1 - cost_matrix[row_ind, col_ind]) * combined_idf_weights) / np.sum(combined_idf_weights)
|
180 |
+
|
181 |
+
return 1-weighted_score # Higher score = more dissimilar
|
182 |
+
|
183 |
+
def _normalize_dict(self, metric_dict):
|
184 |
+
"""
|
185 |
+
Normalize the values in a dictionary to be between 0 and 1.
|
186 |
+
"""
|
187 |
+
values = np.array(list(metric_dict.values()))
|
188 |
+
min_val = values.min()
|
189 |
+
max_val = values.max()
|
190 |
+
if max_val - min_val == 0:
|
191 |
+
normalized_values = np.zeros_like(values)
|
192 |
+
else:
|
193 |
+
normalized_values = (values - min_val) / (max_val - min_val)
|
194 |
+
return dict(zip(metric_dict.keys(), normalized_values))
|
195 |
+
|
196 |
+
def get_normalized_metrics(self):
|
197 |
+
"""
|
198 |
+
Get all normalized metrics as a dictionary.
|
199 |
+
"""
|
200 |
+
return {
|
201 |
+
'Min Edit Distance': self.normalized_levenshtein,
|
202 |
+
'BERTScore': self.normalized_bert_scores,
|
203 |
+
'Mover Score': self.normalized_mover_scores
|
204 |
+
}
|
205 |
+
|
206 |
+
def get_combined_distortions(self):
|
207 |
+
"""
|
208 |
+
Get the dictionary of combined distortion values.
|
209 |
+
"""
|
210 |
+
return self.combined_distortions
|
211 |
+
|
212 |
+
# Example usage
|
213 |
+
if __name__ == "__main__":
|
214 |
+
|
215 |
+
config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
|
216 |
+
|
217 |
+
# Original sentence
|
218 |
+
original_sentence = "The quick brown fox jumps over the lazy dog"
|
219 |
+
|
220 |
+
# Paraphrased sentences
|
221 |
+
paraphrased_sentences = [
|
222 |
+
# Original 1: "A swift auburn fox leaps across a sleepy canine."
|
223 |
+
"The swift auburn fox leaps across a sleepy canine.",
|
224 |
+
"A quick auburn fox leaps across a sleepy canine.",
|
225 |
+
"A swift ginger fox leaps across a sleepy canine.",
|
226 |
+
"A swift auburn fox bounds across a sleepy canine.",
|
227 |
+
"A swift auburn fox leaps across a tired canine.",
|
228 |
+
"Three swift auburn foxes leap across a sleepy canine.",
|
229 |
+
"The vulpine specimen rapidly traverses over a dormant dog.",
|
230 |
+
"Like lightning, the russet hunter soars over the drowsy guardian.",
|
231 |
+
"Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
|
232 |
+
"One rapid Vulpes vulpes traverses the path of a quiescent canine.",
|
233 |
+
"A swift auburn predator navigates across a lethargic pet.",
|
234 |
+
"Subject A (fox) demonstrates velocity over Subject B (dog).",
|
235 |
+
|
236 |
+
# Original 2: "The agile russet fox bounds over an idle hound."
|
237 |
+
"Some agile russet foxes bound over an idle hound.",
|
238 |
+
"The nimble russet fox bounds over an idle hound.",
|
239 |
+
"The agile brown fox bounds over an idle hound.",
|
240 |
+
"The agile russet fox jumps over an idle hound.",
|
241 |
+
"The agile russet fox bounds over a lazy hound.",
|
242 |
+
"Two agile russet foxes bound over an idle hound.",
|
243 |
+
"A dexterous vulpine surpasses a stationary canine.",
|
244 |
+
"Quick as thought, the copper warrior sails over the guardian.",
|
245 |
+
"Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
|
246 |
+
"A dexterous V. vulpes exceeds the plane of an inactive canine.",
|
247 |
+
"An agile russet hunter maneuvers above a resting hound.",
|
248 |
+
"Test subject F-1 achieves displacement superior to subject D-1.",
|
249 |
+
|
250 |
+
# Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
|
251 |
+
"The nimble mahogany vulpine vaults above a drowsy dog.",
|
252 |
+
"A swift mahogany vulpine vaults above a drowsy dog.",
|
253 |
+
"A nimble reddish vulpine vaults above a drowsy dog.",
|
254 |
+
"A nimble mahogany fox vaults above a drowsy dog.",
|
255 |
+
"A nimble mahogany vulpine leaps above a drowsy dog.",
|
256 |
+
"Four nimble mahogany vulpines vault above a drowsy dog.",
|
257 |
+
"An agile specimen of reddish fur surpasses a somnolent canine.",
|
258 |
+
"Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
|
259 |
+
"Tha quick brown beastie jumps o'er the tired pup, aye.",
|
260 |
+
"Single V. vulpes demonstrates vertical traverse over C. familiaris.",
|
261 |
+
"A nimble rust-colored predator crosses above a drowsy pet.",
|
262 |
+
"Observed: Subject Red executes vertical motion over Subject Gray.",
|
263 |
+
|
264 |
+
# Original 4: "The speedy copper-colored fox hops over the lethargic pup."
|
265 |
+
"A speedy copper-colored fox hops over the lethargic pup.",
|
266 |
+
"The quick copper-colored fox hops over the lethargic pup.",
|
267 |
+
"The speedy bronze fox hops over the lethargic pup.",
|
268 |
+
"The speedy copper-colored fox jumps over the lethargic pup.",
|
269 |
+
"The speedy copper-colored fox hops over the tired pup.",
|
270 |
+
"Multiple speedy copper-colored foxes hop over the lethargic pup.",
|
271 |
+
"A rapid vulpine of bronze hue traverses an inactive young canine.",
|
272 |
+
"Swift as a dart, the metallic hunter bounds over the lazy puppy.",
|
273 |
+
"Tha fast copper beastie leaps o'er the sleepy wee dog.",
|
274 |
+
"1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
|
275 |
+
"A fleet copper-toned predator moves past a sluggish young dog.",
|
276 |
+
"Field note: Adult fox subject exceeds puppy subject vertically.",
|
277 |
+
|
278 |
+
# Original 5: "A rapid tawny fox springs over a sluggish dog."
|
279 |
+
"The rapid tawny fox springs over a sluggish dog.",
|
280 |
+
"A quick tawny fox springs over a sluggish dog.",
|
281 |
+
"A rapid golden fox springs over a sluggish dog.",
|
282 |
+
"A rapid tawny fox jumps over a sluggish dog.",
|
283 |
+
"A rapid tawny fox springs over a lazy dog.",
|
284 |
+
"Six rapid tawny foxes spring over a sluggish dog.",
|
285 |
+
"An expeditious yellowish vulpine surpasses a torpid canine.",
|
286 |
+
"Fast as a bullet, the golden hunter vaults over the idle guard.",
|
287 |
+
"Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
|
288 |
+
"One V. vulpes displays rapid transit over one inactive C. familiaris.",
|
289 |
+
"A speedy yellow-brown predator bypasses a motionless dog.",
|
290 |
+
"Log entry: Vulpine subject achieves swift vertical displacement.",
|
291 |
+
|
292 |
+
# Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
|
293 |
+
"A fleet-footed chestnut fox soars above an indolent canine.",
|
294 |
+
"The swift chestnut fox soars above an indolent canine.",
|
295 |
+
"The fleet-footed brown fox soars above an indolent canine.",
|
296 |
+
"The fleet-footed chestnut fox leaps above an indolent canine.",
|
297 |
+
"The fleet-footed chestnut fox soars above a lazy canine.",
|
298 |
+
"Several fleet-footed chestnut foxes soar above an indolent canine.",
|
299 |
+
"A rapid brown vulpine specimen traverses a lethargic domestic dog.",
|
300 |
+
"Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
|
301 |
+
"Tha quick brown beastie sails o'er the sleepy hound, ken.",
|
302 |
+
"Single agile V. vulpes achieves elevation above stationary canine.",
|
303 |
+
"A nimble brown predator glides over an unmoving domestic animal.",
|
304 |
+
"Research note: Brown subject displays superior vertical mobility.",
|
305 |
+
|
306 |
+
# Original 7: "A fast ginger fox hurdles past a slothful dog."
|
307 |
+
"The fast ginger fox hurdles past a slothful dog.",
|
308 |
+
"A quick ginger fox hurdles past a slothful dog.",
|
309 |
+
"A fast red fox hurdles past a slothful dog.",
|
310 |
+
"A fast ginger fox jumps past a slothful dog.",
|
311 |
+
"A fast ginger fox hurdles past a lazy dog.",
|
312 |
+
"Five fast ginger foxes hurdle past a slothful dog.",
|
313 |
+
"A rapid orange vulpine bypasses a lethargic canine.",
|
314 |
+
"Quick as lightning, the flame-colored hunter races past the lazy guard.",
|
315 |
+
"Tha swift ginger beastie leaps past the tired doggy, ye see.",
|
316 |
+
"1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
|
317 |
+
"A speedy red-orange predator overtakes a motionless dog.",
|
318 |
+
"Data point: Orange subject demonstrates rapid transit past Gray subject.",
|
319 |
+
|
320 |
+
# Original 8: "The spry rusty-colored fox jumps across a dozing hound."
|
321 |
+
"A spry rusty-colored fox jumps across a dozing hound.",
|
322 |
+
"The agile rusty-colored fox jumps across a dozing hound.",
|
323 |
+
"The spry reddish fox jumps across a dozing hound.",
|
324 |
+
"The spry rusty-colored fox leaps across a dozing hound.",
|
325 |
+
"The spry rusty-colored fox jumps across a sleeping hound.",
|
326 |
+
"Multiple spry rusty-colored foxes jump across a dozing hound.",
|
327 |
+
"An agile rust-toned vulpine traverses a somnolent canine.",
|
328 |
+
"Nimble as thought, the copper hunter bounds over the resting guard.",
|
329 |
+
"Tha lively rust-colored beastie hops o'er the snoozin' hound.",
|
330 |
+
"Single dexterous V. vulpes crosses path of dormant C. familiaris.",
|
331 |
+
"A lithe rust-tinted predator moves past a slumbering dog.",
|
332 |
+
"Observation: Russet subject exhibits agility over dormant subject.",
|
333 |
+
|
334 |
+
# Original 9: "A quick tan fox leaps over an inactive dog."
|
335 |
+
"The quick tan fox leaps over an inactive dog.",
|
336 |
+
"A swift tan fox leaps over an inactive dog.",
|
337 |
+
"A quick beige fox leaps over an inactive dog.",
|
338 |
+
"A quick tan fox jumps over an inactive dog.",
|
339 |
+
"A quick tan fox leaps over a motionless dog.",
|
340 |
+
"Seven quick tan foxes leap over an inactive dog.",
|
341 |
+
"A rapid light-brown vulpine surpasses a stationary canine.",
|
342 |
+
"Fast as wind, the sand-colored hunter soars over the still guard.",
|
343 |
+
"Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
|
344 |
+
"One agile fawn V. vulpes traverses one immobile C. familiaris.",
|
345 |
+
"A fleet tan-colored predator bypasses an unmoving dog.",
|
346 |
+
"Field report: Tan subject demonstrates movement over static subject.",
|
347 |
+
|
348 |
+
# Original 10: "The brisk auburn vulpine bounces over a listless canine."
|
349 |
+
"Some brisk auburn vulpines bounce over a listless canine.",
|
350 |
+
"The quick auburn vulpine bounces over a listless canine.",
|
351 |
+
"The brisk russet vulpine bounces over a listless canine.",
|
352 |
+
"The brisk auburn fox bounces over a listless canine.",
|
353 |
+
"The brisk auburn vulpine jumps over a listless canine.",
|
354 |
+
"Five brisk auburn vulpines bounce over a listless canine.",
|
355 |
+
"The expeditious specimen supersedes a quiescent Canis lupus.",
|
356 |
+
"Swift as wind, the russet hunter vaults over the idle guardian.",
|
357 |
+
"Tha quick ginger beastie hops o'er the lazy mutt, aye.",
|
358 |
+
"One V. vulpes achieves displacement over inactive C. familiaris.",
|
359 |
+
"A high-velocity auburn predator traverses an immobile animal.",
|
360 |
+
"Final observation: Red subject shows mobility over Gray subject."
|
361 |
+
]
|
362 |
+
|
363 |
+
distortion_calculator = SentenceDistortionCalculator(config, original_sentence, paraphrased_sentences)
|
364 |
+
for _ in tqdm(range(1)):
|
365 |
+
distortion_calculator.calculate_all_metrics()
|
366 |
+
distortion_calculator.normalize_metrics()
|
367 |
+
distortion_calculator.calculate_combined_distortion()
|
368 |
+
distortion_calculator.plot_metrics()
|
369 |
+
print("Normalized Metrics:", distortion_calculator.get_normalized_metrics())
|
370 |
+
print("Combined Distortion:", distortion_calculator.get_combined_distortions())
|
renderers/__pycache__/highlighter.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
renderers/__pycache__/highlighter.cpython-311.pyc
ADDED
Binary file (6.79 kB). View file
|
|
renderers/__pycache__/plot_3d.cpython-310.pyc
ADDED
Binary file (4.34 kB). View file
|
|
renderers/__pycache__/plot_3d.cpython-311.pyc
ADDED
Binary file (6 kB). View file
|
|
renderers/__pycache__/tree.cpython-310.pyc
ADDED
Binary file (10.6 kB). View file
|
|
renderers/__pycache__/tree.cpython-311.pyc
ADDED
Binary file (21.1 kB). View file
|
|
renderers/highlighter.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def highlight_common_words(common_words, sentences, title):
|
4 |
+
"""
|
5 |
+
Highlight common words in sentences by adding color-coded background and unique IDs.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
|
9 |
+
sentences (list of str): List of sentences to search through.
|
10 |
+
title (str): The title for the HTML output.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: HTML string with the highlighted sentences.
|
14 |
+
"""
|
15 |
+
color_map = {}
|
16 |
+
color_index = 0
|
17 |
+
highlighted_html = []
|
18 |
+
|
19 |
+
# Process each sentence
|
20 |
+
for idx, sentence in enumerate(sentences, start=1):
|
21 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
22 |
+
highlighted_sentence = sentence_with_idx
|
23 |
+
|
24 |
+
# Highlight common words in each sentence
|
25 |
+
for index, word in common_words:
|
26 |
+
if word not in color_map:
|
27 |
+
color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
28 |
+
color_index += 1
|
29 |
+
|
30 |
+
# Escape word and create regex pattern to match whole word
|
31 |
+
escaped_word = re.escape(word)
|
32 |
+
pattern = rf'\b{escaped_word}\b'
|
33 |
+
|
34 |
+
# Replace the word with highlighted version
|
35 |
+
highlighted_sentence = re.sub(
|
36 |
+
pattern,
|
37 |
+
lambda m, idx=index, color=color_map[word]: (
|
38 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
39 |
+
f' padding: 2px 4px; border-radius: 2px; position: relative;">'
|
40 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
41 |
+
f' padding: 2px 5px; margin-right: 5px;">{idx}</span>'
|
42 |
+
f'{m.group(0)}'
|
43 |
+
f'</span>'
|
44 |
+
),
|
45 |
+
highlighted_sentence,
|
46 |
+
flags=re.IGNORECASE
|
47 |
+
)
|
48 |
+
|
49 |
+
highlighted_html.append(highlighted_sentence)
|
50 |
+
|
51 |
+
# Format the HTML output with the title
|
52 |
+
final_html = "<br><br>".join(highlighted_html)
|
53 |
+
return f'''
|
54 |
+
<div style="border: solid 1px #FFFFFF; padding: 16px; background-color: #000000; color: #FFFFFF; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
55 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
|
56 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
|
57 |
+
</div>
|
58 |
+
'''
|
59 |
+
|
60 |
+
def highlight_common_words_dict(common_words, sentences, title):
|
61 |
+
"""
|
62 |
+
Highlight common words in sentences (from a dictionary) by adding color-coded background and unique IDs.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
|
66 |
+
sentences (dict): A dictionary of sentences where the key is the sentence and the value is an entailment score.
|
67 |
+
title (str): The title for the HTML output.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
str: HTML string with the highlighted sentences and their entailment scores.
|
71 |
+
"""
|
72 |
+
color_map = {}
|
73 |
+
color_index = 0
|
74 |
+
highlighted_html = []
|
75 |
+
|
76 |
+
# Process each sentence and its score
|
77 |
+
for idx, (sentence, score) in enumerate(sentences.items(), start=1):
|
78 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
79 |
+
highlighted_sentence = sentence_with_idx
|
80 |
+
|
81 |
+
# Highlight common words in each sentence
|
82 |
+
for index, word in common_words:
|
83 |
+
if word not in color_map:
|
84 |
+
color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
85 |
+
color_index += 1
|
86 |
+
|
87 |
+
# Escape word and create regex pattern to match whole word
|
88 |
+
escaped_word = re.escape(word)
|
89 |
+
pattern = rf'\b{escaped_word}\b'
|
90 |
+
|
91 |
+
# Replace the word with highlighted version
|
92 |
+
highlighted_sentence = re.sub(
|
93 |
+
pattern,
|
94 |
+
lambda m, idx=index, color=color_map[word]: (
|
95 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
96 |
+
f' padding: 1px 2px; border-radius: 2px; position: relative;">'
|
97 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
98 |
+
f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
|
99 |
+
f'{m.group(0)}'
|
100 |
+
f'</span>'
|
101 |
+
),
|
102 |
+
highlighted_sentence,
|
103 |
+
flags=re.IGNORECASE
|
104 |
+
)
|
105 |
+
|
106 |
+
# Add the entailment score
|
107 |
+
highlighted_html.append(
|
108 |
+
f'<div style="margin-bottom: 5px;">'
|
109 |
+
f'{highlighted_sentence}'
|
110 |
+
f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; '
|
111 |
+
f'background-color: #333333; color: white; font-size: 0.9em;">'
|
112 |
+
f'Entailment Score: {score}</div></div>'
|
113 |
+
)
|
114 |
+
|
115 |
+
# Format the HTML output with the title
|
116 |
+
final_html = "<br>".join(highlighted_html)
|
117 |
+
return f'''
|
118 |
+
<div style="background-color: #000000; color: #FFFFFF;border: solid 1px #FFFFFF; border-radius: 8px;">
|
119 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
|
120 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
|
121 |
+
</div>
|
122 |
+
'''
|
123 |
+
|
124 |
+
def reparaphrased_sentences_html(sentences):
|
125 |
+
"""
|
126 |
+
Create an HTML representation of sentences with numbering.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
sentences (list of str): List of sentences to format.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
str: HTML string with numbered sentences.
|
133 |
+
"""
|
134 |
+
formatted_sentences = []
|
135 |
+
|
136 |
+
# Process each sentence
|
137 |
+
for idx, sentence in enumerate(sentences, start=1):
|
138 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
139 |
+
formatted_sentences.append(sentence_with_idx)
|
140 |
+
|
141 |
+
# Format the HTML output
|
142 |
+
final_html = "<br><br>".join(formatted_sentences)
|
143 |
+
return f'''
|
144 |
+
<div style="border: solid 1px #FFFFFF; background-color: #000000; color: #FFFFFF;
|
145 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
146 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
147 |
+
</div>
|
148 |
+
'''
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
# Example usage
|
152 |
+
common_words = [(1, "highlight"), (2, "numbering")]
|
153 |
+
sentences = ["This is a test to highlight words.", "Numbering is important for clarity."]
|
154 |
+
|
155 |
+
# Test highlight_common_words
|
156 |
+
highlighted_html = highlight_common_words(common_words, sentences, "Test Highlighting")
|
157 |
+
print(highlighted_html)
|
158 |
+
|
159 |
+
# Test highlight_common_words_dict
|
160 |
+
sentences_with_scores = {"Highlight words in this text.": 0.95, "Number sentences for clarity.": 0.8}
|
161 |
+
highlighted_html_dict = highlight_common_words_dict(common_words, sentences_with_scores, "Test Dict Highlighting")
|
162 |
+
print(highlighted_html_dict)
|
renderers/plot_3d.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the code to plot a 3d tree
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
from scipy.interpolate import griddata
|
7 |
+
|
8 |
+
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
|
9 |
+
"""
|
10 |
+
Generates a 3D surface plot showing the relationship between detectability, distortion,
|
11 |
+
and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.
|
12 |
+
|
13 |
+
The function takes three sets of values: detectability, distortion, and Euclidean distance,
|
14 |
+
normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
|
15 |
+
The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted
|
16 |
+
as a red marker on the 3D surface plot.
|
17 |
+
|
18 |
+
The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
|
19 |
+
for the Euclidean distance over the detectability and distortion values. The result is a surface plot
|
20 |
+
where the contours represent different Euclidean distances.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
detectability_val (list or array): A list or array of detectability scores.
|
24 |
+
distortion_val (list or array): A list or array of distortion scores.
|
25 |
+
euclidean_val (list or array): A list or array of Euclidean distances.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot,
|
29 |
+
with contour lines and a marker for the sweet spot.
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the
|
33 |
+
input data does not allow for a proper interpolation.
|
34 |
+
|
35 |
+
Example:
|
36 |
+
# Example of usage:
|
37 |
+
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
|
38 |
+
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
|
39 |
+
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
|
40 |
+
|
41 |
+
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
|
42 |
+
fig.show() # Displays the plot in a web browser
|
43 |
+
|
44 |
+
Notes:
|
45 |
+
- The composite score is calculated as:
|
46 |
+
`composite_score = norm_detectability - (norm_distortion + norm_euclidean)`,
|
47 |
+
where the goal is to maximize detectability and minimize distortion and Euclidean distance.
|
48 |
+
- The `griddata` function uses linear interpolation to create a smooth surface for the plot.
|
49 |
+
- The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
|
50 |
+
"""
|
51 |
+
|
52 |
+
detectability = np.array(detectability_val)
|
53 |
+
distortion = np.array(distortion_val)
|
54 |
+
euclidean = np.array(euclidean_val)
|
55 |
+
|
56 |
+
# Normalize the values to range [0, 1]
|
57 |
+
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
|
58 |
+
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
|
59 |
+
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
|
60 |
+
|
61 |
+
# Composite score: maximize detectability, minimize distortion and Euclidean distance
|
62 |
+
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
|
63 |
+
|
64 |
+
# Find the index of the maximum score (sweet spot)
|
65 |
+
sweet_spot_index = np.argmax(composite_score)
|
66 |
+
|
67 |
+
# Sweet spot values
|
68 |
+
sweet_spot_detectability = detectability[sweet_spot_index]
|
69 |
+
sweet_spot_distortion = distortion[sweet_spot_index]
|
70 |
+
sweet_spot_euclidean = euclidean[sweet_spot_index]
|
71 |
+
|
72 |
+
# Create a meshgrid from the data
|
73 |
+
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
|
74 |
+
np.linspace(min(distortion), max(distortion), 30))
|
75 |
+
|
76 |
+
# Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
|
77 |
+
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')
|
78 |
+
|
79 |
+
if z_grid is None:
|
80 |
+
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
|
81 |
+
|
82 |
+
# Create the 3D contour plot with the Plasma color scale
|
83 |
+
fig = go.Figure(data=go.Surface(
|
84 |
+
z=z_grid,
|
85 |
+
x=x_grid,
|
86 |
+
y=y_grid,
|
87 |
+
contours={
|
88 |
+
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
|
89 |
+
},
|
90 |
+
colorscale='Plasma'
|
91 |
+
))
|
92 |
+
|
93 |
+
# Add a marker for the sweet spot
|
94 |
+
fig.add_trace(go.Scatter3d(
|
95 |
+
x=[sweet_spot_detectability],
|
96 |
+
y=[sweet_spot_distortion],
|
97 |
+
z=[sweet_spot_euclidean],
|
98 |
+
mode='markers+text',
|
99 |
+
marker=dict(size=10, color='red', symbol='circle'),
|
100 |
+
text=["Sweet Spot"],
|
101 |
+
textposition="top center"
|
102 |
+
))
|
103 |
+
|
104 |
+
# Set axis labels
|
105 |
+
fig.update_layout(
|
106 |
+
scene=dict(
|
107 |
+
xaxis_title='Detectability Score',
|
108 |
+
yaxis_title='Distortion Score',
|
109 |
+
zaxis_title='Euclidean Distance'
|
110 |
+
),
|
111 |
+
margin=dict(l=0, r=0, b=0, t=0)
|
112 |
+
)
|
113 |
+
|
114 |
+
return fig
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
# Example input data
|
118 |
+
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
|
119 |
+
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
|
120 |
+
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
|
121 |
+
|
122 |
+
# Call the function with example data
|
123 |
+
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
|
124 |
+
|
125 |
+
# Show the plot
|
126 |
+
fig.show()
|
renderers/tree.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import textwrap
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams):
|
7 |
+
"""
|
8 |
+
Generates a subplot visualizing paraphrased and masked sentences in a tree structure.
|
9 |
+
Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
paraphrased_sentence (str): The paraphrased sentence to be visualized.
|
13 |
+
masked_sentences (list of str): A list of masked sentences to be visualized.
|
14 |
+
strategies (list of str, optional): List of strategies used for each masked sentence.
|
15 |
+
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
|
16 |
+
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
|
20 |
+
"""
|
21 |
+
# Combine nodes into one list with appropriate labels
|
22 |
+
if isinstance(masked_sentences, str):
|
23 |
+
masked_sentences = [masked_sentences]
|
24 |
+
nodes = [paraphrased_sentence] + masked_sentences
|
25 |
+
nodes[0] += ' L0' # Paraphrased sentence is level 0
|
26 |
+
if len(nodes) < 2:
|
27 |
+
print("[ERROR] Insufficient nodes for visualization")
|
28 |
+
return go.Figure()
|
29 |
+
|
30 |
+
for i in range(1, len(nodes)):
|
31 |
+
nodes[i] += ' L1' # masked sentences are level 1
|
32 |
+
|
33 |
+
def apply_lcs_numbering(sentence, common_grams):
|
34 |
+
"""
|
35 |
+
Applies LCS numbering to the sentence based on the common_grams.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
sentence (str): The sentence to which the LCS numbering should be applied.
|
39 |
+
common_grams (list of tuples): A list of common grams to be replaced with LCS numbers.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
str: The sentence with LCS numbering applied.
|
43 |
+
"""
|
44 |
+
for idx, lcs in common_grams:
|
45 |
+
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
|
46 |
+
return sentence
|
47 |
+
|
48 |
+
# Apply LCS numbering
|
49 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
50 |
+
|
51 |
+
|
52 |
+
def highlight_words(sentence, color_map):
|
53 |
+
"""
|
54 |
+
Highlights words in the sentence based on the color_map.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
sentence (str): The sentence where the words will be highlighted.
|
58 |
+
color_map (dict): A dictionary mapping words to their colors.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
str: The sentence with highlighted words.
|
62 |
+
"""
|
63 |
+
for word, color in color_map.items():
|
64 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
65 |
+
return sentence
|
66 |
+
|
67 |
+
# Clean and wrap nodes, and highlight specified words globally
|
68 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
69 |
+
global_color_map = dict(highlight_info)
|
70 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
71 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
|
72 |
+
|
73 |
+
def get_levels_and_edges(nodes, strategies=None):
|
74 |
+
"""
|
75 |
+
Determines tree levels and creates edges dynamically.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
nodes (list of str): The nodes representing the sentences.
|
79 |
+
strategies (list of str, optional): The strategies used for each edge.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
tuple: A tuple containing two dictionaries:
|
83 |
+
- levels: A dictionary mapping node indices to their levels.
|
84 |
+
- edges: A list of edges where each edge is represented by a tuple of node indices.
|
85 |
+
"""
|
86 |
+
levels = {}
|
87 |
+
edges = []
|
88 |
+
for i, node in enumerate(nodes):
|
89 |
+
level = int(node.split()[-1][1])
|
90 |
+
levels[i] = level
|
91 |
+
|
92 |
+
# Add edges from L0 to all L1 nodes
|
93 |
+
root_node = next((i for i, level in levels.items() if level == 0), 0)
|
94 |
+
for i, level in levels.items():
|
95 |
+
if level == 1:
|
96 |
+
edges.append((root_node, i))
|
97 |
+
|
98 |
+
return levels, edges
|
99 |
+
|
100 |
+
# Get levels and dynamic edges
|
101 |
+
levels, edges = get_levels_and_edges(nodes, strategies)
|
102 |
+
max_level = max(levels.values(), default=0)
|
103 |
+
|
104 |
+
# Calculate positions
|
105 |
+
positions = {}
|
106 |
+
level_heights = defaultdict(int)
|
107 |
+
for node, level in levels.items():
|
108 |
+
level_heights[level] += 1
|
109 |
+
|
110 |
+
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
|
111 |
+
x_gap = 2
|
112 |
+
l1_y_gap = 10
|
113 |
+
|
114 |
+
for node, level in levels.items():
|
115 |
+
if level == 1:
|
116 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
117 |
+
else:
|
118 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
119 |
+
y_offsets[level] += 1
|
120 |
+
|
121 |
+
def color_highlighted_words(node, color_map):
|
122 |
+
"""
|
123 |
+
Colors the highlighted words in the node text.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
node (str): The node text to be highlighted.
|
127 |
+
color_map (dict): A dictionary mapping words to their colors.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
str: The node text with highlighted words.
|
131 |
+
"""
|
132 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
133 |
+
colored_parts = []
|
134 |
+
for part in parts:
|
135 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
136 |
+
if match:
|
137 |
+
word = match.group(1)
|
138 |
+
color = color_map.get(word, 'black')
|
139 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
140 |
+
else:
|
141 |
+
colored_parts.append(part)
|
142 |
+
return ''.join(colored_parts)
|
143 |
+
|
144 |
+
# Define the text for each edge
|
145 |
+
default_edge_texts = [
|
146 |
+
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
|
147 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
148 |
+
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
|
149 |
+
"Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling",
|
150 |
+
"Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling"
|
151 |
+
]
|
152 |
+
|
153 |
+
if len(nodes) < 2:
|
154 |
+
print("[ERROR] Insufficient nodes for visualization")
|
155 |
+
return go.Figure()
|
156 |
+
|
157 |
+
# Create figure
|
158 |
+
fig1 = go.Figure()
|
159 |
+
|
160 |
+
# Add nodes to the figure
|
161 |
+
for i, node in enumerate(wrapped_nodes):
|
162 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
163 |
+
x, y = positions[i]
|
164 |
+
fig1.add_trace(go.Scatter(
|
165 |
+
x=[-x], # Reflect the x coordinate
|
166 |
+
y=[y],
|
167 |
+
mode='markers',
|
168 |
+
marker=dict(size=20, color='blue', line=dict(color='black', width=2)),
|
169 |
+
hoverinfo='none'
|
170 |
+
))
|
171 |
+
fig1.add_annotation(
|
172 |
+
x=-x, # Reflect the x coordinate
|
173 |
+
y=y,
|
174 |
+
text=colored_node,
|
175 |
+
showarrow=False,
|
176 |
+
xshift=15,
|
177 |
+
align="center",
|
178 |
+
font=dict(size=12),
|
179 |
+
bordercolor='black',
|
180 |
+
borderwidth=2,
|
181 |
+
borderpad=4,
|
182 |
+
bgcolor='white',
|
183 |
+
width=400,
|
184 |
+
height=100
|
185 |
+
)
|
186 |
+
|
187 |
+
# Add edges and text above each edge
|
188 |
+
for i, edge in enumerate(edges):
|
189 |
+
x0, y0 = positions[edge[0]]
|
190 |
+
x1, y1 = positions[edge[1]]
|
191 |
+
|
192 |
+
# Use strategy if available, otherwise use default edge text
|
193 |
+
if strategies and i < len(strategies):
|
194 |
+
edge_text = strategies[i]
|
195 |
+
else:
|
196 |
+
edge_text = default_edge_texts[i % len(default_edge_texts)]
|
197 |
+
|
198 |
+
fig1.add_trace(go.Scatter(
|
199 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
200 |
+
y=[y0, y1],
|
201 |
+
mode='lines',
|
202 |
+
line=dict(color='black', width=1)
|
203 |
+
))
|
204 |
+
|
205 |
+
# Calculate the midpoint of the edge
|
206 |
+
mid_x = (-x0 + -x1) / 2
|
207 |
+
mid_y = (y0 + y1) / 2
|
208 |
+
|
209 |
+
# Adjust y position to shift text upwards
|
210 |
+
text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
|
211 |
+
|
212 |
+
# Add text annotation above the edge
|
213 |
+
fig1.add_annotation(
|
214 |
+
x=mid_x,
|
215 |
+
y=text_y_position,
|
216 |
+
text=edge_text, # Use the text specific to this edge
|
217 |
+
showarrow=False,
|
218 |
+
font=dict(size=12),
|
219 |
+
align="center"
|
220 |
+
)
|
221 |
+
|
222 |
+
fig1.update_layout(
|
223 |
+
showlegend=False,
|
224 |
+
margin=dict(t=50, b=50, l=50, r=50),
|
225 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
226 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
227 |
+
width=800 + max_level * 200, # Adjusted width to accommodate more levels
|
228 |
+
height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels
|
229 |
+
plot_bgcolor='rgba(240,240,240,0.2)',
|
230 |
+
paper_bgcolor='white'
|
231 |
+
)
|
232 |
+
|
233 |
+
return fig1
|
234 |
+
|
235 |
+
def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams):
|
236 |
+
"""
|
237 |
+
Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure.
|
238 |
+
Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
masked_sentences (list of str): A list of masked sentences to be visualized as root nodes.
|
242 |
+
sampled_sentences (list of str): A list of sampled sentences derived from masked sentences.
|
243 |
+
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
|
244 |
+
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
|
248 |
+
"""
|
249 |
+
# Define sampling techniques
|
250 |
+
sampling_techniques = [
|
251 |
+
"Greedy Sampling",
|
252 |
+
"Temperature Sampling",
|
253 |
+
"Exponential Minimum Sampling",
|
254 |
+
"Inverse Transform Sampling"
|
255 |
+
]
|
256 |
+
|
257 |
+
# Calculate total number of nodes
|
258 |
+
num_masked = len(masked_sentences)
|
259 |
+
num_sampled_per_masked = len(sampling_techniques)
|
260 |
+
total_nodes = num_masked + (num_masked * num_sampled_per_masked)
|
261 |
+
|
262 |
+
# Combine all sentences into nodes list with appropriate labels
|
263 |
+
nodes = []
|
264 |
+
# Level 0: masked sentences (root nodes)
|
265 |
+
nodes.extend([s + ' L0' for s in masked_sentences])
|
266 |
+
|
267 |
+
# Level 1: sampled sentences (branch nodes)
|
268 |
+
# For each masked sentence, we should have samples from each technique
|
269 |
+
sampled_nodes = []
|
270 |
+
|
271 |
+
# Validate if we have the expected number of sampled sentences
|
272 |
+
expected_sampled_count = num_masked * num_sampled_per_masked
|
273 |
+
if len(sampled_sentences) < expected_sampled_count:
|
274 |
+
# If insufficient samples provided, pad with placeholder sentences
|
275 |
+
print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}")
|
276 |
+
while len(sampled_sentences) < expected_sampled_count:
|
277 |
+
sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}")
|
278 |
+
|
279 |
+
# Add all sampled sentences with level information
|
280 |
+
for s in sampled_sentences[:expected_sampled_count]:
|
281 |
+
sampled_nodes.append(s + ' L1')
|
282 |
+
|
283 |
+
nodes.extend(sampled_nodes)
|
284 |
+
|
285 |
+
def apply_lcs_numbering(sentence, common_grams):
|
286 |
+
"""
|
287 |
+
Applies LCS numbering to the sentence based on the common_grams.
|
288 |
+
"""
|
289 |
+
for idx, lcs in common_grams:
|
290 |
+
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
|
291 |
+
return sentence
|
292 |
+
|
293 |
+
# Apply LCS numbering
|
294 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
295 |
+
|
296 |
+
def highlight_words(sentence, color_map):
|
297 |
+
"""
|
298 |
+
Highlights words in the sentence based on the color_map.
|
299 |
+
"""
|
300 |
+
for word, color in color_map.items():
|
301 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
302 |
+
return sentence
|
303 |
+
|
304 |
+
# Helper function to color highlighted words
|
305 |
+
def color_highlighted_words(node, color_map):
|
306 |
+
"""
|
307 |
+
Colors the highlighted words in the node text.
|
308 |
+
"""
|
309 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
310 |
+
colored_parts = []
|
311 |
+
for part in parts:
|
312 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
313 |
+
if match:
|
314 |
+
word = match.group(1)
|
315 |
+
color = color_map.get(word, 'black')
|
316 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
317 |
+
else:
|
318 |
+
colored_parts.append(part)
|
319 |
+
return ''.join(colored_parts)
|
320 |
+
|
321 |
+
# Clean nodes, highlight words, and wrap text
|
322 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
323 |
+
global_color_map = dict(highlight_info)
|
324 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
325 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
|
326 |
+
|
327 |
+
# Generate edges based on the tree structure
|
328 |
+
def get_levels_and_edges(nodes):
|
329 |
+
levels = {}
|
330 |
+
edges = []
|
331 |
+
|
332 |
+
# Extract level info from node labels
|
333 |
+
for i, node in enumerate(nodes):
|
334 |
+
level = int(node.split()[-1][1])
|
335 |
+
levels[i] = level
|
336 |
+
|
337 |
+
# Create edges from masked sentences to their sampled variants
|
338 |
+
for masked_idx in range(num_masked):
|
339 |
+
# For each masked sentence, create edges to its sampled variants
|
340 |
+
for technique_idx in range(num_sampled_per_masked):
|
341 |
+
sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
|
342 |
+
if sampled_idx < len(nodes):
|
343 |
+
edges.append((masked_idx, sampled_idx))
|
344 |
+
|
345 |
+
return levels, edges
|
346 |
+
|
347 |
+
levels, edges = get_levels_and_edges(nodes)
|
348 |
+
|
349 |
+
# Calculate positions with improved spacing
|
350 |
+
positions = {}
|
351 |
+
|
352 |
+
# Calculate horizontal spacing for the root nodes (masked sentences)
|
353 |
+
root_x_spacing = 0 # All root nodes at x=0
|
354 |
+
root_y_spacing = 8.0 # Vertical spacing between root nodes
|
355 |
+
|
356 |
+
# Calculate positions for sampled nodes
|
357 |
+
sampled_x = 3 # X position for all sampled nodes
|
358 |
+
|
359 |
+
# Calculate y positions for root nodes (masked sentences)
|
360 |
+
root_y_start = -(num_masked - 1) * root_y_spacing / 2
|
361 |
+
for i in range(num_masked):
|
362 |
+
positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing)
|
363 |
+
|
364 |
+
# Calculate y positions for sampled nodes
|
365 |
+
for masked_idx in range(num_masked):
|
366 |
+
root_y = positions[masked_idx][1] # Y position of parent masked sentence
|
367 |
+
|
368 |
+
# Calculate y-spacing for children of this root
|
369 |
+
children_y_spacing = 1.5 # Vertical spacing between children of the same root
|
370 |
+
children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2
|
371 |
+
|
372 |
+
# Position each child
|
373 |
+
for technique_idx in range(num_sampled_per_masked):
|
374 |
+
child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
|
375 |
+
child_y = children_y_start + technique_idx * children_y_spacing
|
376 |
+
positions[child_idx] = (sampled_x, child_y)
|
377 |
+
|
378 |
+
# Create figure
|
379 |
+
fig2 = go.Figure()
|
380 |
+
|
381 |
+
# Add nodes
|
382 |
+
for i, node in enumerate(wrapped_nodes):
|
383 |
+
x, y = positions[i]
|
384 |
+
|
385 |
+
# Define node color based on level
|
386 |
+
node_color = 'blue' if levels[i] == 0 else 'green'
|
387 |
+
|
388 |
+
# Add the node marker
|
389 |
+
fig2.add_trace(go.Scatter(
|
390 |
+
x=[x],
|
391 |
+
y=[y],
|
392 |
+
mode='markers',
|
393 |
+
marker=dict(size=20, color=node_color, line=dict(color='black', width=2)),
|
394 |
+
hoverinfo='none'
|
395 |
+
))
|
396 |
+
|
397 |
+
# Add node label with highlighting
|
398 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
399 |
+
|
400 |
+
fig2.add_annotation(
|
401 |
+
x=x,
|
402 |
+
y=y,
|
403 |
+
text=colored_node,
|
404 |
+
showarrow=False,
|
405 |
+
xshift=15,
|
406 |
+
align="left",
|
407 |
+
font=dict(size=12),
|
408 |
+
bordercolor='black',
|
409 |
+
borderwidth=2,
|
410 |
+
borderpad=4,
|
411 |
+
bgcolor='white',
|
412 |
+
width=400,
|
413 |
+
height=100
|
414 |
+
)
|
415 |
+
|
416 |
+
# Add edges with labels
|
417 |
+
for i, (src, dst) in enumerate(edges):
|
418 |
+
x0, y0 = positions[src]
|
419 |
+
x1, y1 = positions[dst]
|
420 |
+
|
421 |
+
# Draw the edge
|
422 |
+
fig2.add_trace(go.Scatter(
|
423 |
+
x=[x0, x1],
|
424 |
+
y=[y0, y1],
|
425 |
+
mode='lines',
|
426 |
+
line=dict(color='black', width=1)
|
427 |
+
))
|
428 |
+
|
429 |
+
# Add sampling technique label
|
430 |
+
# Determine which sampling technique this is
|
431 |
+
parent_idx = src
|
432 |
+
technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i)
|
433 |
+
technique_label = sampling_techniques[technique_count % len(sampling_techniques)]
|
434 |
+
|
435 |
+
# Calculate midpoint for the label
|
436 |
+
mid_x = (x0 + x1) / 2
|
437 |
+
mid_y = (y0 + y1) / 2
|
438 |
+
|
439 |
+
# Add slight offset to avoid overlap
|
440 |
+
label_offset = 0.1
|
441 |
+
|
442 |
+
fig2.add_annotation(
|
443 |
+
x=mid_x,
|
444 |
+
y=mid_y + label_offset,
|
445 |
+
text=technique_label,
|
446 |
+
showarrow=False,
|
447 |
+
font=dict(size=8),
|
448 |
+
align="center"
|
449 |
+
)
|
450 |
+
|
451 |
+
# Update layout
|
452 |
+
fig2.update_layout(
|
453 |
+
showlegend=False,
|
454 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
455 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
456 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
457 |
+
width=1200, # Adjusted width to accommodate more levels
|
458 |
+
height=2000, # Adjusted height to accommodate more levels
|
459 |
+
plot_bgcolor='rgba(240,240,240,0.2)',
|
460 |
+
paper_bgcolor='white'
|
461 |
+
|
462 |
+
)
|
463 |
+
|
464 |
+
return fig2
|
465 |
+
|
466 |
+
if __name__ == "__main__":
|
467 |
+
paraphrased_sentence = "The quick brown fox jumps over the lazy dog."
|
468 |
+
masked_sentences = [
|
469 |
+
"A fast brown fox leaps over the lazy dog.",
|
470 |
+
"A quick brown fox hops over a lazy dog."
|
471 |
+
]
|
472 |
+
highlight_info = [
|
473 |
+
("quick", "red"),
|
474 |
+
("brown", "green"),
|
475 |
+
("fox", "blue"),
|
476 |
+
("lazy", "purple")
|
477 |
+
]
|
478 |
+
common_grams = [
|
479 |
+
(1, "quick brown fox"),
|
480 |
+
(2, "lazy dog")
|
481 |
+
]
|
482 |
+
|
483 |
+
fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams)
|
484 |
+
fig1.show()
|
485 |
+
|
486 |
+
sampled_sentence = ["A fast brown fox jumps over a lazy dog."]
|
487 |
+
|
488 |
+
|
489 |
+
fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams)
|
490 |
+
fig2.show()
|
utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.watermark import Watermarker
|
2 |
+
from utils.paraphraser import Paraphraser
|
3 |
+
from utils.entailment import EntailmentAnalyzer
|
4 |
+
from utils.sampling import SamplingProcessor
|
5 |
+
from utils.config import load_config
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (404 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (509 Bytes). View file
|
|
utils/__pycache__/config.cpython-310.pyc
ADDED
Binary file (594 Bytes). View file
|
|
utils/__pycache__/config.cpython-311.pyc
ADDED
Binary file (971 Bytes). View file
|
|
utils/__pycache__/entailment.cpython-310.pyc
ADDED
Binary file (3.69 kB). View file
|
|
utils/__pycache__/entailment.cpython-311.pyc
ADDED
Binary file (5.33 kB). View file
|
|
utils/__pycache__/masking_methods.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
utils/__pycache__/masking_methods.cpython-311.pyc
ADDED
Binary file (23.5 kB). View file
|
|
utils/__pycache__/non_melting_point.cpython-310.pyc
ADDED
Binary file (5.05 kB). View file
|
|
utils/__pycache__/non_melting_point.cpython-311.pyc
ADDED
Binary file (9.08 kB). View file
|
|
utils/__pycache__/paraphraser.cpython-310.pyc
ADDED
Binary file (2.85 kB). View file
|
|
utils/__pycache__/paraphraser.cpython-311.pyc
ADDED
Binary file (4.89 kB). View file
|
|
utils/__pycache__/sampling.cpython-310.pyc
ADDED
Binary file (5.06 kB). View file
|
|
utils/__pycache__/sampling.cpython-311.pyc
ADDED
Binary file (9.2 kB). View file
|
|
utils/__pycache__/watermark.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
utils/__pycache__/watermark.cpython-311.pyc
ADDED
Binary file (20.1 kB). View file
|
|
utils/config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file loads config from config.yaml
|
3 |
+
"""
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
def load_config(path):
|
8 |
+
"""
|
9 |
+
Function to load config from config.yaml
|
10 |
+
"""
|
11 |
+
try:
|
12 |
+
with open(path, "r") as file:
|
13 |
+
config = yaml.safe_load(file)
|
14 |
+
return config
|
15 |
+
except FileNotFoundError:
|
16 |
+
raise FileNotFoundError("Config file not found")
|
17 |
+
except Exception as e:
|
18 |
+
raise e
|
utils/config.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is the official config file.
|
2 |
+
PECCAVI_TEXT:
|
3 |
+
Entailment:
|
4 |
+
task: "text-classification"
|
5 |
+
model: "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
|
6 |
+
|
7 |
+
Masking:
|
8 |
+
task: "fill-mask"
|
9 |
+
tokenizer: "bert-base-uncased"
|
10 |
+
model: "bert-base-uncased"
|
11 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
12 |
+
# model: "bert-large-cased-whole-word-masking"
|
13 |
+
|
14 |
+
Vocabulary:
|
15 |
+
tokenizer: "bert-base-uncased"
|
16 |
+
model: "bert-base-uncased"
|
17 |
+
# permissible_ratio: 0.5
|
18 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
19 |
+
# model: "bert-large-cased-whole-word-masking"
|
20 |
+
permissible_ratio: 1.0
|
21 |
+
|
22 |
+
Sampling:
|
23 |
+
tokenizer: "bert-base-uncased"
|
24 |
+
model: "bert-base-uncased"
|
25 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
26 |
+
# model: "bert-large-cased-whole-word-masking"
|
27 |
+
|
28 |
+
Metrics:
|
29 |
+
EuclideanDistance: "sentence-transformers/all-MiniLM-L6-v2"
|
30 |
+
Distortion: "gpt2"
|
31 |
+
|
32 |
+
Detector:
|
33 |
+
tokenizer: "bert-base-uncased"
|
34 |
+
model: "bert-base-uncased"
|
35 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
36 |
+
# model: "bert-large-cased-whole-word-masking"
|
37 |
+
|
38 |
+
Paraphrase:
|
39 |
+
tokenizer: "humarin/chatgpt_paraphraser_on_T5_base"
|
40 |
+
model: "humarin/chatgpt_paraphraser_on_T5_base"
|
41 |
+
num_beams: 10
|
42 |
+
num_beam_groups: 10
|
43 |
+
num_return_sequences: 10
|
44 |
+
repetition_penalty: 10.0
|
45 |
+
diversity_penalty: 3.0
|
46 |
+
no_repeat_ngram_size: 2
|
47 |
+
temperature: 0.7
|
48 |
+
max_length: 64
|
utils/entailment.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from transformers import pipeline
|
7 |
+
from typing import List
|
8 |
+
from utils.config import load_config
|
9 |
+
|
10 |
+
|
11 |
+
class EntailmentAnalyzer:
|
12 |
+
# def __init__(self, config_path: str):
|
13 |
+
def __init__(self, config):
|
14 |
+
"""
|
15 |
+
Initialize the EntailmentAnalyzer with the config file path.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
config_path: The path to the configuration file.
|
19 |
+
"""
|
20 |
+
# self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment']
|
21 |
+
self.config = config
|
22 |
+
self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model'])
|
23 |
+
|
24 |
+
def check_entailment(self, premise: str, hypothesis: str) -> float:
|
25 |
+
"""
|
26 |
+
Check entailment between the premise and hypothesis.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
premise: The premise sentence.
|
30 |
+
hypothesis: The hypothesis sentence.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
float: The entailment score.
|
34 |
+
"""
|
35 |
+
results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None)
|
36 |
+
entailment_score = next(item['score'] for item in results if item['label'] == 'entailment')
|
37 |
+
return entailment_score
|
38 |
+
|
39 |
+
def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple:
|
40 |
+
"""
|
41 |
+
Analyze entailment scores for paraphrased sentences. If no selected sentences are found,
|
42 |
+
lower the threshold and rerun the analysis.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
original_sentence: The original sentence.
|
46 |
+
paraphrased_sentences: List of paraphrased sentences.
|
47 |
+
threshold: Minimum score to select a sentence.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
tuple: A dictionary of all scores, selected sentences, and discarded sentences.
|
51 |
+
"""
|
52 |
+
all_sentences = {}
|
53 |
+
selected_sentences = {}
|
54 |
+
discarded_sentences = {}
|
55 |
+
|
56 |
+
# Loop to reduce threshold if no sentences are selected
|
57 |
+
while not selected_sentences:
|
58 |
+
for paraphrased_sentence in paraphrased_sentences:
|
59 |
+
entailment_score = self.check_entailment(original_sentence, paraphrased_sentence)
|
60 |
+
|
61 |
+
all_sentences[paraphrased_sentence] = entailment_score
|
62 |
+
if entailment_score >= threshold:
|
63 |
+
selected_sentences[paraphrased_sentence] = entailment_score
|
64 |
+
else:
|
65 |
+
discarded_sentences[paraphrased_sentence] = entailment_score
|
66 |
+
|
67 |
+
# If no sentences are selected, lower the threshold
|
68 |
+
if not selected_sentences:
|
69 |
+
print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).")
|
70 |
+
threshold -= 0.1
|
71 |
+
if threshold <= 0:
|
72 |
+
print("Threshold has reached 0. No sentences meet the criteria.")
|
73 |
+
break
|
74 |
+
|
75 |
+
return all_sentences, selected_sentences, discarded_sentences
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
|
80 |
+
|
81 |
+
config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml'
|
82 |
+
|
83 |
+
config = load_config(config_path)
|
84 |
+
|
85 |
+
entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment'])
|
86 |
+
|
87 |
+
all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment(
|
88 |
+
"The weather is nice today",
|
89 |
+
[
|
90 |
+
"The climate is pleasant today",
|
91 |
+
"It's a good day weather-wise",
|
92 |
+
"Today, the weather is terrible",
|
93 |
+
"What a beautiful day it is",
|
94 |
+
"The sky is clear and the weather is perfect",
|
95 |
+
"It's pouring rain outside today",
|
96 |
+
"The weather isn't bad today",
|
97 |
+
"A lovely day for outdoor activities"
|
98 |
+
],
|
99 |
+
0.7
|
100 |
+
)
|
101 |
+
|
102 |
+
print("----------------------- All Sentences -----------------------")
|
103 |
+
print(all_sentences)
|
104 |
+
print("----------------------- Discarded Sentences -----------------------")
|
105 |
+
print(discarded_sentences)
|
106 |
+
print("----------------------- Selected Sentences -----------------------")
|
107 |
+
print(selected_sentences)
|
utils/masking_methods.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
5 |
+
from nltk.corpus import stopwords
|
6 |
+
import nltk
|
7 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
# Set logging to WARNING for a cleaner terminal.
|
11 |
+
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Ensure stopwords are downloaded
|
15 |
+
try:
|
16 |
+
nltk.data.find('corpora/stopwords')
|
17 |
+
except LookupError:
|
18 |
+
nltk.download('stopwords')
|
19 |
+
|
20 |
+
class MaskingProcessor:
|
21 |
+
def __init__(self, tokenizer, model):
|
22 |
+
self.tokenizer = tokenizer
|
23 |
+
self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
self.stop_words = set(stopwords.words('english'))
|
26 |
+
tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}")
|
27 |
+
|
28 |
+
def remove_stopwords(self, words):
|
29 |
+
return [word for word in words if word.lower() not in self.stop_words]
|
30 |
+
|
31 |
+
def adjust_ngram_indices(self, original_words, common_ngrams):
|
32 |
+
logger.info("Adjusting n-gram indices.")
|
33 |
+
non_stop_words = self.remove_stopwords(original_words)
|
34 |
+
original_to_non_stop = []
|
35 |
+
non_stop_idx = 0
|
36 |
+
for original_idx, word in enumerate(original_words):
|
37 |
+
if word.lower() not in self.stop_words:
|
38 |
+
original_to_non_stop.append((original_idx, non_stop_idx))
|
39 |
+
non_stop_idx += 1
|
40 |
+
adjusted_ngrams = {}
|
41 |
+
for ngram, positions in common_ngrams.items():
|
42 |
+
adjusted_positions = []
|
43 |
+
for start, end in positions:
|
44 |
+
try:
|
45 |
+
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
|
46 |
+
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
|
47 |
+
adjusted_positions.append((new_start, new_end))
|
48 |
+
except StopIteration:
|
49 |
+
continue
|
50 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
51 |
+
return adjusted_ngrams
|
52 |
+
|
53 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
54 |
+
tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}")
|
55 |
+
original_words = sentence.split()
|
56 |
+
has_punctuation = False
|
57 |
+
punctuation = ''
|
58 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
59 |
+
has_punctuation = True
|
60 |
+
punctuation = original_words[-1][-1]
|
61 |
+
original_words = original_words[:-1]
|
62 |
+
|
63 |
+
non_stop_words = self.remove_stopwords(original_words)
|
64 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
65 |
+
mask_indices = []
|
66 |
+
|
67 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
68 |
+
if ngram_positions:
|
69 |
+
first_ngram_start = ngram_positions[0][0]
|
70 |
+
if first_ngram_start > 0:
|
71 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
72 |
+
mask_indices.append(mask_index_before_ngram)
|
73 |
+
|
74 |
+
for i in range(len(ngram_positions) - 1):
|
75 |
+
end_prev = ngram_positions[i][1]
|
76 |
+
start_next = ngram_positions[i + 1][0]
|
77 |
+
if start_next > end_prev + 1:
|
78 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
79 |
+
mask_indices.append(mask_index_between_ngrams)
|
80 |
+
|
81 |
+
last_ngram_end = ngram_positions[-1][1]
|
82 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
83 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
84 |
+
mask_indices.append(mask_index_after_ngram)
|
85 |
+
|
86 |
+
non_stop_to_original = {}
|
87 |
+
non_stop_idx = 0
|
88 |
+
for orig_idx, word in enumerate(original_words):
|
89 |
+
if word.lower() not in self.stop_words:
|
90 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
91 |
+
non_stop_idx += 1
|
92 |
+
|
93 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
94 |
+
masked_words = original_words.copy()
|
95 |
+
for idx in original_mask_indices:
|
96 |
+
masked_words[idx] = self.tokenizer.mask_token
|
97 |
+
|
98 |
+
if has_punctuation:
|
99 |
+
masked_words.append(punctuation)
|
100 |
+
|
101 |
+
logger.info(f"Masked sentence (random): {' '.join(masked_words)}")
|
102 |
+
return " ".join(masked_words), original_mask_indices
|
103 |
+
|
104 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
105 |
+
logger.info(f"Masking sentence using pseudorandom strategy: {sentence}")
|
106 |
+
random.seed(3)
|
107 |
+
original_words = sentence.split()
|
108 |
+
has_punctuation = False
|
109 |
+
punctuation = ''
|
110 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
111 |
+
has_punctuation = True
|
112 |
+
punctuation = original_words[-1][-1]
|
113 |
+
original_words = original_words[:-1]
|
114 |
+
|
115 |
+
non_stop_words = self.remove_stopwords(original_words)
|
116 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
117 |
+
mask_indices = []
|
118 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
119 |
+
|
120 |
+
if ngram_positions:
|
121 |
+
first_ngram_start = ngram_positions[0][0]
|
122 |
+
if first_ngram_start > 0:
|
123 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
124 |
+
mask_indices.append(mask_index_before_ngram)
|
125 |
+
for i in range(len(ngram_positions) - 1):
|
126 |
+
end_prev = ngram_positions[i][1]
|
127 |
+
start_next = ngram_positions[i + 1][0]
|
128 |
+
if start_next > end_prev + 1:
|
129 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
130 |
+
mask_indices.append(mask_index_between_ngrams)
|
131 |
+
last_ngram_end = ngram_positions[-1][1]
|
132 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
133 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
134 |
+
mask_indices.append(mask_index_after_ngram)
|
135 |
+
|
136 |
+
non_stop_to_original = {}
|
137 |
+
non_stop_idx = 0
|
138 |
+
for orig_idx, word in enumerate(original_words):
|
139 |
+
if word.lower() not in self.stop_words:
|
140 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
141 |
+
non_stop_idx += 1
|
142 |
+
|
143 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
144 |
+
masked_words = original_words.copy()
|
145 |
+
for idx in original_mask_indices:
|
146 |
+
masked_words[idx] = self.tokenizer.mask_token
|
147 |
+
|
148 |
+
if has_punctuation:
|
149 |
+
masked_words.append(punctuation)
|
150 |
+
|
151 |
+
logger.info(f"Masked sentence (pseudorandom): {' '.join(masked_words)}")
|
152 |
+
return " ".join(masked_words), original_mask_indices
|
153 |
+
|
154 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
155 |
+
logger.info(f"Masking sentence using entropy strategy: {sentence}")
|
156 |
+
original_words = sentence.split()
|
157 |
+
has_punctuation = False
|
158 |
+
punctuation = ''
|
159 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
160 |
+
has_punctuation = True
|
161 |
+
punctuation = original_words[-1][-1]
|
162 |
+
original_words = original_words[:-1]
|
163 |
+
|
164 |
+
non_stop_words = self.remove_stopwords(original_words)
|
165 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
166 |
+
mask_indices = []
|
167 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
168 |
+
non_stop_to_original = {}
|
169 |
+
non_stop_idx = 0
|
170 |
+
for orig_idx, word in enumerate(original_words):
|
171 |
+
if word.lower() not in self.stop_words:
|
172 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
173 |
+
non_stop_idx += 1
|
174 |
+
|
175 |
+
if ngram_positions:
|
176 |
+
first_ngram_start = ngram_positions[0][0]
|
177 |
+
if first_ngram_start > 0:
|
178 |
+
candidate_positions = range(0, first_ngram_start)
|
179 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
|
180 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
181 |
+
for i in range(len(ngram_positions) - 1):
|
182 |
+
end_prev = ngram_positions[i][1]
|
183 |
+
start_next = ngram_positions[i + 1][0]
|
184 |
+
if start_next > end_prev + 1:
|
185 |
+
candidate_positions = range(end_prev + 1, start_next)
|
186 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
|
187 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
188 |
+
last_ngram_end = ngram_positions[-1][1]
|
189 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
190 |
+
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
|
191 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
|
192 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
193 |
+
|
194 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
195 |
+
masked_words = original_words.copy()
|
196 |
+
for idx in original_mask_indices:
|
197 |
+
masked_words[idx] = self.tokenizer.mask_token
|
198 |
+
|
199 |
+
if has_punctuation:
|
200 |
+
masked_words.append(punctuation)
|
201 |
+
|
202 |
+
logger.info(f"Masked sentence (entropy): {' '.join(masked_words)}")
|
203 |
+
return " ".join(masked_words), original_mask_indices
|
204 |
+
|
205 |
+
def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
206 |
+
logger.info(f"Calculating mask logits for sentence: {original_sentence}")
|
207 |
+
words = original_sentence.split()
|
208 |
+
mask_logits = {}
|
209 |
+
for idx in original_mask_indices:
|
210 |
+
masked_words = words.copy()
|
211 |
+
masked_words[idx] = self.tokenizer.mask_token
|
212 |
+
masked_sentence = " ".join(masked_words)
|
213 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
|
214 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
215 |
+
with torch.no_grad():
|
216 |
+
outputs = self.model(input_ids)
|
217 |
+
logits = outputs.logits
|
218 |
+
mask_logits_tensor = logits[0, mask_token_index, :]
|
219 |
+
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1)
|
220 |
+
top_tokens = []
|
221 |
+
top_logits = []
|
222 |
+
seen_words = set()
|
223 |
+
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
|
224 |
+
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
|
225 |
+
if token.startswith('##'):
|
226 |
+
continue
|
227 |
+
word = self.tokenizer.convert_tokens_to_string([token]).strip()
|
228 |
+
if word and word not in seen_words:
|
229 |
+
seen_words.add(word)
|
230 |
+
top_tokens.append(word)
|
231 |
+
top_logits.append(logit.item())
|
232 |
+
if len(top_tokens) == 50:
|
233 |
+
break
|
234 |
+
mask_logits[idx] = {
|
235 |
+
"tokens": top_tokens,
|
236 |
+
"logits": top_logits
|
237 |
+
}
|
238 |
+
logger.info("Completed calculating mask logits.")
|
239 |
+
return mask_logits
|
240 |
+
|
241 |
+
def calculate_word_entropy(self, sentence, word_position):
|
242 |
+
logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}")
|
243 |
+
words = sentence.split()
|
244 |
+
masked_words = words.copy()
|
245 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
246 |
+
masked_sentence = " ".join(masked_words)
|
247 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
|
248 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
249 |
+
with torch.no_grad():
|
250 |
+
outputs = self.model(input_ids)
|
251 |
+
logits = outputs.logits
|
252 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
253 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
254 |
+
logger.info(f"Computed entropy: {entropy.item()}")
|
255 |
+
return entropy.item()
|
256 |
+
|
257 |
+
def process_sentences(self, sentences_list, common_grams, method="random"):
|
258 |
+
tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}")
|
259 |
+
results = {}
|
260 |
+
for sentence, ngrams in tqdm(common_grams.items(), desc="Masking Sentences"):
|
261 |
+
words = sentence.split()
|
262 |
+
last_word = words[-1]
|
263 |
+
if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
264 |
+
words[-1] = last_word[:-1]
|
265 |
+
punctuation = last_word[-1]
|
266 |
+
processed_sentence = " ".join(words) + " " + punctuation
|
267 |
+
else:
|
268 |
+
processed_sentence = sentence
|
269 |
+
|
270 |
+
if method == "random":
|
271 |
+
masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
|
272 |
+
elif method == "pseudorandom":
|
273 |
+
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
|
274 |
+
else: # entropy
|
275 |
+
masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
|
276 |
+
|
277 |
+
logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
|
278 |
+
results[sentence] = {
|
279 |
+
"masked_sentence": masked_sentence,
|
280 |
+
"mask_logits": logits
|
281 |
+
}
|
282 |
+
logger.info(f"Processed sentence: {sentence}")
|
283 |
+
tqdm.write("[MaskingProcessor] Completed processing sentences.")
|
284 |
+
return results
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
sentences = [
|
288 |
+
"The quick brown fox jumps over small cat the lazy dog everyday again and again .",
|
289 |
+
]
|
290 |
+
result_dict = {
|
291 |
+
'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {
|
292 |
+
'brown fox': [(2, 3)],
|
293 |
+
'cat': [(7, 7)],
|
294 |
+
'dog': [(10, 10)]
|
295 |
+
}
|
296 |
+
}
|
297 |
+
processor = MaskingProcessor(
|
298 |
+
BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"),
|
299 |
+
BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
300 |
+
)
|
301 |
+
results_entropy = processor.process_sentences(sentences_list, common_grams, method="random")
|
302 |
+
for sentence, output in results_entropy.items():
|
303 |
+
logger.info(f"Original Sentence (Random): {sentence}")
|
304 |
+
logger.info(f"Masked Sentence (Random): {output['masked_sentence']}")
|
utils/non_melting_point.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import logging
|
3 |
+
from nltk.corpus import stopwords
|
4 |
+
from nltk.util import ngrams
|
5 |
+
from collections import Counter
|
6 |
+
import re
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
# Set logging to WARNING for minimal console output.
|
10 |
+
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class NgramProcessor:
|
14 |
+
def __init__(self):
|
15 |
+
try:
|
16 |
+
nltk.data.find('corpora/stopwords')
|
17 |
+
except LookupError:
|
18 |
+
nltk.download('stopwords')
|
19 |
+
self.stop_words = set(stopwords.words('english'))
|
20 |
+
tqdm.write("[NgramProcessor] Initialized with stopwords.")
|
21 |
+
|
22 |
+
def remove_stopwords(self, text):
|
23 |
+
# No need for extensive logging inside this helper.
|
24 |
+
words = re.findall(r'\w+', text.lower())
|
25 |
+
filtered_words = [word for word in words if word not in self.stop_words]
|
26 |
+
return ' '.join(filtered_words)
|
27 |
+
|
28 |
+
def is_exact_match(self, ngram, sentences):
|
29 |
+
logger.info(f"Checking exact match for ngram: {ngram}")
|
30 |
+
result = all(ngram in sentence for sentence in sentences)
|
31 |
+
logger.info(f"Exact match result for '{ngram}': {result}")
|
32 |
+
return result
|
33 |
+
|
34 |
+
def is_substring_of_any(self, ngram, common_ngrams):
|
35 |
+
logger.info(f"Checking if ngram: {ngram} is substring of any common ngram.")
|
36 |
+
result = any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
|
37 |
+
logger.info(f"Substring check result for '{ngram}': {result}")
|
38 |
+
return result
|
39 |
+
|
40 |
+
def find_filtered_ngrams(self, sentences):
|
41 |
+
from collections import Counter
|
42 |
+
tqdm.write("[NgramProcessor] Cleaning sentences...")
|
43 |
+
sentences_cleaned = [self.remove_stopwords(sentence)
|
44 |
+
for sentence in tqdm(sentences, desc="Cleaning Sentences")]
|
45 |
+
ngram_lengths = [4, 3, 2, 1]
|
46 |
+
common_ngrams = []
|
47 |
+
result = {}
|
48 |
+
for n in ngram_lengths:
|
49 |
+
ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences_cleaned]
|
50 |
+
ngrams_counter = Counter(ngrams_list[0])
|
51 |
+
for ngram in ngrams_counter:
|
52 |
+
ngram_str = ' '.join(ngram)
|
53 |
+
if any(word in self.stop_words for word in ngram_str.split()):
|
54 |
+
continue
|
55 |
+
if self.is_exact_match(ngram_str, sentences_cleaned) and not self.is_substring_of_any(ngram_str, common_ngrams):
|
56 |
+
common_ngrams.append(ngram_str)
|
57 |
+
for sentence, cleaned_sentence in tqdm(zip(sentences, sentences_cleaned),
|
58 |
+
total=len(sentences),
|
59 |
+
desc="Mapping N-grams"):
|
60 |
+
sentence_result = {}
|
61 |
+
original_words = sentence.split()
|
62 |
+
cleaned_words = cleaned_sentence.split()
|
63 |
+
index_map = {}
|
64 |
+
cleaned_idx = 0
|
65 |
+
for orig_idx, word in enumerate(original_words):
|
66 |
+
if word.lower() not in self.stop_words:
|
67 |
+
index_map[cleaned_idx] = orig_idx
|
68 |
+
cleaned_idx += 1
|
69 |
+
for ngram in common_ngrams:
|
70 |
+
ngram_words = ngram.split()
|
71 |
+
indices = []
|
72 |
+
for i in range(len(cleaned_words) - len(ngram_words) + 1):
|
73 |
+
if cleaned_words[i:i + len(ngram_words)] == ngram_words:
|
74 |
+
if i in index_map:
|
75 |
+
start_idx = index_map[i]
|
76 |
+
end_idx = index_map.get(i + len(ngram_words) - 1, start_idx)
|
77 |
+
if end_idx - start_idx == len(ngram_words) - 1:
|
78 |
+
indices.append((start_idx, end_idx))
|
79 |
+
|
80 |
+
if indices:
|
81 |
+
sentence_result[ngram] = indices
|
82 |
+
result[sentence] = sentence_result
|
83 |
+
return result
|
84 |
+
|
85 |
+
# def find_relative_order(self, sentence, common_ngrams):
|
86 |
+
# from tqdm import tqdm
|
87 |
+
# relative_order = []
|
88 |
+
# for ngram in tqdm(common_ngrams, desc="Ordering N-grams", leave=False):
|
89 |
+
# index = sentence.find(ngram)
|
90 |
+
# if index != -1:
|
91 |
+
# relative_order.append((index, ngram))
|
92 |
+
# return sorted(relative_order)
|
93 |
+
|
94 |
+
def find_relative_order(self, sentence, common_ngrams):
|
95 |
+
from tqdm import tqdm
|
96 |
+
sentence = sentence.lower()
|
97 |
+
relative_order = []
|
98 |
+
|
99 |
+
for ngram in tqdm(common_ngrams, desc="Ordering N-grams", leave=False):
|
100 |
+
index = sentence.find(ngram.lower())
|
101 |
+
if index != -1:
|
102 |
+
relative_order.append((index, ngram))
|
103 |
+
|
104 |
+
sorted_pairs = sorted(relative_order)
|
105 |
+
return [(i+1, ngram) for i, (_, ngram) in enumerate(sorted_pairs)]
|
106 |
+
|
107 |
+
# Example usage
|
108 |
+
if __name__ == "__main__":
|
109 |
+
sentences = [
|
110 |
+
"The quick brown fox jumps over the lazy dog .",
|
111 |
+
"A speedy brown fox jumps over a lazy dog.",
|
112 |
+
"A swift brown fox leaps over the lethargic dog.",
|
113 |
+
]
|
114 |
+
processor = NgramProcessor()
|
115 |
+
common_ngrams = processor.find_filtered_ngrams(sentences)
|
116 |
+
print(common_ngrams)
|
117 |
+
# modified_output = list({
|
118 |
+
# (indices[0][0], gram)
|
119 |
+
# for grams in common_ngrams.values()
|
120 |
+
# for gram, indices in grams.items()
|
121 |
+
# })
|
122 |
+
# print(modified_output)
|
123 |
+
logger.info(f"Common n-grams and their indices per sentence: {common_ngrams}")
|
124 |
+
for sentence in sentences:
|
125 |
+
order = processor.find_relative_order(sentence, common_ngrams[sentence])
|
126 |
+
logger.info(f"Sentence: {sentence} -> Order: {order}")
|
127 |
+
|
128 |
+
|
129 |
+
"""
|
130 |
+
|
131 |
+
{
|
132 |
+
'The quick brown fox jumps over the lazy dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]},
|
133 |
+
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]},
|
134 |
+
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]}
|
135 |
+
}
|
136 |
+
"""
|
137 |
+
|
utils/old/masking/masking_methods.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
# Ensure stopwords are downloaded
|
8 |
+
try:
|
9 |
+
nltk.data.find('corpora/stopwords')
|
10 |
+
except LookupError:
|
11 |
+
nltk.download('stopwords')
|
12 |
+
|
13 |
+
class MaskingProcessor:
|
14 |
+
def __init__(self, ):
|
15 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
16 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
17 |
+
self.stop_words = set(stopwords.words('english'))
|
18 |
+
|
19 |
+
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
|
20 |
+
"""
|
21 |
+
Adjust indices of common n-grams after removing stop words.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
words (list): List of words in the original sentence.
|
25 |
+
common_ngrams (dict): Common n-grams and their indices.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
dict: Adjusted common n-grams and their indices.
|
29 |
+
"""
|
30 |
+
if not remove_stopwords:
|
31 |
+
return common_ngrams
|
32 |
+
|
33 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
34 |
+
adjusted_ngrams = {}
|
35 |
+
|
36 |
+
for ngram, positions in common_ngrams.items():
|
37 |
+
adjusted_positions = []
|
38 |
+
for start, end in positions:
|
39 |
+
try:
|
40 |
+
new_start = non_stop_word_indices.index(start)
|
41 |
+
new_end = non_stop_word_indices.index(end)
|
42 |
+
adjusted_positions.append((new_start, new_end))
|
43 |
+
except ValueError:
|
44 |
+
continue # Skip if indices cannot be mapped
|
45 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
46 |
+
|
47 |
+
return adjusted_ngrams
|
48 |
+
|
49 |
+
# def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
50 |
+
# """
|
51 |
+
# Mask one word before the first common n-gram, one between two n-grams,
|
52 |
+
# and one after the last common n-gram (random selection).
|
53 |
+
|
54 |
+
# Args:
|
55 |
+
# original_sentence (str): Original sentence
|
56 |
+
# common_ngrams (dict): Common n-grams and their indices
|
57 |
+
|
58 |
+
# Returns:
|
59 |
+
# str: Masked sentence with original stop words retained
|
60 |
+
# """
|
61 |
+
# words = original_sentence.split()
|
62 |
+
# if remove_stopwords:
|
63 |
+
# non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
64 |
+
# non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
65 |
+
# else:
|
66 |
+
# non_stop_words = words
|
67 |
+
# non_stop_word_indices = list(range(len(words)))
|
68 |
+
# # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
69 |
+
# adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
70 |
+
|
71 |
+
# mask_indices = []
|
72 |
+
# # Handle before the first common n-gram
|
73 |
+
# if adjusted_ngrams:
|
74 |
+
# first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
75 |
+
# if first_ngram_start > 0:
|
76 |
+
# mask_indices.append(random.randint(0, first_ngram_start - 1))
|
77 |
+
|
78 |
+
# # Handle between common n-grams
|
79 |
+
# ngram_positions = list(adjusted_ngrams.values())
|
80 |
+
# for i in range(len(ngram_positions) - 1):
|
81 |
+
# end_prev = ngram_positions[i][-1][1]
|
82 |
+
# start_next = ngram_positions[i + 1][0][0]
|
83 |
+
# if start_next > end_prev + 1:
|
84 |
+
# mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
85 |
+
|
86 |
+
# # Handle after the last common n-gram
|
87 |
+
# last_ngram_end = ngram_positions[-1][-1][1]
|
88 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
89 |
+
# mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
|
90 |
+
|
91 |
+
# # Mask the chosen indices
|
92 |
+
# original_masked_sentence = words[:]
|
93 |
+
# # for idx in mask_indices:
|
94 |
+
# # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
95 |
+
# # non_stop_words[idx] = self.tokenizer.mask_token
|
96 |
+
# # original_masked_sentence[idx] = self.tokenizer.mask_token
|
97 |
+
# for idx in mask_indices:
|
98 |
+
# if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
99 |
+
# continue # Skip if index belongs to common n-grams
|
100 |
+
# if remove_stopwords:
|
101 |
+
# original_idx = non_stop_word_indices[idx] # Map back to original indices
|
102 |
+
# original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
103 |
+
# else:
|
104 |
+
# original_masked_sentence[idx] = self.tokenizer.mask_token
|
105 |
+
|
106 |
+
|
107 |
+
# return " ".join(original_masked_sentence)
|
108 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
109 |
+
"""
|
110 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
111 |
+
and one after the last common n-gram (random selection).
|
112 |
+
|
113 |
+
Args:
|
114 |
+
original_sentence (str): Original sentence
|
115 |
+
common_ngrams (dict): Common n-grams and their indices
|
116 |
+
remove_stopwords (bool): Whether to remove stop words
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
str: Masked sentence with original stop words retained
|
120 |
+
"""
|
121 |
+
words = original_sentence.split()
|
122 |
+
if remove_stopwords:
|
123 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
124 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
125 |
+
else:
|
126 |
+
non_stop_words = words
|
127 |
+
non_stop_word_indices = list(range(len(words)))
|
128 |
+
|
129 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
130 |
+
|
131 |
+
# Collect all indices corresponding to common n-grams
|
132 |
+
common_ngram_indices = {
|
133 |
+
idx for ngram_positions in adjusted_ngrams.values()
|
134 |
+
for start, end in ngram_positions
|
135 |
+
for idx in range(start, end + 1)
|
136 |
+
}
|
137 |
+
|
138 |
+
mask_indices = []
|
139 |
+
# Handle before the first common n-gram
|
140 |
+
if adjusted_ngrams:
|
141 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
142 |
+
if first_ngram_start > 0:
|
143 |
+
potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
|
144 |
+
if potential_indices:
|
145 |
+
mask_indices.append(random.choice(potential_indices))
|
146 |
+
|
147 |
+
# Handle between common n-grams
|
148 |
+
ngram_positions = list(adjusted_ngrams.values())
|
149 |
+
for i in range(len(ngram_positions) - 1):
|
150 |
+
end_prev = ngram_positions[i][-1][1]
|
151 |
+
start_next = ngram_positions[i + 1][0][0]
|
152 |
+
potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
|
153 |
+
if potential_indices:
|
154 |
+
mask_indices.append(random.choice(potential_indices))
|
155 |
+
|
156 |
+
# Handle after the last common n-gram
|
157 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
158 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
159 |
+
potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
|
160 |
+
if potential_indices:
|
161 |
+
mask_indices.append(random.choice(potential_indices))
|
162 |
+
|
163 |
+
# Mask the chosen indices
|
164 |
+
original_masked_sentence = words[:]
|
165 |
+
for idx in mask_indices:
|
166 |
+
if remove_stopwords:
|
167 |
+
original_idx = non_stop_word_indices[idx] # Map back to original indices
|
168 |
+
original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
169 |
+
else:
|
170 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
171 |
+
|
172 |
+
return " ".join(original_masked_sentence)
|
173 |
+
|
174 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
|
175 |
+
"""
|
176 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
177 |
+
and one after the last common n-gram (highest entropy selection).
|
178 |
+
|
179 |
+
Args:
|
180 |
+
original_sentence (str): Original sentence
|
181 |
+
common_ngrams (dict): Common n-grams and their indices
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
str: Masked sentence with original stop words retained
|
185 |
+
"""
|
186 |
+
words = original_sentence.split()
|
187 |
+
# non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
188 |
+
if remove_stopwords:
|
189 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
190 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
191 |
+
else:
|
192 |
+
non_stop_words = words
|
193 |
+
non_stop_word_indices = list(range(len(words)))
|
194 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
195 |
+
entropy_scores = {}
|
196 |
+
|
197 |
+
for idx, word in enumerate(non_stop_words):
|
198 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
199 |
+
continue # Skip words in common n-grams
|
200 |
+
|
201 |
+
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
|
202 |
+
masked_sentence = " ".join(masked_sentence)
|
203 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
204 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
205 |
+
|
206 |
+
with torch.no_grad():
|
207 |
+
outputs = self.model(input_ids)
|
208 |
+
logits = outputs.logits
|
209 |
+
|
210 |
+
filtered_logits = logits[0, mask_token_index, :]
|
211 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
212 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
213 |
+
entropy_scores[idx] = entropy
|
214 |
+
|
215 |
+
mask_indices = []
|
216 |
+
|
217 |
+
# Handle before the first common n-gram
|
218 |
+
if adjusted_ngrams:
|
219 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
220 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
221 |
+
if candidates:
|
222 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
223 |
+
|
224 |
+
# Handle between common n-grams
|
225 |
+
ngram_positions = list(adjusted_ngrams.values())
|
226 |
+
for i in range(len(ngram_positions) - 1):
|
227 |
+
end_prev = ngram_positions[i][-1][1]
|
228 |
+
start_next = ngram_positions[i + 1][0][0]
|
229 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
230 |
+
if candidates:
|
231 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
232 |
+
|
233 |
+
# Handle after the last common n-gram
|
234 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
235 |
+
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
|
236 |
+
if candidates:
|
237 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
238 |
+
|
239 |
+
# Mask the chosen indices
|
240 |
+
original_masked_sentence = words[:]
|
241 |
+
# for idx in mask_indices:
|
242 |
+
# non_stop_words[idx] = self.tokenizer.mask_token
|
243 |
+
# original_masked_sentence[idx] = self.tokenizer.mask_token
|
244 |
+
|
245 |
+
for idx in mask_indices:
|
246 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
247 |
+
continue # Skip if index belongs to common n-grams
|
248 |
+
if remove_stopwords:
|
249 |
+
original_idx = non_stop_word_indices[idx] # Map back to original indices
|
250 |
+
original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
251 |
+
else:
|
252 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
253 |
+
|
254 |
+
|
255 |
+
return " ".join(original_masked_sentence)
|
256 |
+
|
257 |
+
def calculate_mask_logits(self, masked_sentence):
|
258 |
+
"""
|
259 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
dict: Masked token indices and their logits
|
266 |
+
"""
|
267 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
268 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
269 |
+
|
270 |
+
with torch.no_grad():
|
271 |
+
outputs = self.model(input_ids)
|
272 |
+
logits = outputs.logits
|
273 |
+
|
274 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
275 |
+
return mask_logits
|
276 |
+
|
277 |
+
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
|
278 |
+
"""
|
279 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
original_sentences (list): List of original sentences
|
283 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
284 |
+
method (str): Masking method ("random" or "entropy")
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
dict: Masked sentences and their logits for each sentence
|
288 |
+
"""
|
289 |
+
results = {}
|
290 |
+
|
291 |
+
for sentence, ngrams in result_dict.items():
|
292 |
+
if method == "random":
|
293 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
|
294 |
+
elif method == "entropy":
|
295 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
|
296 |
+
else:
|
297 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
298 |
+
|
299 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
300 |
+
results[sentence] = {
|
301 |
+
"masked_sentence": masked_sentence,
|
302 |
+
"mask_logits": logits
|
303 |
+
}
|
304 |
+
|
305 |
+
return results
|
306 |
+
|
307 |
+
# Example usage
|
308 |
+
if __name__ == "__main__":
|
309 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
310 |
+
sentences = [
|
311 |
+
"The quick brown fox jumps over the lazy dog.",
|
312 |
+
"A speedy brown fox jumps over a lazy dog.",
|
313 |
+
"A swift brown fox leaps over the lethargic dog."
|
314 |
+
]
|
315 |
+
result_dict ={
|
316 |
+
'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
317 |
+
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
318 |
+
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
319 |
+
}
|
320 |
+
|
321 |
+
|
322 |
+
processor = MaskingProcessor()
|
323 |
+
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
|
324 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
325 |
+
|
326 |
+
for sentence, output in results_random.items():
|
327 |
+
print(f"Original Sentence (Random): {sentence}")
|
328 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
329 |
+
# # print(f"Mask Logits (Random): {output['mask_logits']}")
|
330 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
331 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
332 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
333 |
+
print('--------------------------------')
|
334 |
+
# for mask_idx, logits in output["mask_logits"].items():
|
335 |
+
# print(f"Logits for [MASK] at position {mask_idx}:")
|
336 |
+
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
# result_dict = {
|
342 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
343 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
344 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
345 |
+
# }
|
346 |
+
|
347 |
+
|
348 |
+
# print('--------------------------------')
|
349 |
+
# for sentence, output in results_entropy.items():
|
350 |
+
# print(f"Original Sentence (Entropy): {sentence}")
|
351 |
+
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
352 |
+
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
353 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
354 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
355 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
utils/old/masking/masking_methods_new_work.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
# Ensure stopwords are downloaded
|
8 |
+
try:
|
9 |
+
nltk.data.find('corpora/stopwords')
|
10 |
+
except LookupError:
|
11 |
+
nltk.download('stopwords')
|
12 |
+
|
13 |
+
class MaskingProcessor:
|
14 |
+
def __init__(self):
|
15 |
+
|
16 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
17 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
18 |
+
self.stop_words = set(stopwords.words('english'))
|
19 |
+
|
20 |
+
def remove_stopwords(self, words):
|
21 |
+
"""
|
22 |
+
Remove stopwords from the given list of words.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
words (list): List of words.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
list: List of non-stop words.
|
29 |
+
"""
|
30 |
+
return [word for word in words if word.lower() not in self.stop_words]
|
31 |
+
|
32 |
+
def adjust_ngram_indices(self, original_words, common_ngrams):
|
33 |
+
"""
|
34 |
+
Adjust indices of common n-grams after removing stopwords.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
original_words (list): Original list of words.
|
38 |
+
common_ngrams (dict): Common n-grams and their indices.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
dict: Adjusted common n-grams with updated indices.
|
42 |
+
"""
|
43 |
+
non_stop_words = self.remove_stopwords(original_words)
|
44 |
+
original_to_non_stop = []
|
45 |
+
non_stop_idx = 0
|
46 |
+
|
47 |
+
for original_idx, word in enumerate(original_words):
|
48 |
+
if word.lower() not in self.stop_words:
|
49 |
+
original_to_non_stop.append((original_idx, non_stop_idx))
|
50 |
+
non_stop_idx += 1
|
51 |
+
|
52 |
+
adjusted_ngrams = {}
|
53 |
+
for ngram, positions in common_ngrams.items():
|
54 |
+
adjusted_positions = []
|
55 |
+
for start, end in positions:
|
56 |
+
try:
|
57 |
+
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
|
58 |
+
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
|
59 |
+
adjusted_positions.append((new_start, new_end))
|
60 |
+
except StopIteration:
|
61 |
+
continue # Skip if indices cannot be mapped
|
62 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
63 |
+
|
64 |
+
return adjusted_ngrams
|
65 |
+
|
66 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
67 |
+
"""
|
68 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
69 |
+
"""
|
70 |
+
original_words = sentence.split()
|
71 |
+
print(f' ---- original_words : {original_words} ----- ')
|
72 |
+
non_stop_words = self.remove_stopwords(original_words)
|
73 |
+
print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
74 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
75 |
+
print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
76 |
+
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
77 |
+
|
78 |
+
mask_indices = []
|
79 |
+
|
80 |
+
# Extract n-gram positions in non-stop words
|
81 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
82 |
+
|
83 |
+
# Mask a word before the first common n-gram
|
84 |
+
if ngram_positions:
|
85 |
+
print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
86 |
+
first_ngram_start = ngram_positions[0][0]
|
87 |
+
print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
88 |
+
if first_ngram_start > 0:
|
89 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
90 |
+
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
91 |
+
mask_indices.append(mask_index_before_ngram)
|
92 |
+
|
93 |
+
# Mask words between common n-grams
|
94 |
+
for i in range(len(ngram_positions) - 1):
|
95 |
+
end_prev = ngram_positions[i][1]
|
96 |
+
print(f' ---- end_prev : {end_prev} ----- ') # END INDICE FROM PREV LOOP FUNKNLKNLKNLKNLKNLKNLSKDNFLKSDHJFLSDJKFH:KLSDHF:LHKSDF:HJKLDFS:HJKLDFSHJK:
|
97 |
+
start_next = ngram_positions[i + 1][0]
|
98 |
+
print(f' ---- start_next : {start_next} ----- ')
|
99 |
+
if start_next > end_prev + 1:
|
100 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
101 |
+
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
102 |
+
mask_indices.append(mask_index_between_ngrams)
|
103 |
+
|
104 |
+
# Mask a word after the last common n-gram
|
105 |
+
last_ngram_end = ngram_positions[-1][1]
|
106 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
107 |
+
print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
108 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
109 |
+
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
110 |
+
mask_indices.append(mask_index_after_ngram)
|
111 |
+
|
112 |
+
# Create mapping from non-stop words to original indices
|
113 |
+
non_stop_to_original = {}
|
114 |
+
non_stop_idx = 0
|
115 |
+
for orig_idx, word in enumerate(original_words):
|
116 |
+
if word.lower() not in self.stop_words:
|
117 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
118 |
+
non_stop_idx += 1
|
119 |
+
|
120 |
+
# Map mask indices from non-stop word positions to original positions
|
121 |
+
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
122 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
123 |
+
print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
124 |
+
|
125 |
+
# Apply masks to the original sentence
|
126 |
+
masked_words = original_words.copy()
|
127 |
+
for idx in original_mask_indices:
|
128 |
+
masked_words[idx] = self.tokenizer.mask_token
|
129 |
+
|
130 |
+
return " ".join(masked_words)
|
131 |
+
|
132 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
133 |
+
"""
|
134 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
135 |
+
"""
|
136 |
+
random.seed(42)
|
137 |
+
original_words = sentence.split()
|
138 |
+
print(f' ---- original_words : {original_words} ----- ')
|
139 |
+
non_stop_words = self.remove_stopwords(original_words)
|
140 |
+
print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
141 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
142 |
+
print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
143 |
+
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
144 |
+
|
145 |
+
mask_indices = []
|
146 |
+
|
147 |
+
# Extract n-gram positions in non-stop words
|
148 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
149 |
+
|
150 |
+
# Mask a word before the first common n-gram
|
151 |
+
if ngram_positions:
|
152 |
+
print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
153 |
+
first_ngram_start = ngram_positions[0][0]
|
154 |
+
print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
155 |
+
if first_ngram_start > 0:
|
156 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
157 |
+
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
158 |
+
mask_indices.append(mask_index_before_ngram)
|
159 |
+
|
160 |
+
# Mask words between common n-grams
|
161 |
+
for i in range(len(ngram_positions) - 1):
|
162 |
+
end_prev = ngram_positions[i][1]
|
163 |
+
print(f' ---- end_prev : {end_prev} ----- ')
|
164 |
+
start_next = ngram_positions[i + 1][0]
|
165 |
+
print(f' ---- start_next : {start_next} ----- ')
|
166 |
+
if start_next > end_prev + 1:
|
167 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
168 |
+
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
169 |
+
mask_indices.append(mask_index_between_ngrams)
|
170 |
+
|
171 |
+
# Mask a word after the last common n-gram
|
172 |
+
last_ngram_end = ngram_positions[-1][1]
|
173 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
174 |
+
print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
175 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
176 |
+
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
177 |
+
mask_indices.append(mask_index_after_ngram)
|
178 |
+
|
179 |
+
# Create mapping from non-stop words to original indices
|
180 |
+
non_stop_to_original = {}
|
181 |
+
non_stop_idx = 0
|
182 |
+
for orig_idx, word in enumerate(original_words):
|
183 |
+
if word.lower() not in self.stop_words:
|
184 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
185 |
+
non_stop_idx += 1
|
186 |
+
|
187 |
+
# Map mask indices from non-stop word positions to original positions
|
188 |
+
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
189 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
190 |
+
print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
191 |
+
|
192 |
+
# Apply masks to the original sentence
|
193 |
+
masked_words = original_words.copy()
|
194 |
+
for idx in original_mask_indices:
|
195 |
+
masked_words[idx] = self.tokenizer.mask_token
|
196 |
+
|
197 |
+
return " ".join(masked_words)
|
198 |
+
|
199 |
+
|
200 |
+
def calculate_word_entropy(self, sentence, word_position):
|
201 |
+
"""
|
202 |
+
Calculate entropy for a specific word position in the sentence.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
sentence (str): The input sentence
|
206 |
+
word_position (int): Position of the word to calculate entropy for
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
float: Entropy value for the word
|
210 |
+
"""
|
211 |
+
words = sentence.split()
|
212 |
+
masked_words = words.copy()
|
213 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
214 |
+
masked_sentence = " ".join(masked_words)
|
215 |
+
|
216 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
217 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
218 |
+
|
219 |
+
with torch.no_grad():
|
220 |
+
outputs = self.model(input_ids)
|
221 |
+
logits = outputs.logits
|
222 |
+
|
223 |
+
# Get probabilities for the masked position
|
224 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
225 |
+
# Calculate entropy: -sum(p * log(p))
|
226 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
227 |
+
|
228 |
+
return entropy.item()
|
229 |
+
|
230 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
231 |
+
"""
|
232 |
+
Mask words in the sentence based on entropy, following n-gram positioning rules.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
sentence (str): Original sentence
|
236 |
+
common_ngrams (dict): Common n-grams and their indices
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
str: Masked sentence
|
240 |
+
"""
|
241 |
+
original_words = sentence.split()
|
242 |
+
non_stop_words = self.remove_stopwords(original_words)
|
243 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
244 |
+
|
245 |
+
# Create mapping from non-stop words to original indices
|
246 |
+
non_stop_to_original = {}
|
247 |
+
original_to_non_stop = {}
|
248 |
+
non_stop_idx = 0
|
249 |
+
for orig_idx, word in enumerate(original_words):
|
250 |
+
if word.lower() not in self.stop_words:
|
251 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
252 |
+
original_to_non_stop[orig_idx] = non_stop_idx
|
253 |
+
non_stop_idx += 1
|
254 |
+
|
255 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
256 |
+
mask_indices = []
|
257 |
+
|
258 |
+
if ngram_positions:
|
259 |
+
# Handle words before first n-gram
|
260 |
+
first_ngram_start = ngram_positions[0][0]
|
261 |
+
if first_ngram_start > 0:
|
262 |
+
# Calculate entropy for all candidate positions
|
263 |
+
candidate_positions = range(0, first_ngram_start)
|
264 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
265 |
+
for pos in candidate_positions]
|
266 |
+
# Select position with highest entropy
|
267 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
268 |
+
|
269 |
+
# Handle words between n-grams
|
270 |
+
for i in range(len(ngram_positions) - 1):
|
271 |
+
end_prev = ngram_positions[i][1]
|
272 |
+
start_next = ngram_positions[i + 1][0]
|
273 |
+
if start_next > end_prev + 1:
|
274 |
+
candidate_positions = range(end_prev + 1, start_next)
|
275 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
276 |
+
for pos in candidate_positions]
|
277 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
278 |
+
|
279 |
+
# Handle words after last n-gram
|
280 |
+
last_ngram_end = ngram_positions[-1][1]
|
281 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
282 |
+
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
|
283 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
284 |
+
for pos in candidate_positions]
|
285 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
286 |
+
|
287 |
+
# Map mask indices to original sentence positions and apply masks
|
288 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
289 |
+
masked_words = original_words.copy()
|
290 |
+
for idx in original_mask_indices:
|
291 |
+
masked_words[idx] = self.tokenizer.mask_token
|
292 |
+
|
293 |
+
return " ".join(masked_words)
|
294 |
+
|
295 |
+
|
296 |
+
def calculate_mask_logits(self, masked_sentence):
|
297 |
+
"""
|
298 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
dict: Masked token indices and their logits.
|
305 |
+
"""
|
306 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
307 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
308 |
+
|
309 |
+
with torch.no_grad():
|
310 |
+
outputs = self.model(input_ids)
|
311 |
+
logits = outputs.logits
|
312 |
+
|
313 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
314 |
+
return mask_logits
|
315 |
+
|
316 |
+
def process_sentences(self, sentences, result_dict, method="random"):
|
317 |
+
"""
|
318 |
+
Process sentences and calculate logits for masked tokens.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
sentences (list): List of sentences
|
322 |
+
result_dict (dict): Dictionary of common n-grams
|
323 |
+
method (str): Masking method ("random" or "entropy")
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
dict: Masked sentences and logits for each sentence
|
327 |
+
"""
|
328 |
+
results = {}
|
329 |
+
|
330 |
+
for sentence, ngrams in result_dict.items():
|
331 |
+
if method == "random":
|
332 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams)
|
333 |
+
elif method == "pseudorandom":
|
334 |
+
masked_sentence = self.mask_sentence_pseudorandom(sentence, ngrams)
|
335 |
+
else: # entropy
|
336 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
|
337 |
+
|
338 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
339 |
+
results[sentence] = {
|
340 |
+
"masked_sentence": masked_sentence,
|
341 |
+
"mask_logits": logits
|
342 |
+
}
|
343 |
+
|
344 |
+
return results
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
350 |
+
sentences = [
|
351 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
352 |
+
# "A speedy brown fox jumps over a lazy dog.",
|
353 |
+
# "A swift brown fox leaps over the lethargic dog."
|
354 |
+
]
|
355 |
+
result_dict ={
|
356 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
357 |
+
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
358 |
+
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
359 |
+
}
|
360 |
+
|
361 |
+
|
362 |
+
processor = MaskingProcessor()
|
363 |
+
# results_random = processor.process_sentences(sentences, result_dict)
|
364 |
+
results_entropy = processor.process_sentences(sentences, result_dict, method="random")
|
365 |
+
|
366 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
367 |
+
|
368 |
+
for sentence, output in results_entropy.items():
|
369 |
+
print(f"Original Sentence (Random): {sentence}")
|
370 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
371 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
372 |
+
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
373 |
+
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
374 |
+
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
375 |
+
print('--------------------------------')
|
376 |
+
for mask_idx, logits in output["mask_logits"].items():
|
377 |
+
print(f"Logits for [MASK] at position {mask_idx}:")
|
378 |
+
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
379 |
+
print(f' len(logits) : {len(logits)}')
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
|
384 |
+
# -------------------------------------------------------------------------------------------
|
385 |
+
# def mask_sentence(self, sentence, common_ngrams):
|
386 |
+
# """
|
387 |
+
# Mask words in the sentence based on the specified rules after removing stopwords.
|
388 |
+
|
389 |
+
# Args:
|
390 |
+
# sentence (str): Original sentence.
|
391 |
+
# common_ngrams (dict): Common n-grams and their indices.
|
392 |
+
|
393 |
+
# Returns:
|
394 |
+
# str: Masked sentence.
|
395 |
+
# """
|
396 |
+
# original_words = sentence.split()
|
397 |
+
# print(f' ---- original_words : {original_words} ----- ')
|
398 |
+
# non_stop_words = self.remove_stopwords(original_words)
|
399 |
+
# print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
400 |
+
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
401 |
+
# print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
402 |
+
# print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
403 |
+
|
404 |
+
# mask_indices = []
|
405 |
+
|
406 |
+
# # Extract n-gram positions in non-stop words
|
407 |
+
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
408 |
+
# print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
409 |
+
# # Mask a word before the first common n-gram
|
410 |
+
# if ngram_positions:
|
411 |
+
# first_ngram_start = ngram_positions[0][0]
|
412 |
+
# print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
413 |
+
# if first_ngram_start > 0:
|
414 |
+
# mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
415 |
+
# print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
416 |
+
# mask_indices.append(mask_index_before_ngram)
|
417 |
+
|
418 |
+
# # Mask words between common n-grams
|
419 |
+
# for i in range(len(ngram_positions) - 1):
|
420 |
+
# end_prev = ngram_positions[i][1]
|
421 |
+
# print(f' ---- end_prev : {end_prev} ----- ')
|
422 |
+
# start_next = ngram_positions[i + 1][0]
|
423 |
+
# print(f' ---- start_next : {start_next} ----- ')
|
424 |
+
# if start_next > end_prev + 1:
|
425 |
+
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
426 |
+
# print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
427 |
+
# mask_indices.append(mask_index_between_ngrams)
|
428 |
+
|
429 |
+
# # Mask a word after the last common n-gram
|
430 |
+
# last_ngram_end = ngram_positions[-1][1]
|
431 |
+
# print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
432 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
433 |
+
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
434 |
+
# print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
435 |
+
# mask_indices.append(mask_index_after_ngram)
|
436 |
+
|
437 |
+
# # Map mask indices back to original sentence
|
438 |
+
# adjusted_indices = [
|
439 |
+
# orig for orig, non_stop in enumerate(original_words)
|
440 |
+
# if non_stop in mask_indices
|
441 |
+
# ]
|
442 |
+
|
443 |
+
# # Apply masks to the original sentence
|
444 |
+
# for idx in adjusted_indices:
|
445 |
+
# original_words[idx] = self.tokenizer.mask_token
|
446 |
+
|
447 |
+
# return " ".join(original_words)
|
utils/old/masking/masking_methods_ok_working.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
# Ensure stopwords are downloaded
|
8 |
+
try:
|
9 |
+
nltk.data.find('corpora/stopwords')
|
10 |
+
except LookupError:
|
11 |
+
nltk.download('stopwords')
|
12 |
+
|
13 |
+
class MaskingProcessor:
|
14 |
+
def __init__(self, ):
|
15 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
16 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
17 |
+
self.stop_words = set(stopwords.words('english'))
|
18 |
+
|
19 |
+
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
|
20 |
+
"""
|
21 |
+
Adjust indices of common n-grams after removing stop words.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
words (list): List of words in the original sentence.
|
25 |
+
common_ngrams (dict): Common n-grams and their indices.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
dict: Adjusted common n-grams and their indices.
|
29 |
+
"""
|
30 |
+
if not remove_stopwords:
|
31 |
+
return common_ngrams
|
32 |
+
|
33 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
34 |
+
adjusted_ngrams = {}
|
35 |
+
|
36 |
+
for ngram, positions in common_ngrams.items():
|
37 |
+
adjusted_positions = []
|
38 |
+
for start, end in positions:
|
39 |
+
try:
|
40 |
+
new_start = non_stop_word_indices.index(start)
|
41 |
+
new_end = non_stop_word_indices.index(end)
|
42 |
+
adjusted_positions.append((new_start, new_end))
|
43 |
+
except ValueError:
|
44 |
+
continue # Skip if indices cannot be mapped
|
45 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
46 |
+
|
47 |
+
return adjusted_ngrams
|
48 |
+
|
49 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
50 |
+
"""
|
51 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
52 |
+
and one after the last common n-gram (random selection).
|
53 |
+
|
54 |
+
Args:
|
55 |
+
original_sentence (str): Original sentence
|
56 |
+
common_ngrams (dict): Common n-grams and their indices
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: Masked sentence with original stop words retained
|
60 |
+
"""
|
61 |
+
words = original_sentence.split()
|
62 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
63 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
64 |
+
|
65 |
+
mask_indices = []
|
66 |
+
# Handle before the first common n-gram
|
67 |
+
if adjusted_ngrams:
|
68 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
69 |
+
if first_ngram_start > 0:
|
70 |
+
mask_indices.append(random.randint(0, first_ngram_start - 1))
|
71 |
+
|
72 |
+
# Handle between common n-grams
|
73 |
+
ngram_positions = list(adjusted_ngrams.values())
|
74 |
+
for i in range(len(ngram_positions) - 1):
|
75 |
+
end_prev = ngram_positions[i][-1][1]
|
76 |
+
start_next = ngram_positions[i + 1][0][0]
|
77 |
+
if start_next > end_prev + 1:
|
78 |
+
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
79 |
+
|
80 |
+
# Handle after the last common n-gram
|
81 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
82 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
83 |
+
mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
|
84 |
+
|
85 |
+
# Mask the chosen indices
|
86 |
+
original_masked_sentence = words[:]
|
87 |
+
for idx in mask_indices:
|
88 |
+
if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
89 |
+
non_stop_words[idx] = self.tokenizer.mask_token
|
90 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
91 |
+
|
92 |
+
return " ".join(original_masked_sentence)
|
93 |
+
|
94 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
|
95 |
+
"""
|
96 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
97 |
+
and one after the last common n-gram (highest entropy selection).
|
98 |
+
|
99 |
+
Args:
|
100 |
+
original_sentence (str): Original sentence
|
101 |
+
common_ngrams (dict): Common n-grams and their indices
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
str: Masked sentence with original stop words retained
|
105 |
+
"""
|
106 |
+
words = original_sentence.split()
|
107 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
108 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
109 |
+
entropy_scores = {}
|
110 |
+
|
111 |
+
for idx, word in enumerate(non_stop_words):
|
112 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
113 |
+
continue # Skip words in common n-grams
|
114 |
+
|
115 |
+
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
|
116 |
+
masked_sentence = " ".join(masked_sentence)
|
117 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
118 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
outputs = self.model(input_ids)
|
122 |
+
logits = outputs.logits
|
123 |
+
|
124 |
+
filtered_logits = logits[0, mask_token_index, :]
|
125 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
126 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
127 |
+
entropy_scores[idx] = entropy
|
128 |
+
|
129 |
+
mask_indices = []
|
130 |
+
|
131 |
+
# Handle before the first common n-gram
|
132 |
+
if adjusted_ngrams:
|
133 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
134 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
135 |
+
if candidates:
|
136 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
137 |
+
|
138 |
+
# Handle between common n-grams
|
139 |
+
ngram_positions = list(adjusted_ngrams.values())
|
140 |
+
for i in range(len(ngram_positions) - 1):
|
141 |
+
end_prev = ngram_positions[i][-1][1]
|
142 |
+
start_next = ngram_positions[i + 1][0][0]
|
143 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
144 |
+
if candidates:
|
145 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
146 |
+
|
147 |
+
# Handle after the last common n-gram
|
148 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
149 |
+
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
|
150 |
+
if candidates:
|
151 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
152 |
+
|
153 |
+
# Mask the chosen indices
|
154 |
+
original_masked_sentence = words[:]
|
155 |
+
for idx in mask_indices:
|
156 |
+
non_stop_words[idx] = self.tokenizer.mask_token
|
157 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
158 |
+
|
159 |
+
return " ".join(original_masked_sentence)
|
160 |
+
|
161 |
+
def calculate_mask_logits(self, masked_sentence):
|
162 |
+
"""
|
163 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
dict: Masked token indices and their logits
|
170 |
+
"""
|
171 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
172 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
173 |
+
|
174 |
+
with torch.no_grad():
|
175 |
+
outputs = self.model(input_ids)
|
176 |
+
logits = outputs.logits
|
177 |
+
|
178 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
179 |
+
return mask_logits
|
180 |
+
|
181 |
+
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
|
182 |
+
"""
|
183 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
original_sentences (list): List of original sentences
|
187 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
188 |
+
method (str): Masking method ("random" or "entropy")
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
dict: Masked sentences and their logits for each sentence
|
192 |
+
"""
|
193 |
+
results = {}
|
194 |
+
|
195 |
+
for sentence, ngrams in result_dict.items():
|
196 |
+
if method == "random":
|
197 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
|
198 |
+
elif method == "entropy":
|
199 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
|
200 |
+
else:
|
201 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
202 |
+
|
203 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
204 |
+
results[sentence] = {
|
205 |
+
"masked_sentence": masked_sentence,
|
206 |
+
"mask_logits": logits
|
207 |
+
}
|
208 |
+
|
209 |
+
return results
|
210 |
+
|
211 |
+
# Example usage
|
212 |
+
if __name__ == "__main__":
|
213 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
214 |
+
sentences = [
|
215 |
+
"The quick brown fox jumps over the lazy dog.",
|
216 |
+
"A quick brown dog outpaces a lazy fox.",
|
217 |
+
"Quick brown animals leap over lazy obstacles."
|
218 |
+
]
|
219 |
+
|
220 |
+
result_dict = {
|
221 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
222 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
223 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
224 |
+
}
|
225 |
+
|
226 |
+
# result_dict = {
|
227 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
228 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
229 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
230 |
+
# }
|
231 |
+
|
232 |
+
processor = MaskingProcessor()
|
233 |
+
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
234 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
235 |
+
|
236 |
+
for sentence, output in results_random.items():
|
237 |
+
print(f"Original Sentence (Random): {sentence}")
|
238 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
239 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
240 |
+
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
241 |
+
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
242 |
+
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
243 |
+
print('--------------------------------')
|
244 |
+
for mask_idx, logits in output["mask_logits"].items():
|
245 |
+
print(f"Logits for [MASK] at position {mask_idx}:")
|
246 |
+
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
# print('--------------------------------')
|
251 |
+
# for sentence, output in results_entropy.items():
|
252 |
+
# print(f"Original Sentence (Entropy): {sentence}")
|
253 |
+
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
254 |
+
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
255 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
256 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
257 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
utils/old/masking/masking_methods_v1_working.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
8 |
+
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
|
9 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
10 |
+
|
11 |
+
|
12 |
+
# Ensure stopwords are downloaded
|
13 |
+
try:
|
14 |
+
nltk.data.find('corpora/stopwords')
|
15 |
+
except LookupError:
|
16 |
+
nltk.download('stopwords')
|
17 |
+
|
18 |
+
class MaskingProcessor:
|
19 |
+
def __init__(self):
|
20 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
21 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
22 |
+
self.stop_words = set(stopwords.words('english'))
|
23 |
+
|
24 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False):
|
25 |
+
"""
|
26 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
27 |
+
and one after the last common n-gram (random selection).
|
28 |
+
|
29 |
+
Args:
|
30 |
+
original_sentence (str): Original sentence
|
31 |
+
common_ngrams (dict): Common n-grams and their indices
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: Masked sentence
|
35 |
+
"""
|
36 |
+
if remove_stopwords:
|
37 |
+
words = original_sentence.split()
|
38 |
+
words = [word for word in words if word not in self.stop_words]
|
39 |
+
else:
|
40 |
+
words = original_sentence.split()
|
41 |
+
|
42 |
+
mask_indices = []
|
43 |
+
# Handle before the first common n-gram
|
44 |
+
if common_ngrams:
|
45 |
+
first_ngram_start = list(common_ngrams.values())[0][0][0]
|
46 |
+
if first_ngram_start > 0:
|
47 |
+
mask_indices.append(random.randint(0, first_ngram_start - 1))
|
48 |
+
|
49 |
+
# Handle between common n-grams
|
50 |
+
ngram_positions = list(common_ngrams.values())
|
51 |
+
for i in range(len(ngram_positions) - 1):
|
52 |
+
end_prev = ngram_positions[i][-1][1]
|
53 |
+
start_next = ngram_positions[i + 1][0][0]
|
54 |
+
if start_next > end_prev + 1:
|
55 |
+
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
56 |
+
|
57 |
+
# Handle after the last common n-gram
|
58 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
59 |
+
if last_ngram_end < len(words) - 1:
|
60 |
+
mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1))
|
61 |
+
|
62 |
+
# Mask the chosen indices
|
63 |
+
for idx in mask_indices:
|
64 |
+
if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
65 |
+
words[idx] = self.tokenizer.mask_token
|
66 |
+
|
67 |
+
return " ".join(words)
|
68 |
+
|
69 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False):
|
70 |
+
"""
|
71 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
72 |
+
and one after the last common n-gram (highest entropy selection).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
original_sentence (str): Original sentence
|
76 |
+
common_ngrams (dict): Common n-grams and their indices
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: Masked sentence
|
80 |
+
"""
|
81 |
+
if remove_stopwords:
|
82 |
+
words = original_sentence.split()
|
83 |
+
words = [word for word in words if word not in self.stop_words]
|
84 |
+
else:
|
85 |
+
words = original_sentence.split()
|
86 |
+
entropy_scores = {}
|
87 |
+
|
88 |
+
for idx, word in enumerate(words):
|
89 |
+
if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
90 |
+
continue # Skip words in common n-grams
|
91 |
+
|
92 |
+
masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:]
|
93 |
+
masked_sentence = " ".join(masked_sentence)
|
94 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
95 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
96 |
+
|
97 |
+
with torch.no_grad():
|
98 |
+
outputs = self.model(input_ids)
|
99 |
+
logits = outputs.logits
|
100 |
+
|
101 |
+
filtered_logits = logits[0, mask_token_index, :]
|
102 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
103 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
104 |
+
entropy_scores[idx] = entropy
|
105 |
+
|
106 |
+
mask_indices = []
|
107 |
+
|
108 |
+
# Handle before the first common n-gram
|
109 |
+
if common_ngrams:
|
110 |
+
first_ngram_start = list(common_ngrams.values())[0][0][0]
|
111 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
112 |
+
if candidates:
|
113 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
114 |
+
|
115 |
+
# Handle between common n-grams
|
116 |
+
ngram_positions = list(common_ngrams.values())
|
117 |
+
for i in range(len(ngram_positions) - 1):
|
118 |
+
end_prev = ngram_positions[i][-1][1]
|
119 |
+
start_next = ngram_positions[i + 1][0][0]
|
120 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
121 |
+
if candidates:
|
122 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
123 |
+
|
124 |
+
# Handle after the last common n-gram
|
125 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
126 |
+
candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores]
|
127 |
+
if candidates:
|
128 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
129 |
+
|
130 |
+
# Mask the chosen indices
|
131 |
+
for idx in mask_indices:
|
132 |
+
words[idx] = self.tokenizer.mask_token
|
133 |
+
|
134 |
+
return " ".join(words)
|
135 |
+
|
136 |
+
def calculate_mask_logits(self, masked_sentence):
|
137 |
+
"""
|
138 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
dict: Masked token indices and their logits
|
145 |
+
"""
|
146 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
147 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
148 |
+
|
149 |
+
with torch.no_grad():
|
150 |
+
outputs = self.model(input_ids)
|
151 |
+
logits = outputs.logits
|
152 |
+
|
153 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
154 |
+
return mask_logits
|
155 |
+
|
156 |
+
def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"):
|
157 |
+
"""
|
158 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
original_sentences (list): List of original sentences
|
162 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
163 |
+
method (str): Masking method ("random" or "entropy")
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
dict: Masked sentences and their logits for each sentence
|
167 |
+
"""
|
168 |
+
results = {}
|
169 |
+
|
170 |
+
for sentence, ngrams in result_dict.items():
|
171 |
+
if method == "random":
|
172 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams)
|
173 |
+
elif method == "entropy":
|
174 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
|
175 |
+
else:
|
176 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
177 |
+
|
178 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
179 |
+
results[sentence] = {
|
180 |
+
"masked_sentence": masked_sentence,
|
181 |
+
"mask_logits": logits
|
182 |
+
}
|
183 |
+
|
184 |
+
return results
|
185 |
+
|
186 |
+
# Example usage
|
187 |
+
if __name__ == "__main__":
|
188 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
189 |
+
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
|
190 |
+
|
191 |
+
sentences = [
|
192 |
+
"The quick brown fox jumps over the lazy dog.",
|
193 |
+
"A quick brown dog outpaces a lazy fox.",
|
194 |
+
"Quick brown animals leap over lazy obstacles."
|
195 |
+
]
|
196 |
+
|
197 |
+
result_dict = {
|
198 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
199 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
200 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
201 |
+
}
|
202 |
+
|
203 |
+
# result_dict = {
|
204 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
205 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
206 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
207 |
+
# }
|
208 |
+
|
209 |
+
processor = MaskingProcessor()
|
210 |
+
results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random")
|
211 |
+
results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy")
|
212 |
+
|
213 |
+
for sentence, output in results_random.items():
|
214 |
+
print(f"Original Sentence (Random): {sentence}")
|
215 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
216 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
217 |
+
|
218 |
+
for sentence, output in results_entropy.items():
|
219 |
+
print(f"Original Sentence (Entropy): {sentence}")
|
220 |
+
print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
221 |
+
# print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
'''
|
227 |
+
result_dict = {
|
228 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
229 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
230 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
231 |
+
}
|
232 |
+
|
233 |
+
'''
|
utils/old/masking_methods_final_copy.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
7 |
+
|
8 |
+
|
9 |
+
# Ensure stopwords are downloaded
|
10 |
+
try:
|
11 |
+
nltk.data.find('corpora/stopwords')
|
12 |
+
except LookupError:
|
13 |
+
nltk.download('stopwords')
|
14 |
+
|
15 |
+
class MaskingProcessor:
|
16 |
+
# def __init__(self, tokenizer, model):
|
17 |
+
def __init__(self):
|
18 |
+
# self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
19 |
+
# self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
20 |
+
|
21 |
+
# self.tokenizer = tokenizer
|
22 |
+
# self.model = model
|
23 |
+
|
24 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
25 |
+
self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
26 |
+
|
27 |
+
# self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
28 |
+
# self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
29 |
+
|
30 |
+
self.stop_words = set(stopwords.words('english'))
|
31 |
+
|
32 |
+
def remove_stopwords(self, words):
|
33 |
+
"""
|
34 |
+
Remove stopwords from the given list of words.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
words (list): List of words.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
list: List of non-stop words.
|
41 |
+
"""
|
42 |
+
return [word for word in words if word.lower() not in self.stop_words]
|
43 |
+
|
44 |
+
def adjust_ngram_indices(self, original_words, common_ngrams):
|
45 |
+
"""
|
46 |
+
Adjust indices of common n-grams after removing stopwords.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
original_words (list): Original list of words.
|
50 |
+
common_ngrams (dict): Common n-grams and their indices.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
dict: Adjusted common n-grams with updated indices.
|
54 |
+
"""
|
55 |
+
non_stop_words = self.remove_stopwords(original_words)
|
56 |
+
original_to_non_stop = []
|
57 |
+
non_stop_idx = 0
|
58 |
+
|
59 |
+
for original_idx, word in enumerate(original_words):
|
60 |
+
if word.lower() not in self.stop_words:
|
61 |
+
original_to_non_stop.append((original_idx, non_stop_idx))
|
62 |
+
non_stop_idx += 1
|
63 |
+
|
64 |
+
adjusted_ngrams = {}
|
65 |
+
for ngram, positions in common_ngrams.items():
|
66 |
+
adjusted_positions = []
|
67 |
+
for start, end in positions:
|
68 |
+
try:
|
69 |
+
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
|
70 |
+
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
|
71 |
+
adjusted_positions.append((new_start, new_end))
|
72 |
+
except StopIteration:
|
73 |
+
continue # Skip if indices cannot be mapped
|
74 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
75 |
+
|
76 |
+
return adjusted_ngrams
|
77 |
+
|
78 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
79 |
+
"""
|
80 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
81 |
+
"""
|
82 |
+
# Split sentence into words
|
83 |
+
original_words = sentence.split()
|
84 |
+
|
85 |
+
# Handle punctuation at the end
|
86 |
+
has_punctuation = False
|
87 |
+
punctuation = None
|
88 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
89 |
+
has_punctuation = True
|
90 |
+
punctuation = original_words[-1][-1]
|
91 |
+
original_words = original_words[:-1]
|
92 |
+
|
93 |
+
print(f' ---- original_words : {original_words} ----- ')
|
94 |
+
|
95 |
+
# Process words without punctuation
|
96 |
+
non_stop_words = self.remove_stopwords(original_words)
|
97 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
98 |
+
|
99 |
+
# Rest of the existing function code...
|
100 |
+
mask_indices = []
|
101 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
102 |
+
|
103 |
+
if ngram_positions:
|
104 |
+
first_ngram_start = ngram_positions[0][0]
|
105 |
+
if first_ngram_start > 0:
|
106 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
107 |
+
mask_indices.append(mask_index_before_ngram)
|
108 |
+
|
109 |
+
# Mask words between common n-grams
|
110 |
+
for i in range(len(ngram_positions) - 1):
|
111 |
+
end_prev = ngram_positions[i][1]
|
112 |
+
start_next = ngram_positions[i + 1][0]
|
113 |
+
if start_next > end_prev + 1:
|
114 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
115 |
+
mask_indices.append(mask_index_between_ngrams)
|
116 |
+
|
117 |
+
# Mask a word after the last common n-gram
|
118 |
+
last_ngram_end = ngram_positions[-1][1]
|
119 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
120 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
121 |
+
mask_indices.append(mask_index_after_ngram)
|
122 |
+
|
123 |
+
# Create mapping from non-stop words to original indices
|
124 |
+
non_stop_to_original = {}
|
125 |
+
non_stop_idx = 0
|
126 |
+
for orig_idx, word in enumerate(original_words):
|
127 |
+
if word.lower() not in self.stop_words:
|
128 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
129 |
+
non_stop_idx += 1
|
130 |
+
|
131 |
+
# Map mask indices and apply masks
|
132 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
133 |
+
masked_words = original_words.copy()
|
134 |
+
for idx in original_mask_indices:
|
135 |
+
masked_words[idx] = self.tokenizer.mask_token
|
136 |
+
# masked_words[idx] = '<mask>' # for roberta
|
137 |
+
|
138 |
+
# Add back punctuation if it existed
|
139 |
+
if has_punctuation:
|
140 |
+
masked_words.append(punctuation)
|
141 |
+
|
142 |
+
print(f' ***** masked_words at end : {masked_words} ***** ')
|
143 |
+
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
|
144 |
+
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
|
145 |
+
|
146 |
+
return " ".join(masked_words), original_mask_indices
|
147 |
+
|
148 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
149 |
+
"""
|
150 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
151 |
+
"""
|
152 |
+
# Split sentence into words
|
153 |
+
random.seed(3)
|
154 |
+
original_words = sentence.split()
|
155 |
+
|
156 |
+
# Handle punctuation at the end
|
157 |
+
has_punctuation = False
|
158 |
+
punctuation = None
|
159 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
160 |
+
has_punctuation = True
|
161 |
+
punctuation = original_words[-1][-1]
|
162 |
+
original_words = original_words[:-1]
|
163 |
+
|
164 |
+
print(f' ---- original_words : {original_words} ----- ')
|
165 |
+
|
166 |
+
# Process words without punctuation
|
167 |
+
non_stop_words = self.remove_stopwords(original_words)
|
168 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
169 |
+
|
170 |
+
# Rest of the existing function code...
|
171 |
+
mask_indices = []
|
172 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
173 |
+
|
174 |
+
if ngram_positions:
|
175 |
+
first_ngram_start = ngram_positions[0][0]
|
176 |
+
if first_ngram_start > 0:
|
177 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
178 |
+
mask_indices.append(mask_index_before_ngram)
|
179 |
+
|
180 |
+
# Mask words between common n-grams
|
181 |
+
for i in range(len(ngram_positions) - 1):
|
182 |
+
end_prev = ngram_positions[i][1]
|
183 |
+
start_next = ngram_positions[i + 1][0]
|
184 |
+
if start_next > end_prev + 1:
|
185 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
186 |
+
mask_indices.append(mask_index_between_ngrams)
|
187 |
+
|
188 |
+
# Mask a word after the last common n-gram
|
189 |
+
last_ngram_end = ngram_positions[-1][1]
|
190 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
191 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
192 |
+
mask_indices.append(mask_index_after_ngram)
|
193 |
+
|
194 |
+
# Create mapping from non-stop words to original indices
|
195 |
+
non_stop_to_original = {}
|
196 |
+
non_stop_idx = 0
|
197 |
+
for orig_idx, word in enumerate(original_words):
|
198 |
+
if word.lower() not in self.stop_words:
|
199 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
200 |
+
non_stop_idx += 1
|
201 |
+
|
202 |
+
# Map mask indices and apply masks
|
203 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
204 |
+
masked_words = original_words.copy()
|
205 |
+
for idx in original_mask_indices:
|
206 |
+
masked_words[idx] = self.tokenizer.mask_token
|
207 |
+
# masked_words[idx] = '<mask>' # for roberta
|
208 |
+
|
209 |
+
# Add back punctuation if it existed
|
210 |
+
if has_punctuation:
|
211 |
+
masked_words.append(punctuation)
|
212 |
+
|
213 |
+
print(f' ***** masked_words at end : {masked_words} ***** ')
|
214 |
+
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
|
215 |
+
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
|
216 |
+
|
217 |
+
return " ".join(masked_words), original_mask_indices
|
218 |
+
|
219 |
+
|
220 |
+
def calculate_word_entropy(self, sentence, word_position):
|
221 |
+
"""
|
222 |
+
Calculate entropy for a specific word position in the sentence.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
sentence (str): The input sentence
|
226 |
+
word_position (int): Position of the word to calculate entropy for
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
float: Entropy value for the word
|
230 |
+
"""
|
231 |
+
words = sentence.split()
|
232 |
+
masked_words = words.copy()
|
233 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
234 |
+
masked_sentence = " ".join(masked_words)
|
235 |
+
|
236 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
237 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
238 |
+
|
239 |
+
with torch.no_grad():
|
240 |
+
outputs = self.model(input_ids)
|
241 |
+
logits = outputs.logits
|
242 |
+
|
243 |
+
# Get probabilities for the masked position
|
244 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
245 |
+
# Calculate entropy: -sum(p * log(p))
|
246 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
247 |
+
|
248 |
+
return entropy.item()
|
249 |
+
|
250 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
251 |
+
"""
|
252 |
+
Mask words in the sentence based on entropy, following n-gram positioning rules.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
sentence (str): Original sentence
|
256 |
+
common_ngrams (dict): Common n-grams and their indices
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
str: Masked sentence
|
260 |
+
"""
|
261 |
+
# Split sentence into words
|
262 |
+
original_words = sentence.split()
|
263 |
+
|
264 |
+
# Handle punctuation at the end
|
265 |
+
has_punctuation = False
|
266 |
+
punctuation = None
|
267 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
268 |
+
has_punctuation = True
|
269 |
+
punctuation = original_words[-1][-1]
|
270 |
+
original_words = original_words[:-1]
|
271 |
+
|
272 |
+
# Process words without punctuation
|
273 |
+
non_stop_words = self.remove_stopwords(original_words)
|
274 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
275 |
+
|
276 |
+
# Create mapping from non-stop words to original indices
|
277 |
+
non_stop_to_original = {}
|
278 |
+
original_to_non_stop = {}
|
279 |
+
non_stop_idx = 0
|
280 |
+
for orig_idx, word in enumerate(original_words):
|
281 |
+
if word.lower() not in self.stop_words:
|
282 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
283 |
+
original_to_non_stop[orig_idx] = non_stop_idx
|
284 |
+
non_stop_idx += 1
|
285 |
+
|
286 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
287 |
+
mask_indices = []
|
288 |
+
|
289 |
+
if ngram_positions:
|
290 |
+
# Handle words before first n-gram
|
291 |
+
first_ngram_start = ngram_positions[0][0]
|
292 |
+
if first_ngram_start > 0:
|
293 |
+
candidate_positions = range(0, first_ngram_start)
|
294 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
295 |
+
for pos in candidate_positions]
|
296 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
297 |
+
|
298 |
+
# Handle words between n-grams
|
299 |
+
for i in range(len(ngram_positions) - 1):
|
300 |
+
end_prev = ngram_positions[i][1]
|
301 |
+
start_next = ngram_positions[i + 1][0]
|
302 |
+
if start_next > end_prev + 1:
|
303 |
+
candidate_positions = range(end_prev + 1, start_next)
|
304 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
305 |
+
for pos in candidate_positions]
|
306 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
307 |
+
|
308 |
+
# Handle words after last n-gram
|
309 |
+
last_ngram_end = ngram_positions[-1][1]
|
310 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
311 |
+
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
|
312 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
313 |
+
for pos in candidate_positions]
|
314 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
315 |
+
|
316 |
+
# Map mask indices to original sentence positions and apply masks
|
317 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
318 |
+
masked_words = original_words.copy()
|
319 |
+
for idx in original_mask_indices:
|
320 |
+
masked_words[idx] = self.tokenizer.mask_token
|
321 |
+
|
322 |
+
# Add back punctuation if it existed
|
323 |
+
if has_punctuation:
|
324 |
+
masked_words.append(punctuation)
|
325 |
+
|
326 |
+
return " ".join(masked_words), original_mask_indices
|
327 |
+
|
328 |
+
def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
329 |
+
"""
|
330 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
original_sentence (str): Original sentence without masks
|
334 |
+
original_mask_indices (list): List of indices to mask
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
dict: Masked token indices and their logits
|
338 |
+
"""
|
339 |
+
print('==========================================================================================================')
|
340 |
+
words = original_sentence.split()
|
341 |
+
print(f' ##### calculate_mask_logits >> words : {words} ##### ')
|
342 |
+
mask_logits = {}
|
343 |
+
|
344 |
+
for idx in original_mask_indices:
|
345 |
+
# Create a copy of words and mask the current position
|
346 |
+
print(f' ---- idx : {idx} ----- ')
|
347 |
+
masked_words = words.copy()
|
348 |
+
masked_words[idx] = '[MASK]'
|
349 |
+
# masked_words[idx] = '<mask>' # for roberta
|
350 |
+
masked_sentence = " ".join(masked_words)
|
351 |
+
print(f' ---- masked_sentence : {masked_sentence} ----- ')
|
352 |
+
|
353 |
+
# Calculate logits for the current mask
|
354 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
355 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
356 |
+
|
357 |
+
with torch.no_grad():
|
358 |
+
outputs = self.model(input_ids)
|
359 |
+
logits = outputs.logits
|
360 |
+
|
361 |
+
# Extract logits for the masked position
|
362 |
+
mask_logits_tensor = logits[0, mask_token_index, :]
|
363 |
+
|
364 |
+
# Get top logits and corresponding tokens
|
365 |
+
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates
|
366 |
+
|
367 |
+
# Convert token IDs to words and filter out subword tokens
|
368 |
+
top_tokens = []
|
369 |
+
top_logits = []
|
370 |
+
seen_words = set() # To keep track of unique words
|
371 |
+
|
372 |
+
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
|
373 |
+
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
|
374 |
+
|
375 |
+
# Skip if it's a subword token (starts with ##)
|
376 |
+
if token.startswith('##'):
|
377 |
+
continue
|
378 |
+
|
379 |
+
# Convert token to proper word
|
380 |
+
word = self.tokenizer.convert_tokens_to_string([token]).strip()
|
381 |
+
|
382 |
+
# Only add if it's a new word and not empty
|
383 |
+
if word and word not in seen_words:
|
384 |
+
seen_words.add(word)
|
385 |
+
top_tokens.append(word)
|
386 |
+
top_logits.append(logit.item())
|
387 |
+
|
388 |
+
# Break if we have 50 unique complete words
|
389 |
+
if len(top_tokens) == 50:
|
390 |
+
break
|
391 |
+
|
392 |
+
# print(f' ---- top_tokens : {top_tokens} ----- ')
|
393 |
+
|
394 |
+
# Store results
|
395 |
+
mask_logits[idx] = {
|
396 |
+
"tokens": top_tokens,
|
397 |
+
"logits": top_logits
|
398 |
+
}
|
399 |
+
|
400 |
+
return mask_logits
|
401 |
+
|
402 |
+
# def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
403 |
+
# """
|
404 |
+
# Calculate logits for masked tokens in the sentence using BERT.
|
405 |
+
|
406 |
+
# Args:
|
407 |
+
# original_sentence (str): Original sentence without masks
|
408 |
+
# original_mask_indices (list): List of indices to mask
|
409 |
+
|
410 |
+
# Returns:
|
411 |
+
# dict: Masked token indices and their logits
|
412 |
+
# """
|
413 |
+
# words = original_sentence.split()
|
414 |
+
# print(f' ##### calculate_mask_logits >> words : {words} ##### ')
|
415 |
+
# mask_logits = {}
|
416 |
+
|
417 |
+
# for idx in original_mask_indices:
|
418 |
+
# # Create a copy of words and mask the current position
|
419 |
+
# print(f' ---- idx : {idx} ----- ')
|
420 |
+
# masked_words = words.copy()
|
421 |
+
# print(f' ---- words : {masked_words} ----- ')
|
422 |
+
# # masked_words[idx] = self.tokenizer.mask_token
|
423 |
+
# masked_words[idx] = '[MASK]'
|
424 |
+
# print(f' ---- masked_words : {masked_words} ----- ')
|
425 |
+
# masked_sentence = " ".join(masked_words)
|
426 |
+
# print(f' ---- masked_sentence : {masked_sentence} ----- ')
|
427 |
+
|
428 |
+
# # Calculate logits for the current mask
|
429 |
+
# input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
430 |
+
# mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
431 |
+
|
432 |
+
# with torch.no_grad():
|
433 |
+
# outputs = self.model(input_ids)
|
434 |
+
# logits = outputs.logits
|
435 |
+
|
436 |
+
# # Extract logits for the masked position
|
437 |
+
# mask_logits_tensor = logits[0, mask_token_index, :]
|
438 |
+
|
439 |
+
# # Get top 50 logits and corresponding tokens
|
440 |
+
# top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1)
|
441 |
+
|
442 |
+
# # Convert token IDs to words
|
443 |
+
# top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]]
|
444 |
+
# print(f' ---- top_tokens : {top_tokens} ----- ')
|
445 |
+
|
446 |
+
# # Store results
|
447 |
+
# mask_logits[idx] = {
|
448 |
+
# "tokens": top_tokens,
|
449 |
+
# "logits": top_mask_logits.tolist()
|
450 |
+
# }
|
451 |
+
|
452 |
+
# return mask_logits
|
453 |
+
|
454 |
+
|
455 |
+
def process_sentences(self, sentences, result_dict, method="random"):
|
456 |
+
"""
|
457 |
+
Process sentences and calculate logits for masked tokens.
|
458 |
+
"""
|
459 |
+
results = {}
|
460 |
+
|
461 |
+
for sentence, ngrams in result_dict.items():
|
462 |
+
# Split punctuation from the last word before processing
|
463 |
+
words = sentence.split()
|
464 |
+
last_word = words[-1]
|
465 |
+
if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
466 |
+
# Split the last word and punctuation
|
467 |
+
words[-1] = last_word[:-1]
|
468 |
+
punctuation = last_word[-1]
|
469 |
+
# Rejoin with space before punctuation to treat it as separate token
|
470 |
+
processed_sentence = " ".join(words) + " " + punctuation
|
471 |
+
else:
|
472 |
+
processed_sentence = sentence
|
473 |
+
|
474 |
+
if method == "random":
|
475 |
+
masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
|
476 |
+
elif method == "pseudorandom":
|
477 |
+
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
|
478 |
+
else: # entropy
|
479 |
+
masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
|
480 |
+
|
481 |
+
logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
|
482 |
+
results[sentence] = {
|
483 |
+
"masked_sentence": masked_sentence,
|
484 |
+
"mask_logits": logits
|
485 |
+
}
|
486 |
+
|
487 |
+
return results
|
488 |
+
|
489 |
+
|
490 |
+
|
491 |
+
if __name__ == "__main__":
|
492 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
493 |
+
sentences = [
|
494 |
+
"The quick brown fox jumps over small cat the lazy dog everyday again and again .",
|
495 |
+
# "A speedy brown fox jumps over a lazy dog.",
|
496 |
+
# "A swift brown fox leaps over the lethargic dog."
|
497 |
+
|
498 |
+
]
|
499 |
+
result_dict ={
|
500 |
+
'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]},
|
501 |
+
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
502 |
+
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
503 |
+
}
|
504 |
+
|
505 |
+
|
506 |
+
processor = MaskingProcessor()
|
507 |
+
# results_random = processor.process_sentences(sentences, result_dict)
|
508 |
+
results_entropy = processor.process_sentences(sentences, result_dict, method="random")
|
509 |
+
|
510 |
+
'''
|
511 |
+
results structure :
|
512 |
+
results = {
|
513 |
+
"The quick brown fox jumps over the lazy dog everyday.":
|
514 |
+
{ # Original sentence as key
|
515 |
+
"masked_sentence": str, # The sentence with [MASK] tokens
|
516 |
+
"mask_logits":
|
517 |
+
{ # Dictionary of mask positions and their predictions
|
518 |
+
1:
|
519 |
+
{ # Position of mask in sentence
|
520 |
+
"tokens" (words) : list, # List of top 50 predicted tokens
|
521 |
+
"logits" (probabilities) : list # Corresponding logits for those tokens
|
522 |
+
},
|
523 |
+
7:
|
524 |
+
{
|
525 |
+
"tokens" (words) : list,
|
526 |
+
"logits" (probabilities) : list
|
527 |
+
},
|
528 |
+
10:
|
529 |
+
{
|
530 |
+
"tokens (words)": list,
|
531 |
+
"logits (probabilities)": list
|
532 |
+
}
|
533 |
+
}
|
534 |
+
}
|
535 |
+
}
|
536 |
+
|
537 |
+
'''
|
538 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
539 |
+
|
540 |
+
for sentence, output in results_entropy.items():
|
541 |
+
print(f"Original Sentence (Random): {sentence}")
|
542 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
543 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
544 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
545 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
546 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
547 |
+
# print('--------------------------------')
|
548 |
+
# for mask_idx, logits in output["mask_logits"].items():
|
549 |
+
# print(f"Logits for [MASK] at position {mask_idx}:")
|
550 |
+
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
551 |
+
# print(f' len(logits) : {len(logits)}')
|
552 |
+
|
553 |
+
|
554 |
+
# ------------------------------------------------------------------------------------------------
|
555 |
+
# def mask_sentence_random(self, sentence, common_ngrams):
|
556 |
+
# """
|
557 |
+
# Mask words in the sentence based on the specified rules after removing stopwords.
|
558 |
+
# """
|
559 |
+
# original_words = sentence.split()
|
560 |
+
# # print(f' ---- original_words : {original_words} ----- ')
|
561 |
+
# non_stop_words = self.remove_stopwords(original_words)
|
562 |
+
# # print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
563 |
+
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
564 |
+
# # print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
565 |
+
# # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
566 |
+
|
567 |
+
# mask_indices = []
|
568 |
+
|
569 |
+
# # Extract n-gram positions in non-stop words
|
570 |
+
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
571 |
+
|
572 |
+
# # Mask a word before the first common n-gram
|
573 |
+
# if ngram_positions:
|
574 |
+
# # print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
575 |
+
# first_ngram_start = ngram_positions[0][0]
|
576 |
+
# # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
577 |
+
# if first_ngram_start > 0:
|
578 |
+
# mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
579 |
+
# # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
580 |
+
# mask_indices.append(mask_index_before_ngram)
|
581 |
+
|
582 |
+
# # Mask words between common n-grams
|
583 |
+
# for i in range(len(ngram_positions) - 1):
|
584 |
+
# end_prev = ngram_positions[i][1]
|
585 |
+
# # print(f' ---- end_prev : {end_prev} ----- ')
|
586 |
+
# start_next = ngram_positions[i + 1][0]
|
587 |
+
# # print(f' ---- start_next : {start_next} ----- ')
|
588 |
+
# if start_next > end_prev + 1:
|
589 |
+
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
590 |
+
# # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
591 |
+
# mask_indices.append(mask_index_between_ngrams)
|
592 |
+
|
593 |
+
# # Mask a word after the last common n-gram
|
594 |
+
# last_ngram_end = ngram_positions[-1][1]
|
595 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
596 |
+
# # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
597 |
+
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
598 |
+
# # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
599 |
+
# mask_indices.append(mask_index_after_ngram)
|
600 |
+
|
601 |
+
# # Create mapping from non-stop words to original indices
|
602 |
+
# non_stop_to_original = {}
|
603 |
+
# non_stop_idx = 0
|
604 |
+
# for orig_idx, word in enumerate(original_words):
|
605 |
+
# if word.lower() not in self.stop_words:
|
606 |
+
# non_stop_to_original[non_stop_idx] = orig_idx
|
607 |
+
# non_stop_idx += 1
|
608 |
+
|
609 |
+
# # Map mask indices from non-stop word positions to original positions
|
610 |
+
# # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
611 |
+
# original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
612 |
+
# # print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
613 |
+
|
614 |
+
# # Apply masks to the original sentence
|
615 |
+
# masked_words = original_words.copy()
|
616 |
+
# for idx in original_mask_indices:
|
617 |
+
# masked_words[idx] = self.tokenizer.mask_token
|
618 |
+
|
619 |
+
# return " ".join(masked_words), original_mask_indices
|
utils/old/non_melting_points_v1.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
from nltk.util import ngrams
|
4 |
+
from collections import Counter
|
5 |
+
import re
|
6 |
+
|
7 |
+
class NgramProcessor:
|
8 |
+
def __init__(self):
|
9 |
+
try:
|
10 |
+
nltk.data.find('corpora/stopwords')
|
11 |
+
except LookupError:
|
12 |
+
nltk.download('stopwords')
|
13 |
+
|
14 |
+
self.stop_words = set(stopwords.words('english'))
|
15 |
+
|
16 |
+
def remove_stopwords(self, text):
|
17 |
+
"""
|
18 |
+
Remove stopwords using NLTK's stopword list
|
19 |
+
|
20 |
+
Args:
|
21 |
+
text (str): Input text
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
str: Cleaned text with stopwords removed
|
25 |
+
"""
|
26 |
+
words = re.findall(r'\w+', text.lower())
|
27 |
+
filtered_words = [word for word in words if word not in self.stop_words]
|
28 |
+
return ' '.join(filtered_words)
|
29 |
+
|
30 |
+
def is_exact_match(self, ngram, sentences):
|
31 |
+
"""
|
32 |
+
Check if the given n-gram has an exact match in all sentences
|
33 |
+
|
34 |
+
Args:
|
35 |
+
ngram (str): The n-gram to search for
|
36 |
+
sentences (list): List of sentences to search in
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
bool: True if n-gram has exact match in all sentences, False otherwise
|
40 |
+
"""
|
41 |
+
return all(ngram in sentence for sentence in sentences)
|
42 |
+
|
43 |
+
def is_substring_of_any(self, ngram, common_ngrams):
|
44 |
+
"""
|
45 |
+
Check if the given n-gram is an exact substring of any previously found common n-grams
|
46 |
+
|
47 |
+
Args:
|
48 |
+
ngram (str): The n-gram to check
|
49 |
+
common_ngrams (list): List of previously found common n-grams
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
bool: True if ngram is a substring of any common_ngrams, False otherwise
|
53 |
+
"""
|
54 |
+
return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
|
55 |
+
|
56 |
+
def find_filtered_ngrams(self, sentences):
|
57 |
+
"""
|
58 |
+
Find all n-grams that have exact matches across all sentences,
|
59 |
+
excluding those that are part of larger common n-grams
|
60 |
+
|
61 |
+
Args:
|
62 |
+
sentences (list): List of sentences to analyze
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
list: List of tuples where each tuple contains the n-gram and its indices in each sentence
|
66 |
+
"""
|
67 |
+
original_sentences = sentences[:]
|
68 |
+
sentences = [self.remove_stopwords(sentence) for sentence in sentences]
|
69 |
+
ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
|
70 |
+
common_ngrams = []
|
71 |
+
|
72 |
+
for n in ngram_lengths:
|
73 |
+
ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
|
74 |
+
ngrams_counter = Counter(ngrams_list[0])
|
75 |
+
|
76 |
+
for ngram in ngrams_counter:
|
77 |
+
ngram_str = ' '.join(ngram)
|
78 |
+
if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, [ng[0] for ng in common_ngrams]):
|
79 |
+
indices = []
|
80 |
+
for original_sentence in original_sentences:
|
81 |
+
words = original_sentence.split()
|
82 |
+
ngram_indices = [
|
83 |
+
(i, i + n - 1) for i in range(len(words) - n + 1)
|
84 |
+
if ' '.join(words[i:i + n]).lower() == ngram_str
|
85 |
+
]
|
86 |
+
indices.append(ngram_indices)
|
87 |
+
common_ngrams.append((ngram_str, indices))
|
88 |
+
|
89 |
+
return common_ngrams
|
90 |
+
|
91 |
+
def find_relative_order(self, sentence, common_ngrams):
|
92 |
+
"""
|
93 |
+
Find the relative order of the common n-grams in the sentence
|
94 |
+
|
95 |
+
Args:
|
96 |
+
sentence (str): Sentence in which to find the relative order
|
97 |
+
common_ngrams (list): List of common n-grams
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
list: List of tuples with the relative position and the n-gram
|
101 |
+
"""
|
102 |
+
relative_order = []
|
103 |
+
for ngram, _ in common_ngrams:
|
104 |
+
index = sentence.find(ngram)
|
105 |
+
if index != -1:
|
106 |
+
relative_order.append((index, ngram))
|
107 |
+
|
108 |
+
return sorted(relative_order)
|
109 |
+
|
110 |
+
# Example usage
|
111 |
+
if __name__ == "__main__":
|
112 |
+
sentences = [
|
113 |
+
"The quick brown fox jumps over the lazy dog.",
|
114 |
+
"A quick brown dog outpaces a lazy fox.",
|
115 |
+
"Quick brown animals leap over lazy obstacles."
|
116 |
+
]
|
117 |
+
|
118 |
+
processor = NgramProcessor()
|
119 |
+
common_ngrams = processor.find_filtered_ngrams(sentences)
|
120 |
+
print("Common n-grams and their indices:")
|
121 |
+
for ngram, indices in common_ngrams:
|
122 |
+
print(f"{ngram}: {indices}")
|
123 |
+
|
124 |
+
for sentence in sentences:
|
125 |
+
relative_order = processor.find_relative_order(sentence, common_ngrams)
|
126 |
+
print(f"Relative order in sentence '{sentence}':", relative_order)
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# import nltk
|
131 |
+
# from nltk.corpus import stopwords
|
132 |
+
# from nltk.util import ngrams
|
133 |
+
# from collections import Counter
|
134 |
+
# import re
|
135 |
+
|
136 |
+
# class NgramProcessor:
|
137 |
+
# def __init__(self):
|
138 |
+
# try:
|
139 |
+
# nltk.data.find('corpora/stopwords')
|
140 |
+
# except LookupError:
|
141 |
+
# nltk.download('stopwords')
|
142 |
+
|
143 |
+
# self.stop_words = set(stopwords.words('english'))
|
144 |
+
|
145 |
+
# def remove_stopwords(self, text):
|
146 |
+
# """
|
147 |
+
# Remove stopwords using NLTK's stopword list
|
148 |
+
|
149 |
+
# Args:
|
150 |
+
# text (str): Input text
|
151 |
+
|
152 |
+
# Returns:
|
153 |
+
# str: Cleaned text with stopwords removed
|
154 |
+
# """
|
155 |
+
# words = re.findall(r'\w+', text.lower())
|
156 |
+
# filtered_words = [word for word in words if word not in self.stop_words]
|
157 |
+
# return ' '.join(filtered_words)
|
158 |
+
|
159 |
+
# def is_exact_match(self, ngram, sentences):
|
160 |
+
# """
|
161 |
+
# Check if the given n-gram has an exact match in all sentences
|
162 |
+
|
163 |
+
# Args:
|
164 |
+
# ngram (str): The n-gram to search for
|
165 |
+
# sentences (list): List of sentences to search in
|
166 |
+
|
167 |
+
# Returns:
|
168 |
+
# bool: True if n-gram has exact match in all sentences, False otherwise
|
169 |
+
# """
|
170 |
+
# return all(ngram in sentence for sentence in sentences)
|
171 |
+
|
172 |
+
# def is_substring_of_any(self, ngram, common_ngrams):
|
173 |
+
# """
|
174 |
+
# Check if the given n-gram is an exact substring of any previously found common n-grams
|
175 |
+
|
176 |
+
# Args:
|
177 |
+
# ngram (str): The n-gram to check
|
178 |
+
# common_ngrams (list): List of previously found common n-grams
|
179 |
+
|
180 |
+
# Returns:
|
181 |
+
# bool: True if ngram is a substring of any common_ngrams, False otherwise
|
182 |
+
# """
|
183 |
+
# return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
|
184 |
+
|
185 |
+
# def find_filtered_ngrams(self, sentences):
|
186 |
+
# """
|
187 |
+
# Find all n-grams that have exact matches across all sentences,
|
188 |
+
# excluding those that are part of larger common n-grams
|
189 |
+
|
190 |
+
# Args:
|
191 |
+
# sentences (list): List of sentences to analyze
|
192 |
+
|
193 |
+
# Returns:
|
194 |
+
# list: List of all common n-grams in order of their appearance in the first sentence
|
195 |
+
# """
|
196 |
+
# sentences = [self.remove_stopwords(sentence) for sentence in sentences]
|
197 |
+
# ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
|
198 |
+
# common_ngrams = []
|
199 |
+
|
200 |
+
# for n in ngram_lengths:
|
201 |
+
# ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
|
202 |
+
# ngrams_counter = Counter(ngrams_list[0])
|
203 |
+
|
204 |
+
# for ngram in ngrams_counter:
|
205 |
+
# ngram_str = ' '.join(ngram)
|
206 |
+
# if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, common_ngrams):
|
207 |
+
# common_ngrams.append(ngram_str)
|
208 |
+
|
209 |
+
# return common_ngrams
|
210 |
+
|
211 |
+
# def find_relative_order(self, sentence, common_ngrams):
|
212 |
+
# """
|
213 |
+
# Find the relative order of the common n-grams in the sentence
|
214 |
+
|
215 |
+
# Args:
|
216 |
+
# sentence (str): Sentence in which to find the relative order
|
217 |
+
# common_ngrams (list): List of common n-grams
|
218 |
+
|
219 |
+
# Returns:
|
220 |
+
# list: List of tuples with the relative position and the n-gram
|
221 |
+
# """
|
222 |
+
# relative_order = []
|
223 |
+
# for ngram in common_ngrams:
|
224 |
+
# index = sentence.find(ngram)
|
225 |
+
# if index != -1:
|
226 |
+
# relative_order.append((index, ngram))
|
227 |
+
|
228 |
+
# return sorted(relative_order)
|
229 |
+
|
230 |
+
# # Example usage
|
231 |
+
# if __name__ == "__main__":
|
232 |
+
# sentences = [
|
233 |
+
# "The quick brown fox jumps over the lazy dog.",
|
234 |
+
# "A quick brown dog outpaces a lazy fox.",
|
235 |
+
# "Quick brown animals leap over lazy obstacles."
|
236 |
+
# ]
|
237 |
+
|
238 |
+
# processor = NgramProcessor()
|
239 |
+
# common_ngrams = processor.find_filtered_ngrams(sentences)
|
240 |
+
# print("Common n-grams:", common_ngrams)
|
241 |
+
|
242 |
+
# for sentence in sentences:
|
243 |
+
# relative_order = processor.find_relative_order(sentence, common_ngrams)
|
244 |
+
# print(f"Relative order in sentence '{sentence}':", relative_order)
|
utils/old/sampling/sampling.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from masking_methods import MaskingProcessor
|
4 |
+
import nltk
|
5 |
+
from nltk.corpus import words
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class SamplingProcessor:
|
10 |
+
def __init__(self, tokenizer):
|
11 |
+
"""
|
12 |
+
Initialize the SamplingProcessor.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
tokenizer: BERT tokenizer instance
|
16 |
+
"""
|
17 |
+
self.tokenizer = tokenizer
|
18 |
+
self.subtoken_prefix = self._get_subtoken_prefix()
|
19 |
+
self.subtoken_ids = self._get_subtoken_ids()
|
20 |
+
try:
|
21 |
+
nltk.data.find('corpora/words')
|
22 |
+
except LookupError:
|
23 |
+
nltk.download('words')
|
24 |
+
self.english_words = set(words.words())
|
25 |
+
|
26 |
+
# def _get_subtoken_prefix(self):
|
27 |
+
# """
|
28 |
+
# Identify the subtoken prefix based on the tokenizer.
|
29 |
+
|
30 |
+
# Returns:
|
31 |
+
# str: The prefix used for subtokens (e.g., "##" for BERT).
|
32 |
+
# """
|
33 |
+
# # This method assumes that the tokenizer uses a consistent subtoken prefix.
|
34 |
+
# # Adjust accordingly if using different tokenizers.
|
35 |
+
# # For BERT's WordPiece tokenizer:
|
36 |
+
# if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
|
37 |
+
# return self.tokenizer.init_kwargs["wordpiece_prefix"]
|
38 |
+
# elif hasattr(self.tokenizer, "prefix_tokens"):
|
39 |
+
# return self.tokenizer.prefix_tokens
|
40 |
+
# else:
|
41 |
+
# # Default to BERT's subtoken prefix
|
42 |
+
# return "##"
|
43 |
+
|
44 |
+
def _get_subtoken_prefix(self):
|
45 |
+
"""
|
46 |
+
Identify the subtoken prefix based on the tokenizer.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
str: The prefix used for subtokens (e.g., "##" for BERT).
|
50 |
+
"""
|
51 |
+
# This method assumes that the tokenizer uses a consistent subtoken prefix.
|
52 |
+
# Adjust accordingly if using different tokenizers.
|
53 |
+
# For BERT's WordPiece tokenizer:
|
54 |
+
if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
|
55 |
+
return self.tokenizer.init_kwargs["wordpiece_prefix"]
|
56 |
+
elif hasattr(self.tokenizer, "prefix_tokens"):
|
57 |
+
return self.tokenizer.prefix_tokens
|
58 |
+
else:
|
59 |
+
# Default to BERT's subtoken prefix
|
60 |
+
return "##"
|
61 |
+
|
62 |
+
|
63 |
+
# def _get_subtoken_ids(self):
|
64 |
+
# """
|
65 |
+
# Retrieve all token IDs that correspond to subtokens.
|
66 |
+
|
67 |
+
# Returns:
|
68 |
+
# set: A set of subtoken IDs.
|
69 |
+
# """
|
70 |
+
# vocab = self.tokenizer.get_vocab()
|
71 |
+
# subtoken_ids = set()
|
72 |
+
# for token, idx in vocab.items():
|
73 |
+
# if token.startswith(self.subtoken_prefix):
|
74 |
+
# subtoken_ids.add(idx)
|
75 |
+
# return subtoken_ids
|
76 |
+
|
77 |
+
def _get_subtoken_ids(self):
|
78 |
+
"""
|
79 |
+
Retrieve all token IDs that correspond to subtokens.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
list: A list of subtoken IDs.
|
83 |
+
"""
|
84 |
+
vocab = self.tokenizer.get_vocab()
|
85 |
+
subtoken_ids = []
|
86 |
+
for token, idx in vocab.items():
|
87 |
+
if token.startswith(self.subtoken_prefix):
|
88 |
+
subtoken_ids.append(idx)
|
89 |
+
return subtoken_ids # Changed from set to list
|
90 |
+
|
91 |
+
|
92 |
+
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
|
93 |
+
tokens = self.tokenizer.tokenize(masked_sentence)
|
94 |
+
|
95 |
+
for mask_pos in sorted(mask_logits_dict.keys()):
|
96 |
+
try:
|
97 |
+
# Get logits and squeeze extra dimension
|
98 |
+
mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension
|
99 |
+
|
100 |
+
# Create a mask for valid tokens (no special tokens, no subwords)
|
101 |
+
valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
|
102 |
+
for idx in range(len(mask_logits)):
|
103 |
+
token = self.tokenizer.convert_ids_to_tokens([idx])[0]
|
104 |
+
# Only allow regular words (no special tokens, no subwords)
|
105 |
+
if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
|
106 |
+
valid_mask[idx] = True
|
107 |
+
|
108 |
+
# Get valid logits
|
109 |
+
valid_logits = mask_logits[valid_mask]
|
110 |
+
valid_indices = torch.where(valid_mask)[0]
|
111 |
+
|
112 |
+
if len(valid_logits) == 0:
|
113 |
+
print(f"Warning: No valid tokens found for position {mask_pos}")
|
114 |
+
continue
|
115 |
+
|
116 |
+
if sampling_technique == "inverse_transform":
|
117 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
118 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
119 |
+
random_prob = random.random()
|
120 |
+
sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
121 |
+
sampled_index = valid_indices[sampled_idx].item()
|
122 |
+
|
123 |
+
elif sampling_technique == "exponential_minimum":
|
124 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
125 |
+
exp_probs = torch.exp(-torch.log(probs))
|
126 |
+
random_probs = torch.rand_like(exp_probs)
|
127 |
+
sampled_idx = torch.argmax(random_probs * exp_probs).item()
|
128 |
+
sampled_index = valid_indices[sampled_idx].item()
|
129 |
+
|
130 |
+
elif sampling_technique == "temperature":
|
131 |
+
valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
|
132 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
133 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
134 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
135 |
+
probs = torch.max(probs, torch.tensor(1e-8))
|
136 |
+
probs = probs / torch.sum(probs)
|
137 |
+
sampled_idx = torch.multinomial(probs, 1)[0].item()
|
138 |
+
sampled_index = valid_indices[sampled_idx].item()
|
139 |
+
|
140 |
+
elif sampling_technique == 'greedy':
|
141 |
+
sampled_idx = torch.argmax(valid_logits).item()
|
142 |
+
sampled_index = valid_indices[sampled_idx].item()
|
143 |
+
|
144 |
+
# Replace mask with sampled token
|
145 |
+
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
146 |
+
tokens[mask_pos] = sampled_token
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error sampling for position {mask_pos}: {str(e)}")
|
150 |
+
continue
|
151 |
+
|
152 |
+
return self.tokenizer.convert_tokens_to_string(tokens)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
|
157 |
+
"""
|
158 |
+
Process all masked sentences in the results dictionary.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
results_dict (dict): Dictionary containing masked sentences and their logits
|
162 |
+
sampling_technique (str): Sampling method to use
|
163 |
+
temperature (float): Temperature parameter for sampling
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
dict: Dictionary containing original, masked, and sampled sentences
|
167 |
+
"""
|
168 |
+
processed_results = {}
|
169 |
+
|
170 |
+
for original_sentence, data in results_dict.items():
|
171 |
+
masked_sentence = data["masked_sentence"]
|
172 |
+
mask_logits = data["mask_logits"]
|
173 |
+
|
174 |
+
sampled_sentence = self.sample_tokens(
|
175 |
+
mask_logits,
|
176 |
+
masked_sentence,
|
177 |
+
sampling_technique,
|
178 |
+
temperature
|
179 |
+
)
|
180 |
+
|
181 |
+
processed_results[original_sentence] = {
|
182 |
+
"masked_sentence": masked_sentence,
|
183 |
+
"sampled_sentence": sampled_sentence
|
184 |
+
}
|
185 |
+
|
186 |
+
return processed_results
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
sentences = [
|
190 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
191 |
+
]
|
192 |
+
result_dict = {
|
193 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
194 |
+
}
|
195 |
+
|
196 |
+
# First, mask the sentences
|
197 |
+
masking_processor = MaskingProcessor()
|
198 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict)
|
199 |
+
|
200 |
+
# Then, sample replacements for the masks
|
201 |
+
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
|
202 |
+
|
203 |
+
# Try different sampling techniques
|
204 |
+
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
|
205 |
+
|
206 |
+
for technique in sampling_techniques:
|
207 |
+
print(f"\nSampling using {technique}:")
|
208 |
+
sampled_results = sampling_processor.process_masked_sentences(
|
209 |
+
masking_results,
|
210 |
+
sampling_technique=technique,
|
211 |
+
temperature=1.0
|
212 |
+
)
|
213 |
+
|
214 |
+
for original_sentence, result in sampled_results.items():
|
215 |
+
print(f"Original: {original_sentence}")
|
216 |
+
print(f"Masked: {result['masked_sentence']}")
|
217 |
+
print(f"Sampled: {result['sampled_sentence']}")
|
218 |
+
print("---")
|
219 |
+
|
220 |
+
# --------------------------------------------------------------------------------------------------
|
221 |
+
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
|
222 |
+
# words = masked_sentence.split()
|
223 |
+
# mask_positions = sorted(mask_logits_dict.keys())
|
224 |
+
|
225 |
+
# for mask_pos in mask_positions:
|
226 |
+
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
|
227 |
+
|
228 |
+
# try:
|
229 |
+
# if sampling_technique == "inverse_transform":
|
230 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
231 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
232 |
+
# random_prob = random.random()
|
233 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
234 |
+
|
235 |
+
# elif sampling_technique == "exponential_minimum":
|
236 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
237 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
238 |
+
# random_probs = torch.rand_like(exp_probs)
|
239 |
+
# sampled_index = torch.argmax(random_probs * exp_probs).item()
|
240 |
+
|
241 |
+
# elif sampling_technique == "temperature":
|
242 |
+
# mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
243 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
244 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
245 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
246 |
+
# probs = torch.max(probs, torch.tensor(1e-8))
|
247 |
+
# probs = probs / torch.sum(probs)
|
248 |
+
# sampled_index = torch.multinomial(probs, 1)[0].item()
|
249 |
+
|
250 |
+
# elif sampling_technique == 'greedy':
|
251 |
+
# sampled_index = torch.argmax(mask_logits).item()
|
252 |
+
|
253 |
+
# else:
|
254 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
255 |
+
|
256 |
+
# # Replace mask with sampled token
|
257 |
+
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
258 |
+
# words[mask_pos] = sampled_token
|
259 |
+
|
260 |
+
# except Exception as e:
|
261 |
+
# print(f"Error sampling for position {mask_pos}: {str(e)}")
|
262 |
+
# continue
|
263 |
+
|
264 |
+
# return " ".join(words)
|
265 |
+
|
266 |
+
## MORE WEIRD RESULTS
|
267 |
+
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
|
268 |
+
# words = masked_sentence.split()
|
269 |
+
# mask_positions = sorted(mask_logits_dict.keys())
|
270 |
+
|
271 |
+
# for mask_pos in mask_positions:
|
272 |
+
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
|
273 |
+
|
274 |
+
# try:
|
275 |
+
# # Create a mask for valid tokens (no special tokens, no subwords)
|
276 |
+
# valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
|
277 |
+
# for idx in range(len(mask_logits)):
|
278 |
+
# token = self.tokenizer.convert_ids_to_tokens([idx])[0]
|
279 |
+
# # Only allow regular words (no special tokens, no subwords)
|
280 |
+
# if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
|
281 |
+
# valid_mask[idx] = True
|
282 |
+
|
283 |
+
# # Get valid logits
|
284 |
+
# valid_logits = mask_logits[valid_mask]
|
285 |
+
# valid_indices = torch.where(valid_mask)[0]
|
286 |
+
|
287 |
+
# if len(valid_logits) == 0:
|
288 |
+
# print(f"Warning: No valid tokens found for position {mask_pos}")
|
289 |
+
# continue
|
290 |
+
|
291 |
+
# if sampling_technique == "inverse_transform":
|
292 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
293 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
294 |
+
# random_prob = random.random()
|
295 |
+
# sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
296 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
297 |
+
|
298 |
+
# elif sampling_technique == "exponential_minimum":
|
299 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
300 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
301 |
+
# random_probs = torch.rand_like(exp_probs)
|
302 |
+
# sampled_idx = torch.argmax(random_probs * exp_probs).item()
|
303 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
304 |
+
|
305 |
+
# elif sampling_technique == "temperature":
|
306 |
+
# valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
|
307 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
308 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
309 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
310 |
+
# probs = torch.max(probs, torch.tensor(1e-8))
|
311 |
+
# probs = probs / torch.sum(probs)
|
312 |
+
# sampled_idx = torch.multinomial(probs, 1)[0].item()
|
313 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
314 |
+
|
315 |
+
# elif sampling_technique == 'greedy':
|
316 |
+
# sampled_idx = torch.argmax(valid_logits).item()
|
317 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
318 |
+
|
319 |
+
# else:
|
320 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
321 |
+
|
322 |
+
# # Replace mask with sampled token
|
323 |
+
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
324 |
+
# words[mask_pos] = sampled_token
|
325 |
+
|
326 |
+
# except Exception as e:
|
327 |
+
# print(f"Error sampling for position {mask_pos}: {str(e)}")
|
328 |
+
# continue
|
329 |
+
|
330 |
+
# return " ".join(words)
|
utils/old/sampling/sampling_methods.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from masking_methods import MaskingProcessor
|
5 |
+
from transformers import pipeline
|
6 |
+
|
7 |
+
class SamplingProcessorWithModel:
|
8 |
+
def __init__(self, model_name='bert-base-uncased'):
|
9 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
10 |
+
self.model = BertForMaskedLM.from_pretrained(model_name)
|
11 |
+
self.model.eval() # Set the model to evaluation mode
|
12 |
+
|
13 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
14 |
+
"""
|
15 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
19 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
20 |
+
temperature (float): Temperature parameter for sampling methods.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: Sentence with the masks filled.
|
24 |
+
"""
|
25 |
+
input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
|
26 |
+
|
27 |
+
while self.tokenizer.mask_token_id in input_ids[0]:
|
28 |
+
# Find indices of all [MASK] tokens
|
29 |
+
mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
30 |
+
|
31 |
+
# Process the first [MASK] token in the sequence
|
32 |
+
mask_index = mask_indices[0].item()
|
33 |
+
|
34 |
+
# Get logits from the model
|
35 |
+
with torch.no_grad():
|
36 |
+
outputs = self.model(input_ids)
|
37 |
+
logits = outputs.logits
|
38 |
+
|
39 |
+
# Extract logits for the [MASK] token
|
40 |
+
mask_logits = logits[0, mask_index]
|
41 |
+
|
42 |
+
if sampling_technique == "inverse_transform":
|
43 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
44 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
45 |
+
random_prob = random.random()
|
46 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
47 |
+
|
48 |
+
elif sampling_technique == "exponential_minimum":
|
49 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
50 |
+
exp_probs = torch.exp(-torch.log(probs))
|
51 |
+
random_probs = torch.rand_like(exp_probs)
|
52 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
53 |
+
|
54 |
+
elif sampling_technique == "temperature":
|
55 |
+
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
56 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
57 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
58 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
59 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
|
60 |
+
probs = probs / torch.sum(probs)
|
61 |
+
probs = probs.flatten()
|
62 |
+
if probs.size(0) > 1:
|
63 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
64 |
+
else:
|
65 |
+
sampled_index = torch.argmax(probs).item()
|
66 |
+
|
67 |
+
elif sampling_technique == 'greedy':
|
68 |
+
sampled_index = torch.argmax(mask_logits).item()
|
69 |
+
|
70 |
+
else:
|
71 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
72 |
+
|
73 |
+
# Replace the first [MASK] with the selected token
|
74 |
+
input_ids[0, mask_index] = sampled_index
|
75 |
+
|
76 |
+
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
77 |
+
|
78 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
79 |
+
"""
|
80 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
84 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
85 |
+
temperature (float): Temperature parameter for sampling methods.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
str: Sentence with the masks filled.
|
89 |
+
"""
|
90 |
+
while '[MASK]' in masked_sentence:
|
91 |
+
# Get predictions for the first [MASK]
|
92 |
+
predictions = self.unmasker(masked_sentence)
|
93 |
+
|
94 |
+
# Ensure predictions is a list of dictionaries
|
95 |
+
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
96 |
+
raise ValueError("Unexpected structure in predictions from the pipeline.")
|
97 |
+
|
98 |
+
# Extract logits (scores) from the predictions
|
99 |
+
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
100 |
+
|
101 |
+
if sampling_technique == "inverse_transform":
|
102 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
103 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
104 |
+
random_prob = random.random()
|
105 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
106 |
+
|
107 |
+
elif sampling_technique == "exponential_minimum":
|
108 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
109 |
+
exp_probs = torch.exp(-torch.log(probs))
|
110 |
+
random_probs = torch.rand_like(exp_probs)
|
111 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
112 |
+
|
113 |
+
elif sampling_technique == "temperature":
|
114 |
+
logits = torch.clamp(logits, min=-1e8, max=1e8)
|
115 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
116 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
117 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
118 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
119 |
+
probs = probs / torch.sum(probs)
|
120 |
+
probs = probs.flatten()
|
121 |
+
if probs.size(0) > 1:
|
122 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
123 |
+
else:
|
124 |
+
sampled_index = torch.argmax(probs).item()
|
125 |
+
|
126 |
+
elif sampling_technique == 'greedy':
|
127 |
+
sampled_index = torch.argmax(logits).item()
|
128 |
+
|
129 |
+
else:
|
130 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
131 |
+
|
132 |
+
# Replace the first [MASK] with the selected word
|
133 |
+
sampled_token = predictions[sampled_index]['token_str']
|
134 |
+
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
135 |
+
|
136 |
+
return masked_sentence
|
137 |
+
|
138 |
+
|
139 |
+
# Example usage
|
140 |
+
if __name__ == "__main__":
|
141 |
+
from transformers import BertTokenizer
|
142 |
+
|
143 |
+
# Define sentences and result_dict
|
144 |
+
sentences = [
|
145 |
+
"The quick brown fox jumps over the lazy dog.",
|
146 |
+
"A quick brown dog outpaces a lazy fox.",
|
147 |
+
"Quick brown dog leaps over lazy the fox."
|
148 |
+
]
|
149 |
+
result_dict = {
|
150 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
151 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
152 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
153 |
+
}
|
154 |
+
|
155 |
+
masking_processor = MaskingProcessor()
|
156 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
157 |
+
|
158 |
+
# Use SamplingProcessor
|
159 |
+
sampling_processor = SamplingProcessorWithModel()
|
160 |
+
|
161 |
+
# Iterate through masking results to apply sampling
|
162 |
+
for sentence, result in masking_results.items():
|
163 |
+
print(f"Original Sentence (Random): {sentence}")
|
164 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
165 |
+
masked_sentence = result["masked_sentence"]
|
166 |
+
|
167 |
+
# Apply different sampling techniques
|
168 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
169 |
+
print(f"Sampling Technique: {technique}")
|
170 |
+
filled_sentence = sampling_processor.fill_masked_sentence(
|
171 |
+
masked_sentence=masked_sentence,
|
172 |
+
sampling_technique=technique,
|
173 |
+
temperature=1.0 # Adjust temperature as needed
|
174 |
+
)
|
175 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
176 |
+
print('--------------------------------')
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
# from transformers import pipeline
|
181 |
+
# import torch
|
182 |
+
# import random
|
183 |
+
# from masking_methods import MaskingProcessor
|
184 |
+
|
185 |
+
|
186 |
+
# class SamplingProcessorWithPipeline:
|
187 |
+
# def __init__(self, model_name='bert-base-uncased'):
|
188 |
+
# self.unmasker = pipeline('fill-mask', model=model_name)
|
189 |
+
# self.tokenizer = self.unmasker.tokenizer
|
190 |
+
|
191 |
+
# def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
192 |
+
# """
|
193 |
+
# Fills each mask in the masked sentence using the specified sampling technique.
|
194 |
+
|
195 |
+
# Args:
|
196 |
+
# masked_sentence (str): Sentence with [MASK] tokens.
|
197 |
+
# sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
198 |
+
# temperature (float): Temperature parameter for sampling methods.
|
199 |
+
|
200 |
+
# Returns:
|
201 |
+
# str: Sentence with the masks filled.
|
202 |
+
# """
|
203 |
+
# while '[MASK]' in masked_sentence:
|
204 |
+
# # Get predictions for the first [MASK]
|
205 |
+
# predictions = self.unmasker(masked_sentence)
|
206 |
+
# print(f' predictions : {predictions}')
|
207 |
+
# print(f' type of predictions : {type(predictions)}')
|
208 |
+
|
209 |
+
# # Ensure predictions is a list of dictionaries for the first [MASK]
|
210 |
+
# if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
211 |
+
# raise ValueError("Unexpected structure in predictions from the pipeline.")
|
212 |
+
|
213 |
+
# # Extract logits (scores) from the predictions
|
214 |
+
# logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
215 |
+
|
216 |
+
# if sampling_technique == "inverse_transform":
|
217 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
218 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
219 |
+
# random_prob = random.random()
|
220 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
221 |
+
|
222 |
+
# elif sampling_technique == "exponential_minimum":
|
223 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
224 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
225 |
+
# random_probs = torch.rand_like(exp_probs)
|
226 |
+
# sampled_index = torch.argmax(random_probs * exp_probs).item()
|
227 |
+
|
228 |
+
# elif sampling_technique == "temperature":
|
229 |
+
# logits = torch.clamp(logits, min=-1e8, max=1e8)
|
230 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
231 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
232 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
233 |
+
# probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
234 |
+
# probs = probs / torch.sum(probs)
|
235 |
+
# probs = probs.flatten()
|
236 |
+
# if probs.size(0) > 1:
|
237 |
+
# sampled_index = torch.multinomial(probs, 1).item()
|
238 |
+
# else:
|
239 |
+
# sampled_index = torch.argmax(probs).item()
|
240 |
+
|
241 |
+
# elif sampling_technique == 'greedy':
|
242 |
+
# sampled_index = torch.argmax(logits).item()
|
243 |
+
|
244 |
+
# else:
|
245 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
246 |
+
|
247 |
+
# # Replace the first [MASK] with the selected word
|
248 |
+
# sampled_token = predictions[sampled_index]['token_str']
|
249 |
+
# masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
250 |
+
|
251 |
+
# return masked_sentence
|
252 |
+
|
253 |
+
|
254 |
+
# # Example usage
|
255 |
+
# if __name__ == "__main__":
|
256 |
+
# from transformers import BertTokenizer
|
257 |
+
|
258 |
+
# # Define sentences and result_dict
|
259 |
+
# sentences = [
|
260 |
+
# "The quick brown fox jumps over the lazy dog.",
|
261 |
+
# "A quick brown dog outpaces a lazy fox.",
|
262 |
+
# "Quick brown animals leap over lazy obstacles."
|
263 |
+
# ]
|
264 |
+
# result_dict = {
|
265 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
266 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
267 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
268 |
+
# }
|
269 |
+
|
270 |
+
# masking_processor = MaskingProcessor()
|
271 |
+
# masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
272 |
+
|
273 |
+
# # Use SamplingProcessor
|
274 |
+
# sampling_processor = SamplingProcessorWithPipeline()
|
275 |
+
|
276 |
+
# # Iterate through masking results to apply sampling
|
277 |
+
# for sentence, result in masking_results.items():
|
278 |
+
# print(f"Original Sentence (Random): {sentence}")
|
279 |
+
# print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
280 |
+
# masked_sentence = result["masked_sentence"]
|
281 |
+
|
282 |
+
# # Apply different sampling techniques
|
283 |
+
# for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
284 |
+
# print(f"Sampling Technique: {technique}")
|
285 |
+
# filled_sentence = sampling_processor.fill_masked_sentence(
|
286 |
+
# masked_sentence=masked_sentence,
|
287 |
+
# sampling_technique=technique,
|
288 |
+
# temperature=1.0 # Adjust temperature as needed
|
289 |
+
# )
|
290 |
+
# print(f"Filled Sentence: {filled_sentence}\n")
|
291 |
+
# print('--------------------------------')
|
utils/old/sampling/sampling_methods_v1.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from masking_methods import MaskingProcessor
|
4 |
+
|
5 |
+
class SamplingProcessor:
|
6 |
+
def __init__(self, tokenizer):
|
7 |
+
self.tokenizer = tokenizer
|
8 |
+
|
9 |
+
def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
|
10 |
+
"""
|
11 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
original_sentence (str): The original masked sentence.
|
15 |
+
mask_logits (dict): Logits for each [MASK] token.
|
16 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
17 |
+
temperature (float): Temperature parameter for sampling methods.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: Sentence with the masks filled.
|
21 |
+
"""
|
22 |
+
sentence_tokens = self.tokenizer.tokenize(original_sentence)
|
23 |
+
mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
|
24 |
+
|
25 |
+
if len(mask_token_indices) != len(mask_logits):
|
26 |
+
raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
|
27 |
+
|
28 |
+
for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
|
29 |
+
# Convert logits to a tensor
|
30 |
+
filtered_logits = torch.tensor(filtered_logits)
|
31 |
+
# filtered_logits, _ = torch.sort(filtered_logits, descending=True)
|
32 |
+
# print(f' type of filtered_logits : {type(filtered_logits)}')
|
33 |
+
# filtered_logits = filtered_logits[:5]
|
34 |
+
|
35 |
+
if sampling_technique == "inverse_transform":
|
36 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
37 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
38 |
+
random_prob = random.random()
|
39 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
40 |
+
|
41 |
+
elif sampling_technique == "exponential_minimum":
|
42 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
43 |
+
exp_probs = torch.exp(-torch.log(probs))
|
44 |
+
random_probs = torch.rand_like(exp_probs)
|
45 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
46 |
+
|
47 |
+
elif sampling_technique == "temperature":
|
48 |
+
filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
|
49 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
50 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
51 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
52 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
|
53 |
+
probs = probs / torch.sum(probs)
|
54 |
+
probs = probs.flatten()
|
55 |
+
if probs.size(0) > 1:
|
56 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
57 |
+
else:
|
58 |
+
sampled_index = torch.argmax(probs).item()
|
59 |
+
|
60 |
+
elif sampling_technique == 'greedy':
|
61 |
+
sampled_index = torch.argmax(filtered_logits).item()
|
62 |
+
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
65 |
+
|
66 |
+
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
67 |
+
sentence_tokens[mask_idx] = sampled_token
|
68 |
+
|
69 |
+
return self.tokenizer.convert_tokens_to_string(sentence_tokens)
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
|
74 |
+
"""
|
75 |
+
Process multiple masked sentences and fill their masks using the specified sampling technique.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
masked_sentences (list): List of masked sentences.
|
79 |
+
mask_logits (dict): Logits for each [MASK] token in each sentence.
|
80 |
+
sampling_technique (str): Sampling technique to use.
|
81 |
+
temperature (float): Temperature parameter for sampling methods.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
list: List of sentences with masks filled.
|
85 |
+
"""
|
86 |
+
filled_sentences = []
|
87 |
+
for sentence, logits in zip(masked_sentences, mask_logits):
|
88 |
+
filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
|
89 |
+
filled_sentences.append(filled_sentence)
|
90 |
+
return filled_sentences
|
91 |
+
|
92 |
+
# Example usage
|
93 |
+
if __name__ == "__main__":
|
94 |
+
from transformers import BertTokenizer
|
95 |
+
|
96 |
+
# tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
97 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
98 |
+
processor = SamplingProcessor(tokenizer)
|
99 |
+
|
100 |
+
sentences = [
|
101 |
+
"The quick brown fox jumps over the lazy dog.",
|
102 |
+
"A quick brown dog outpaces a lazy fox.",
|
103 |
+
"Quick brown dog leaps over lazy the fox."
|
104 |
+
]
|
105 |
+
result_dict = {
|
106 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
107 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
108 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
masking_processor = MaskingProcessor()
|
113 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
114 |
+
# masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
|
115 |
+
# mask_logits = {
|
116 |
+
# 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
|
117 |
+
# 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
|
118 |
+
# }
|
119 |
+
|
120 |
+
# Iterate through masking results to apply sampling
|
121 |
+
for sentence, result in masking_results.items():
|
122 |
+
print(f"Original Sentence (Random): {sentence}")
|
123 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
124 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
125 |
+
print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
|
126 |
+
print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
|
127 |
+
print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
|
128 |
+
masked_sentence = result["masked_sentence"]
|
129 |
+
mask_logits = result["mask_logits"]
|
130 |
+
|
131 |
+
print(f"Original Masked Sentence: {masked_sentence}")
|
132 |
+
|
133 |
+
# Apply different sampling techniques
|
134 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
135 |
+
print(f"Sampling Technique: {technique}")
|
136 |
+
|
137 |
+
# Fill the masks using the sampling processor
|
138 |
+
filled_sentence = processor.fill_masked_sentence(
|
139 |
+
original_sentence=masked_sentence,
|
140 |
+
mask_logits=mask_logits,
|
141 |
+
sampling_technique=technique,
|
142 |
+
temperature=1.0 # Adjust temperature as needed
|
143 |
+
)
|
144 |
+
|
145 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
146 |
+
print('--------------------------------')
|
utils/old/sampling/sampling_methods_v2.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from masking_methods import MaskingProcessor
|
5 |
+
|
6 |
+
|
7 |
+
class SamplingProcessorWithPipeline:
|
8 |
+
def __init__(self, model_name='bert-base-uncased'):
|
9 |
+
self.unmasker = pipeline('fill-mask', model=model_name)
|
10 |
+
self.tokenizer = self.unmasker.tokenizer
|
11 |
+
|
12 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
13 |
+
"""
|
14 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
18 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
19 |
+
temperature (float): Temperature parameter for sampling methods.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: Sentence with the masks filled.
|
23 |
+
"""
|
24 |
+
while '[MASK]' in masked_sentence:
|
25 |
+
# Get predictions for the first [MASK]
|
26 |
+
predictions = self.unmasker(masked_sentence)
|
27 |
+
print(f' predictions : {predictions}')
|
28 |
+
print(f' type of predictions : {type(predictions)}')
|
29 |
+
|
30 |
+
# Ensure predictions is a list of dictionaries
|
31 |
+
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
32 |
+
raise ValueError("Unexpected structure in predictions from the pipeline.")
|
33 |
+
|
34 |
+
# Extract logits (scores) from the predictions
|
35 |
+
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
36 |
+
|
37 |
+
if sampling_technique == "inverse_transform":
|
38 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
39 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
40 |
+
random_prob = random.random()
|
41 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
42 |
+
|
43 |
+
elif sampling_technique == "exponential_minimum":
|
44 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
45 |
+
exp_probs = torch.exp(-torch.log(probs))
|
46 |
+
random_probs = torch.rand_like(exp_probs)
|
47 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
48 |
+
|
49 |
+
elif sampling_technique == "temperature":
|
50 |
+
logits = torch.clamp(logits, min=-1e8, max=1e8)
|
51 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
52 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
53 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
54 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
55 |
+
probs = probs / torch.sum(probs)
|
56 |
+
probs = probs.flatten()
|
57 |
+
if probs.size(0) > 1:
|
58 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
59 |
+
else:
|
60 |
+
sampled_index = torch.argmax(probs).item()
|
61 |
+
|
62 |
+
elif sampling_technique == 'greedy':
|
63 |
+
sampled_index = torch.argmax(logits).item()
|
64 |
+
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
67 |
+
|
68 |
+
# Replace the first [MASK] with the selected word
|
69 |
+
sampled_token = predictions[sampled_index]['token_str']
|
70 |
+
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
71 |
+
|
72 |
+
return masked_sentence
|
73 |
+
|
74 |
+
|
75 |
+
# Example usage
|
76 |
+
if __name__ == "__main__":
|
77 |
+
from transformers import BertTokenizer
|
78 |
+
|
79 |
+
# Define sentences and result_dict
|
80 |
+
sentences = [
|
81 |
+
"The quick brown fox jumps over the lazy dog.",
|
82 |
+
"A quick brown dog outpaces a lazy fox.",
|
83 |
+
"Quick brown dog leaps over lazy the fox."
|
84 |
+
]
|
85 |
+
result_dict = {
|
86 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
87 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
88 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
89 |
+
}
|
90 |
+
|
91 |
+
masking_processor = MaskingProcessor()
|
92 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
93 |
+
|
94 |
+
# Use SamplingProcessor
|
95 |
+
sampling_processor = SamplingProcessorWithPipeline()
|
96 |
+
|
97 |
+
# Iterate through masking results to apply sampling
|
98 |
+
for sentence, result in masking_results.items():
|
99 |
+
print(f"Original Sentence (Random): {sentence}")
|
100 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
101 |
+
masked_sentence = result["masked_sentence"]
|
102 |
+
|
103 |
+
# Apply different sampling techniques
|
104 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
105 |
+
print(f"Sampling Technique: {technique}")
|
106 |
+
filled_sentence = sampling_processor.fill_masked_sentence(
|
107 |
+
masked_sentence=masked_sentence,
|
108 |
+
sampling_technique=technique,
|
109 |
+
temperature=1.0 # Adjust temperature as needed
|
110 |
+
)
|
111 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
112 |
+
print('--------------------------------')
|
utils/old/sampling_final_copy.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from masking_methods import MaskingProcessor
|
4 |
+
|
5 |
+
class SamplingProcessor:
|
6 |
+
def __init__(self, tokenizer):
|
7 |
+
"""
|
8 |
+
Initialize the SamplingProcessor.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
tokenizer: BERT tokenizer instance
|
12 |
+
"""
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
|
15 |
+
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
|
16 |
+
"""
|
17 |
+
Sample tokens for each mask in the sentence using the specified sampling technique.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens
|
21 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
22 |
+
sampling_technique (str): Sampling method to use
|
23 |
+
temperature (float): Temperature parameter for sampling
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: Sentence with sampled tokens replacing masks
|
27 |
+
"""
|
28 |
+
words = masked_sentence.split()
|
29 |
+
|
30 |
+
# Convert positions and logits to sorted list to process masks in order
|
31 |
+
mask_positions = sorted(mask_logits_dict.keys())
|
32 |
+
|
33 |
+
for mask_pos in mask_positions:
|
34 |
+
mask_data = mask_logits_dict[mask_pos]
|
35 |
+
mask_logits = torch.tensor(mask_data['logits'])
|
36 |
+
candidate_tokens = mask_data['tokens']
|
37 |
+
|
38 |
+
try:
|
39 |
+
if sampling_technique == "inverse_transform":
|
40 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
41 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
42 |
+
random_prob = random.random()
|
43 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
44 |
+
|
45 |
+
elif sampling_technique == "exponential_minimum":
|
46 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
47 |
+
exp_probs = torch.exp(-torch.log(probs))
|
48 |
+
random_probs = torch.rand_like(exp_probs)
|
49 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
50 |
+
|
51 |
+
elif sampling_technique == "temperature":
|
52 |
+
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
53 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
54 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
55 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
56 |
+
probs = torch.max(probs, torch.tensor(1e-8))
|
57 |
+
probs = probs / torch.sum(probs)
|
58 |
+
probs = probs.flatten()
|
59 |
+
if probs.size(0) > 1:
|
60 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
61 |
+
else:
|
62 |
+
sampled_index = torch.argmax(probs).item()
|
63 |
+
|
64 |
+
elif sampling_technique == 'greedy':
|
65 |
+
sampled_index = torch.argmax(mask_logits).item()
|
66 |
+
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
69 |
+
|
70 |
+
# Use the sampled index to get the corresponding token
|
71 |
+
sampled_token = candidate_tokens[sampled_index]
|
72 |
+
# Remove ## if it's a subword token
|
73 |
+
sampled_token = sampled_token.replace('##', '')
|
74 |
+
words[mask_pos] = sampled_token
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error sampling for position {mask_pos}: {str(e)}")
|
78 |
+
continue
|
79 |
+
|
80 |
+
return " ".join(words)
|
81 |
+
|
82 |
+
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
|
83 |
+
"""
|
84 |
+
Process all masked sentences in the results dictionary.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
results_dict (dict): Dictionary containing masked sentences and their logits
|
88 |
+
sampling_technique (str): Sampling method to use
|
89 |
+
temperature (float): Temperature parameter for sampling
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict: Dictionary containing original, masked, and sampled sentences
|
93 |
+
"""
|
94 |
+
processed_results = {}
|
95 |
+
|
96 |
+
for original_sentence, data in results_dict.items():
|
97 |
+
masked_sentence = data["masked_sentence"]
|
98 |
+
mask_logits = data["mask_logits"]
|
99 |
+
|
100 |
+
sampled_sentence = self.sample_tokens(
|
101 |
+
mask_logits,
|
102 |
+
masked_sentence,
|
103 |
+
sampling_technique,
|
104 |
+
temperature
|
105 |
+
)
|
106 |
+
|
107 |
+
processed_results[original_sentence] = {
|
108 |
+
"masked_sentence": masked_sentence,
|
109 |
+
"sampled_sentence": sampled_sentence
|
110 |
+
}
|
111 |
+
|
112 |
+
return processed_results
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
sentences = [
|
117 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
118 |
+
"A speedy brown fox jumps over a lazy dog.",
|
119 |
+
"A swift brown fox leaps over the lethargic dog."
|
120 |
+
|
121 |
+
]
|
122 |
+
result_dict ={
|
123 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
124 |
+
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
125 |
+
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
126 |
+
}
|
127 |
+
|
128 |
+
# First, mask the sentences
|
129 |
+
masking_processor = MaskingProcessor()
|
130 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict)
|
131 |
+
|
132 |
+
# Then, sample replacements for the masks
|
133 |
+
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
|
134 |
+
|
135 |
+
# Try different sampling techniques
|
136 |
+
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
|
137 |
+
|
138 |
+
for technique in sampling_techniques:
|
139 |
+
print(f"\nSampling using {technique}:")
|
140 |
+
sampled_results = sampling_processor.process_masked_sentences(
|
141 |
+
masking_results,
|
142 |
+
sampling_technique=technique,
|
143 |
+
temperature=1.0
|
144 |
+
)
|
145 |
+
|
146 |
+
'''
|
147 |
+
{
|
148 |
+
"original_sentence_1":
|
149 |
+
{
|
150 |
+
"masked_sentence": "sentence with [MASK] tokens",
|
151 |
+
"sampling_method1": "sentence with sampled tokens",
|
152 |
+
},
|
153 |
+
"original_sentence_2":
|
154 |
+
{
|
155 |
+
"masked_sentence": "sentence with [MASK] tokens",
|
156 |
+
"sampling_method": "sentence with sampled tokens"
|
157 |
+
},
|
158 |
+
# ... and so on for each input sentence
|
159 |
+
},
|
160 |
+
|
161 |
+
'''
|
162 |
+
|
163 |
+
for original_sentence, result in sampled_results.items():
|
164 |
+
print(f"Original: {original_sentence}")
|
165 |
+
print(f"Masked: {result['masked_sentence']}")
|
166 |
+
print(f"Sampled: {result['sampled_sentence']}")
|
167 |
+
print("---")
|
168 |
+
|
utils/paraphraser.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the code to generate paraphrases of sentences.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
8 |
+
from tqdm import tqdm # for progress bars
|
9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
10 |
+
|
11 |
+
from utils.config import load_config
|
12 |
+
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
|
13 |
+
# config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase']
|
14 |
+
|
15 |
+
# Configure logging to show only warnings or above on the terminal.
|
16 |
+
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class Paraphraser:
|
20 |
+
"""
|
21 |
+
Paraphraser class to generate paraphrases of sentences.
|
22 |
+
"""
|
23 |
+
def __init__(self, config):
|
24 |
+
self.config = config
|
25 |
+
import torch
|
26 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
tqdm.write(f"[Paraphraser] Initializing on device: {self.device}")
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
|
29 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device)
|
30 |
+
self.num_beams = config['num_beams']
|
31 |
+
self.num_beam_groups = config['num_beam_groups']
|
32 |
+
self.num_return_sequences = config['num_return_sequences']
|
33 |
+
self.repetition_penalty = config['repetition_penalty']
|
34 |
+
self.diversity_penalty = config['diversity_penalty']
|
35 |
+
self.no_repeat_ngram_size = config['no_repeat_ngram_size']
|
36 |
+
self.temperature = config['temperature']
|
37 |
+
self.max_length = config['max_length']
|
38 |
+
|
39 |
+
def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None):
|
40 |
+
tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}")
|
41 |
+
if num_return_sequences is None:
|
42 |
+
num_return_sequences = self.num_return_sequences
|
43 |
+
if num_beams is None:
|
44 |
+
num_beams = self.num_beams
|
45 |
+
if num_beam_groups is None:
|
46 |
+
num_beam_groups = self.num_beam_groups
|
47 |
+
|
48 |
+
inputs = self.tokenizer.encode("paraphrase: " + sentence,
|
49 |
+
return_tensors="pt",
|
50 |
+
max_length=self.max_length,
|
51 |
+
truncation=True).to(self.device)
|
52 |
+
outputs = self.model.generate(
|
53 |
+
inputs,
|
54 |
+
max_length=self.max_length,
|
55 |
+
num_beams=num_beams,
|
56 |
+
num_beam_groups=num_beam_groups,
|
57 |
+
num_return_sequences=num_return_sequences,
|
58 |
+
repetition_penalty=self.repetition_penalty,
|
59 |
+
diversity_penalty=self.diversity_penalty,
|
60 |
+
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
61 |
+
temperature=self.temperature
|
62 |
+
)
|
63 |
+
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
|
64 |
+
for output in tqdm(outputs, desc="Decoding Paraphrases")]
|
65 |
+
tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.")
|
66 |
+
return paraphrases
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml'
|
70 |
+
config = load_config(config_path)
|
71 |
+
paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase'])
|
72 |
+
sentence = "The quick brown fox jumps over the lazy dog."
|
73 |
+
paraphrases = paraphraser.paraphrase(sentence)
|
74 |
+
for paraphrase in paraphrases:
|
75 |
+
print(paraphrase)
|