jgyasu commited on
Commit
060ac52
·
1 Parent(s): 80bc0f8

Add entire pipeline

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. UI/__pycache__/gradio.cpython-310.pyc +0 -0
  2. UI/__pycache__/gradio.cpython-311.pyc +0 -0
  3. UI/gradio.py +516 -0
  4. __pycache__/app.cpython-310.pyc +0 -0
  5. app.py +21 -0
  6. environment.yml +245 -0
  7. metrics/distortion.py +370 -0
  8. renderers/__pycache__/highlighter.cpython-310.pyc +0 -0
  9. renderers/__pycache__/highlighter.cpython-311.pyc +0 -0
  10. renderers/__pycache__/plot_3d.cpython-310.pyc +0 -0
  11. renderers/__pycache__/plot_3d.cpython-311.pyc +0 -0
  12. renderers/__pycache__/tree.cpython-310.pyc +0 -0
  13. renderers/__pycache__/tree.cpython-311.pyc +0 -0
  14. renderers/highlighter.py +162 -0
  15. renderers/plot_3d.py +126 -0
  16. renderers/tree.py +490 -0
  17. utils/__init__.py +5 -0
  18. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  19. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  20. utils/__pycache__/config.cpython-310.pyc +0 -0
  21. utils/__pycache__/config.cpython-311.pyc +0 -0
  22. utils/__pycache__/entailment.cpython-310.pyc +0 -0
  23. utils/__pycache__/entailment.cpython-311.pyc +0 -0
  24. utils/__pycache__/masking_methods.cpython-310.pyc +0 -0
  25. utils/__pycache__/masking_methods.cpython-311.pyc +0 -0
  26. utils/__pycache__/non_melting_point.cpython-310.pyc +0 -0
  27. utils/__pycache__/non_melting_point.cpython-311.pyc +0 -0
  28. utils/__pycache__/paraphraser.cpython-310.pyc +0 -0
  29. utils/__pycache__/paraphraser.cpython-311.pyc +0 -0
  30. utils/__pycache__/sampling.cpython-310.pyc +0 -0
  31. utils/__pycache__/sampling.cpython-311.pyc +0 -0
  32. utils/__pycache__/watermark.cpython-310.pyc +0 -0
  33. utils/__pycache__/watermark.cpython-311.pyc +0 -0
  34. utils/config.py +18 -0
  35. utils/config.yaml +48 -0
  36. utils/entailment.py +107 -0
  37. utils/masking_methods.py +304 -0
  38. utils/non_melting_point.py +137 -0
  39. utils/old/masking/masking_methods.py +355 -0
  40. utils/old/masking/masking_methods_new_work.py +447 -0
  41. utils/old/masking/masking_methods_ok_working.py +257 -0
  42. utils/old/masking/masking_methods_v1_working.py +233 -0
  43. utils/old/masking_methods_final_copy.py +619 -0
  44. utils/old/non_melting_points_v1.py +244 -0
  45. utils/old/sampling/sampling.py +330 -0
  46. utils/old/sampling/sampling_methods.py +291 -0
  47. utils/old/sampling/sampling_methods_v1.py +146 -0
  48. utils/old/sampling/sampling_methods_v2.py +112 -0
  49. utils/old/sampling_final_copy.py +168 -0
  50. utils/paraphraser.py +75 -0
UI/__pycache__/gradio.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
UI/__pycache__/gradio.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
UI/gradio.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.watermark import Watermarker
3
+ from utils.config import load_config
4
+ from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
5
+ from renderers.tree import generate_subplot1, generate_subplot2
6
+ from pathlib import Path
7
+ import time
8
+ from typing import Dict, List, Tuple, Any
9
+ import plotly.graph_objects as go
10
+
11
+ class WatermarkerInterface:
12
+ def __init__(self, config):
13
+
14
+ self.pipeline = Watermarker(config)
15
+ self.common_grams = {}
16
+ self.highlight_info = []
17
+ self.masked_sentences = []
18
+
19
+ def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]:
20
+ """Wrapper for paraphrasing that includes highlighting"""
21
+ start_time = time.time()
22
+
23
+ # Run paraphrasing
24
+ self.pipeline.Paraphrase(prompt)
25
+
26
+ # Step 1: Process the original sentence first
27
+ seen_ngrams = {} # Stores first occurrence index of each n-gram
28
+ original_indexed_ngrams = [] # Final indexed list for original
29
+
30
+ original_sentence = self.pipeline.user_prompt
31
+ original_ngrams = self.pipeline.common_grams.get(original_sentence, {})
32
+
33
+ # Step 1.1: Extract n-grams and their first occurrence index
34
+ ngram_occurrences = [
35
+ (min(indices, key=lambda x: x[0])[0], gram) # Get first index
36
+ for gram, indices in original_ngrams.items()
37
+ ]
38
+
39
+ # Step 1.2: Sort n-grams based on their first occurrence
40
+ ngram_occurrences.sort()
41
+
42
+ # Step 1.3: Assign sequential indices
43
+ for idx, (position, gram) in enumerate(ngram_occurrences, start=1):
44
+ seen_ngrams[gram] = idx # Assign sequential index
45
+ original_indexed_ngrams.append((idx, gram))
46
+
47
+ print("Original Indexed N-grams:", original_indexed_ngrams)
48
+
49
+ #generate highlight_info
50
+ colors = ["red", "blue", "green", "purple", "orange"]
51
+ highlight_info = [
52
+ (ngram, colors[i % len(colors)])
53
+ for i, (index, ngram) in enumerate(original_indexed_ngrams)
54
+ ]
55
+ common_grams = original_indexed_ngrams
56
+ self.highlight_info = highlight_info
57
+ self.common_grams = common_grams
58
+
59
+ # Step 2: Process paraphrased sentences and match indices
60
+ paraphrase_indexed_ngrams = {}
61
+
62
+ for sentence in self.pipeline.paraphrased_sentences:
63
+ sentence_ngrams = [] # Stores n-grams for this sentence
64
+ sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {})
65
+
66
+ for gram, indices in sentence_ngrams_dict.items():
67
+ first_occurrence = min(indices, key=lambda x: x[0])[0]
68
+
69
+ # Use the original's index if exists, otherwise assign a new one
70
+ if gram in seen_ngrams:
71
+ index = seen_ngrams[gram] # Use the same index as original
72
+ else:
73
+ index = len(seen_ngrams) + 1 # Assign new index
74
+ seen_ngrams[gram] = index # Store it
75
+
76
+ sentence_ngrams.append((index, gram))
77
+
78
+ sentence_ngrams.sort()
79
+ paraphrase_indexed_ngrams[sentence] = sentence_ngrams
80
+
81
+ print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams)
82
+
83
+ # Step 3: Generate highlighted versions using the renderer
84
+ highlighted_prompt = highlight_common_words(
85
+ common_grams,
86
+ [self.pipeline.user_prompt],
87
+ "Original Prompt with Highlighted Common Sequences"
88
+ )
89
+
90
+ highlighted_accepted = highlight_common_words_dict(
91
+ common_grams,
92
+ self.pipeline.selected_sentences,
93
+ "Accepted Paraphrased Sentences with Entailment Scores"
94
+ )
95
+
96
+ highlighted_discarded = highlight_common_words_dict(
97
+ common_grams,
98
+ self.pipeline.discarded_sentences,
99
+ "Discarded Paraphrased Sentences with Entailment Scores"
100
+ )
101
+
102
+ execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>"
103
+ self.highlight_info = highlight_info
104
+ self.common_grams = common_grams
105
+
106
+ return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time
107
+
108
+ def handle_masking(self) -> Tuple[List[go.Figure], str]:
109
+ """Wrapper for masking that generates visualization trees"""
110
+ start_time = time.time()
111
+
112
+ masking_results = self.pipeline.Masking()
113
+ trees = []
114
+ highlight_info = self.highlight_info
115
+ common_grams = self.common_grams
116
+ sentence_to_masked = {}
117
+
118
+ # Create a consolidated figure with all strategies
119
+ original_sentence = None
120
+
121
+ # First pass - gather all sentences and strategies
122
+ for strategy, sentence_dict in masking_results.items():
123
+ for sent, data in sentence_dict.items():
124
+ if sent not in sentence_to_masked:
125
+ sentence_to_masked[sent] = []
126
+ try:
127
+ if not isinstance(data, dict):
128
+ print(f"[ERROR] Data is not a dictionary for {sent} with strategy {strategy}")
129
+ continue
130
+
131
+ masked_sentence = data.get("masked_sentence", "")
132
+ if masked_sentence:
133
+ sentence_to_masked[sent].append((masked_sentence, strategy))
134
+ except Exception as e:
135
+ print(f"Error processing {strategy} for sentence {sent}: {e}")
136
+
137
+ for original_sentence, masked_sentences_data in sentence_to_masked.items():
138
+ if not masked_sentences_data:
139
+ continue
140
+ masked_sentences = [ms[0] for ms in masked_sentences_data]
141
+ strategies = [ms[1] for ms in masked_sentences_data]
142
+ try:
143
+
144
+ fig = generate_subplot1(
145
+ original_sentence,
146
+ masked_sentences,
147
+ strategies,
148
+ highlight_info,
149
+ common_grams
150
+ )
151
+ trees.append(fig)
152
+ except Exception as e:
153
+ print(f"Error generating multi-strategy tree: {e}")
154
+ trees.append(go.Figure())
155
+
156
+ # Pad with empty plots if needed
157
+ while len(trees) < 10:
158
+ trees.append(go.Figure())
159
+
160
+ execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>"
161
+
162
+ return trees[:10] + [execution_time]
163
+
164
+ def handle_sampling(self) -> Tuple[List[go.Figure], str]:
165
+ """Wrapper for sampling that generates visualization trees"""
166
+ start_time = time.time()
167
+ sampling_results = self.pipeline.Sampling()
168
+ trees = []
169
+
170
+ # Group sentences by original sentence
171
+ organized_results = {}
172
+
173
+ # Generate trees for each sampled sentence
174
+ for sampling_strategy, masking_dict in sampling_results.items():
175
+ for masking_strategy, sentences in masking_dict.items():
176
+ for original_sentence, data in sentences.items():
177
+ if original_sentence not in organized_results:
178
+ organized_results[original_sentence] = {}
179
+
180
+ if masking_strategy not in organized_results[original_sentence]:
181
+ organized_results[original_sentence][masking_strategy] = {
182
+ "masked_sentence": data.get("masked_sentence", ""), # Corrected reference
183
+ "sampled_sentences": {}
184
+ }
185
+
186
+ # Add this sampling result
187
+ organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "")
188
+
189
+ for original_sentence, data in organized_results.items():
190
+ masked_sentences = []
191
+ all_sampled_sentences = []
192
+
193
+ for masking_strategy, masking_data in list(data.items())[:3]: # Ensure this iteration is safe
194
+ masked_sentence = masking_data.get("masked_sentence", "")
195
+ if masked_sentence:
196
+ masked_sentences.append(masked_sentence)
197
+
198
+ for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items():
199
+ if sampled_sentence:
200
+ all_sampled_sentences.append(sampled_sentence)
201
+
202
+ if masked_sentences:
203
+ try:
204
+ fig = generate_subplot2(
205
+ masked_sentences,
206
+ all_sampled_sentences,
207
+ self.highlight_info,
208
+ self.common_grams
209
+ )
210
+ trees.append(fig)
211
+ except Exception as e:
212
+ print(f"Error generating subplot for {original_sentence}: {e}")
213
+ trees.append(go.Figure())
214
+
215
+ while len(trees) < 10:
216
+ trees.append(go.Figure())
217
+
218
+ execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>"
219
+
220
+ return trees[:10] + [execution_time]
221
+
222
+ def handle_reparaphrasing(self) -> Tuple[List[str], str]:
223
+ """Wrapper for re-paraphrasing that formats results as HTML"""
224
+ start_time = time.time()
225
+
226
+ results = self.pipeline.re_paraphrasing()
227
+ html_outputs = []
228
+
229
+ # Generate HTML for each batch of re-paraphrased sentences
230
+ for sampling_strategy, masking_dict in results.items():
231
+ for masking_strategy, sentences in masking_dict.items():
232
+ for original_sent, data in sentences.items():
233
+ if data["re_paraphrased_sentences"]:
234
+ html = reparaphrased_sentences_html(data["re_paraphrased_sentences"])
235
+ html_outputs.append(html)
236
+
237
+ # Pad with empty HTML if needed
238
+ while len(html_outputs) < 120:
239
+ html_outputs.append("")
240
+
241
+ execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>"
242
+
243
+ return html_outputs[:120] + [execution_time]
244
+
245
+
246
+ def create_gradio_interface(config):
247
+ """Creates the Gradio interface with the updated pipeline"""
248
+ interface = WatermarkerInterface(config)
249
+
250
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
251
+ #CSS to enable scrolling for reparaphrased sentences and sampling plots
252
+ demo.css = """
253
+ /* Set fixed height for the reparaphrased tabs container only */
254
+ .gradio-container .tabs[id="reparaphrased-tabs"],
255
+ .gradio-container .tabs[id="sampling-tabs"] {
256
+ overflow-x: hidden;
257
+ white-space: normal;
258
+ border-radius: 8px;
259
+ max-height: 600px; /* Set fixed height for the entire tabs component */
260
+ overflow-y: auto; /* Enable vertical scrolling inside the container */
261
+ }
262
+
263
+ /* Tab content styling for reparaphrased and sampling tabs */
264
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tabitem,
265
+ .gradio-container .tabs[id="sampling-tabs"] .tabitem {
266
+ overflow-x: hidden;
267
+ white-space: normal;
268
+ display: block;
269
+ border-radius: 8px;
270
+ }
271
+
272
+ /* Make the tab navigation fixed at the top for scrollable tabs */
273
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav,
274
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav {
275
+ display: flex;
276
+ overflow-x: auto;
277
+ white-space: nowrap;
278
+ scrollbar-width: thin;
279
+ border-radius: 8px;
280
+ scrollbar-color: #888 #f1f1f1;
281
+ position: sticky;
282
+ top: 0;
283
+ background: white;
284
+ z-index: 100;
285
+ }
286
+
287
+ /* Dropdown menu for scrollable tabs styling */
288
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown,
289
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown {
290
+ position: relative;
291
+ display: inline-block;
292
+ }
293
+
294
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content,
295
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content {
296
+ display: none;
297
+ position: absolute;
298
+ background-color: #f9f9f9;
299
+ min-width: 160px;
300
+ box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
301
+ z-index: 1;
302
+ max-height: 300px;
303
+ overflow-y: auto;
304
+ }
305
+
306
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content,
307
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content {
308
+ display: block;
309
+ }
310
+
311
+ /* Scrollbar styling for scrollable tabs */
312
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar,
313
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar {
314
+ height: 8px;
315
+ border-radius: 8px;
316
+ }
317
+
318
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track,
319
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track {
320
+ background: #f1f1f1;
321
+ border-radius: 8px;
322
+ }
323
+
324
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb,
325
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb {
326
+ background: #888;
327
+ border-radius: 8px;
328
+ }
329
+
330
+ /* Tab button styling for scrollable tabs */
331
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item,
332
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item {
333
+ flex: 0 0 auto;
334
+ border-radius: 8px;
335
+ }
336
+
337
+ /* Plot container styling specifically for sampling tabs */
338
+ .gradio-container .tabs[id="sampling-tabs"] .plot-container {
339
+ min-height: 600px;
340
+ max-height: 1800px;
341
+ overflow-y: auto;
342
+ }
343
+
344
+ /* Ensure text wraps in HTML components */
345
+ .gradio-container .prose {
346
+ white-space: normal;
347
+ word-wrap: break-word;
348
+ overflow-wrap: break-word;
349
+ }
350
+
351
+ /* Dropdown button styling for scrollable tabs */
352
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button,
353
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button {
354
+ background-color: #f0f0f0;
355
+ border: 1px solid #ddd;
356
+ border-radius: 4px;
357
+ padding: 5px 10px;
358
+ cursor: pointer;
359
+ margin: 2px;
360
+ }
361
+
362
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover,
363
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover {
364
+ background-color: #e0e0e0;
365
+ }
366
+
367
+ /* Style dropdown content items for scrollable tabs */
368
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div,
369
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div {
370
+ padding: 8px 12px;
371
+ cursor: pointer;
372
+ }
373
+
374
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover,
375
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover {
376
+ background-color: #e0e0e0;
377
+ }
378
+
379
+ /* Custom styling for execution time display */
380
+ .execution-time {
381
+ text-align: right;
382
+ padding: 8px 16px;
383
+ font-family: inherit;
384
+ color: #555;
385
+ font-size: 0.9rem;
386
+ font-style: italic;
387
+ margin-left: auto;
388
+ width: 100%;
389
+ border-top: 1px solid #eee;
390
+ margin-top: 8px;
391
+ }
392
+
393
+ /* Layout for section headers with execution time */
394
+ .section-header {
395
+ display: flex;
396
+ justify-content: space-between;
397
+ align-items: center;
398
+ width: 100%;
399
+ margin-bottom: 12px;
400
+ }
401
+
402
+ .section-header h3 {
403
+ margin: 0;
404
+ }
405
+ """
406
+ gr.Markdown("# **AIISC Watermarking Model**")
407
+
408
+ with gr.Column():
409
+ gr.Markdown("## Input Prompt")
410
+ user_input = gr.Textbox(
411
+ label="Enter Your Prompt",
412
+ placeholder="Type your text here..."
413
+ )
414
+
415
+ with gr.Row():
416
+ with gr.Column(scale=3):
417
+ gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
418
+ with gr.Column(scale=1):
419
+ step1_time = gr.HTML()
420
+
421
+ paraphrase_button = gr.Button("Generate Paraphrases")
422
+ highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
423
+
424
+ with gr.Tabs():
425
+ with gr.TabItem("Accepted Paraphrased Sentences"):
426
+ highlighted_accepted_sentences = gr.HTML()
427
+ with gr.TabItem("Discarded Paraphrased Sentences"):
428
+ highlighted_discarded_sentences = gr.HTML()
429
+
430
+ with gr.Row():
431
+ with gr.Column(scale=3):
432
+ gr.Markdown("## Step 2: Where to Mask?")
433
+ with gr.Column(scale=1):
434
+ step2_time = gr.HTML()
435
+
436
+ masking_button = gr.Button("Apply Masking")
437
+ gr.Markdown("### Masked Sentence Trees")
438
+ tree1_plots = []
439
+ with gr.Tabs() as tree1_tabs:
440
+ for i in range(10):
441
+ with gr.TabItem(f"Masked Sentence {i+1}"):
442
+ tree1 = gr.Plot()
443
+ tree1_plots.append(tree1)
444
+
445
+ with gr.Row():
446
+ with gr.Column(scale=3):
447
+ gr.Markdown("## Step 3: How to Mask?")
448
+ with gr.Column(scale=1):
449
+ step3_time = gr.HTML()
450
+
451
+ sampling_button = gr.Button("Sample Words")
452
+ gr.Markdown("### Sampled Sentence Trees")
453
+
454
+ tree2_plots = []
455
+ # Add elem_id to make this tab container scrollable
456
+ with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs:
457
+ for i in range(10):
458
+ with gr.TabItem(f"Sampled Sentence {i+1}"):
459
+ # Add a custom class to the container to enable proper styling
460
+ with gr.Column(elem_classes=["plot-container"]):
461
+ tree2 = gr.Plot()
462
+ tree2_plots.append(tree2)
463
+
464
+ with gr.Row():
465
+ with gr.Column(scale=3):
466
+ gr.Markdown("## Step 4: Re-paraphrasing")
467
+ with gr.Column(scale=1):
468
+ step4_time = gr.HTML()
469
+
470
+ reparaphrase_button = gr.Button("Re-paraphrase")
471
+ gr.Markdown("### Reparaphrased Sentences")
472
+ reparaphrased_sentences_tabs = []
473
+ with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs:
474
+ for i in range(120):
475
+ with gr.TabItem(f"Reparaphrased Batch {i+1}"):
476
+ reparaphrased_sent_html = gr.HTML()
477
+ reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
478
+
479
+ # Connect the interface functions to the buttons
480
+ paraphrase_button.click(
481
+ interface.handle_paraphrase,
482
+ inputs=user_input,
483
+ outputs=[
484
+ highlighted_user_prompt,
485
+ highlighted_accepted_sentences,
486
+ highlighted_discarded_sentences,
487
+ step1_time
488
+ ]
489
+ )
490
+
491
+ masking_button.click(
492
+ interface.handle_masking,
493
+ inputs=None,
494
+ outputs=tree1_plots + [step2_time]
495
+ )
496
+
497
+ sampling_button.click(
498
+ interface.handle_sampling,
499
+ inputs=None,
500
+ outputs=tree2_plots + [step3_time]
501
+ )
502
+
503
+ reparaphrase_button.click(
504
+ interface.handle_reparaphrasing,
505
+ inputs=None,
506
+ outputs=reparaphrased_sentences_tabs + [step4_time]
507
+ )
508
+
509
+ return demo
510
+
511
+ if __name__ == "__main__":
512
+ project_root = Path(__file__).parent.parent
513
+ config_path = project_root / "utils" / "config.yaml"
514
+ config = load_config(config_path)['PECCAVI_TEXT']
515
+
516
+ create_gradio_interface(config).launch()
__pycache__/app.cpython-310.pyc ADDED
Binary file (747 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from UI.gradio import create_gradio_interface
4
+
5
+ from pathlib import Path
6
+ from utils.config import load_config
7
+
8
+ project_root = Path(__file__).resolve().parent
9
+ config_path = project_root / "utils" / "config.yaml"
10
+ config = load_config(config_path)['PECCAVI_TEXT']
11
+
12
+ def main():
13
+ """
14
+ This function is the entry point for the PECCAVI Watermarking Model.
15
+
16
+ It creates the Gradio interface for the model and runs it.
17
+ """
18
+ create_gradio_interface(config).launch()
19
+
20
+ if __name__ == "__main__":
21
+ main()
environment.yml ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: panda
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - asttokens=2.4.1=pyhd8ed1ab_0
9
+ - bzip2=1.0.8=h5eee18b_6
10
+ - ca-certificates=2024.8.30=hbcca054_0
11
+ - comm=0.2.2=pyhd8ed1ab_0
12
+ - debugpy=1.8.6=py310hf71b8c6_0
13
+ - decorator=5.1.1=pyhd8ed1ab_0
14
+ - exceptiongroup=1.2.2=pyhd8ed1ab_0
15
+ - executing=2.1.0=pyhd8ed1ab_0
16
+ - ipykernel=6.29.5=pyh3099207_0
17
+ - ipython=8.27.0=pyh707e725_0
18
+ - jedi=0.19.1=pyhd8ed1ab_0
19
+ - jupyter_client=8.6.3=pyhd8ed1ab_0
20
+ - jupyter_core=5.7.2=pyh31011fe_1
21
+ - krb5=1.21.3=h143b758_0
22
+ - ld_impl_linux-64=2.40=h12ee557_0
23
+ - libedit=3.1.20230828=h5eee18b_0
24
+ - libffi=3.4.4=h6a678d5_1
25
+ - libgcc=14.1.0=h77fa898_1
26
+ - libgcc-ng=14.1.0=h69a702a_1
27
+ - libgomp=14.1.0=h77fa898_1
28
+ - libsodium=1.0.20=h4ab18f5_0
29
+ - libstdcxx=14.1.0=hc0a3c3a_1
30
+ - libstdcxx-ng=11.2.0=h1234567_1
31
+ - libuuid=1.41.5=h5eee18b_0
32
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_0
33
+ - ncurses=6.4=h6a678d5_0
34
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
35
+ - openssl=3.3.2=hb9d3cd8_0
36
+ - packaging=24.1=pyhd8ed1ab_0
37
+ - parso=0.8.4=pyhd8ed1ab_0
38
+ - pexpect=4.9.0=pyhd8ed1ab_0
39
+ - pickleshare=0.7.5=py_1003
40
+ - pip=24.2=py310h06a4308_0
41
+ - platformdirs=4.3.6=pyhd8ed1ab_0
42
+ - prompt-toolkit=3.0.48=pyha770c72_0
43
+ - ptyprocess=0.7.0=pyhd3deb0d_0
44
+ - pure_eval=0.2.3=pyhd8ed1ab_0
45
+ - pygments=2.18.0=pyhd8ed1ab_0
46
+ - python=3.10.14=h955ad1f_1
47
+ - python_abi=3.10=2_cp310
48
+ - pyzmq=26.2.0=py310h71f11fc_2
49
+ - readline=8.2=h5eee18b_0
50
+ - setuptools=75.1.0=py310h06a4308_0
51
+ - sqlite=3.45.3=h5eee18b_0
52
+ - stack_data=0.6.2=pyhd8ed1ab_0
53
+ - tk=8.6.14=h39e8969_0
54
+ - tornado=6.4.1=py310ha75aee5_1
55
+ - traitlets=5.14.3=pyhd8ed1ab_0
56
+ - typing_extensions=4.12.2=pyha770c72_0
57
+ - wcwidth=0.2.13=pyhd8ed1ab_0
58
+ - wheel=0.44.0=py310h06a4308_0
59
+ - xz=5.4.6=h5eee18b_1
60
+ - zeromq=4.3.5=ha4adb4c_5
61
+ - zlib=1.2.13=h5eee18b_1
62
+ - pip:
63
+ - absl-py==2.1.0
64
+ - accelerate==0.33.0
65
+ - aiofiles==23.2.1
66
+ - aiohappyeyeballs==2.3.5
67
+ - aiohttp==3.10.3
68
+ - aiosignal==1.3.1
69
+ - altgraph==0.17.4
70
+ - annotated-types==0.7.0
71
+ - anyio==4.6.0
72
+ - astunparse==1.6.3
73
+ - async-timeout==4.0.3
74
+ - attrs==24.2.0
75
+ - av==12.0.0
76
+ - backports-tarfile==1.2.0
77
+ - beautifulsoup4==4.12.3
78
+ - build==1.2.2
79
+ - cachetools==5.5.0
80
+ - certifi==2024.7.4
81
+ - cffi==1.17.1
82
+ - charset-normalizer==3.3.2
83
+ - clean-fid==0.1.35
84
+ - click==8.1.7
85
+ - colorama==0.4.6
86
+ - contextlib2==21.6.0
87
+ - contourpy==1.2.1
88
+ - cryptography==43.0.1
89
+ - cycler==0.12.1
90
+ - datasets==2.21.0
91
+ - diffusers==0.27.2
92
+ - dill==0.3.8
93
+ - docker-pycreds==0.4.0
94
+ - docutils==0.21.2
95
+ - fastapi==0.115.0
96
+ - ffmpy==0.4.0
97
+ - filelock==3.15.4
98
+ - flatbuffers==24.3.25
99
+ - fonttools==4.53.1
100
+ - frozenlist==1.4.1
101
+ - fsspec==2024.6.1
102
+ - gast==0.4.0
103
+ - gdown==5.2.0
104
+ - gitdb==4.0.11
105
+ - gitpython==3.1.43
106
+ - google-auth==2.35.0
107
+ - google-auth-oauthlib==0.4.6
108
+ - google-pasta==0.2.0
109
+ - gradio==4.44.0
110
+ - gradio-client==1.3.0
111
+ - grpcio==1.65.4
112
+ - h11==0.14.0
113
+ - h5py==3.11.0
114
+ - httpcore==1.0.6
115
+ - httpx==0.27.2
116
+ - huggingface-hub==0.25.2
117
+ - idna==3.7
118
+ - imageio==2.35.0
119
+ - importlib-metadata==8.2.0
120
+ - importlib-resources==6.4.5
121
+ - jaraco-classes==3.4.0
122
+ - jaraco-context==6.0.1
123
+ - jaraco-functools==4.1.0
124
+ - jeepney==0.8.0
125
+ - jinja2==3.1.4
126
+ - joblib==1.4.2
127
+ - json-with-comments==1.2.7
128
+ - keras==3.5.0
129
+ - keras-preprocessing==1.1.2
130
+ - keyring==25.4.1
131
+ - kiwisolver==1.4.5
132
+ - kornia==0.7.4
133
+ - kornia-rs==0.1.7
134
+ - lazy-loader==0.4
135
+ - libclang==18.1.1
136
+ - markdown==3.6
137
+ - markdown-it-py==3.0.0
138
+ - markupsafe==2.1.5
139
+ - matplotlib==3.9.2
140
+ - mdurl==0.1.2
141
+ - ml-collections==0.1.1
142
+ - ml-dtypes==0.4.0
143
+ - more-itertools==10.5.0
144
+ - multidict==6.0.5
145
+ - multiprocess==0.70.16
146
+ - namex==0.0.8
147
+ - networkx==3.3
148
+ - nh3==0.2.18
149
+ - nltk==3.9.1
150
+ - numpy==1.26.4
151
+ - nvidia-cublas-cu11==11.10.3.66
152
+ - nvidia-cuda-nvrtc-cu11==11.7.99
153
+ - nvidia-cuda-runtime-cu11==11.7.99
154
+ - nvidia-cudnn-cu11==8.5.0.96
155
+ - oauthlib==3.2.2
156
+ - opencv-python==4.10.0.84
157
+ - opencv-python-headless==4.10.0.84
158
+ - opt-einsum==3.3.0
159
+ - optree==0.12.1
160
+ - orjson==3.10.7
161
+ - pandas==2.2.2
162
+ - pillow==10.4.0
163
+ - pkginfo==1.10.0
164
+ - plotly==5.24.1
165
+ - protobuf==4.25.5
166
+ - psutil==5.9.8
167
+ - pyarrow==17.0.0
168
+ - pyasn1==0.6.1
169
+ - pyasn1-modules==0.4.1
170
+ - pycparser==2.22
171
+ - pydantic==2.9.2
172
+ - pydantic-core==2.23.4
173
+ - pydub==0.25.1
174
+ - pyinstaller==6.10.0
175
+ - pyinstaller-hooks-contrib==2024.8
176
+ - pyparsing==3.1.2
177
+ - pyproject-hooks==1.1.0
178
+ - pysocks==1.7.1
179
+ - python-dateutil==2.9.0.post0
180
+ - python-multipart==0.0.12
181
+ - pytorch-msssim==1.0.0
182
+ - pytorchcv==0.0.73
183
+ - pytz==2023.3.post1
184
+ - pyyaml==6.0.2
185
+ - readme-renderer==44.0
186
+ - regex==2024.7.24
187
+ - requests==2.32.3
188
+ - requests-oauthlib==2.0.0
189
+ - requests-toolbelt==1.0.0
190
+ - rfc3986==2.0.0
191
+ - rich==13.7.1
192
+ - rsa==4.9
193
+ - ruff==0.6.9
194
+ - safetensors==0.4.4
195
+ - saliency==0.2.1
196
+ - scikit-image==0.24.0
197
+ - scikit-learn==1.6.0
198
+ - scipy==1.14.0
199
+ - secretstorage==3.3.3
200
+ - semantic-version==2.10.0
201
+ - sentence-transformers==3.3.1
202
+ - sentry-sdk==2.15.0
203
+ - setproctitle==1.3.3
204
+ - shapely==2.0.5
205
+ - shellingham==1.5.4
206
+ - six==1.12.0
207
+ - smmap==5.0.1
208
+ - sniffio==1.3.1
209
+ - soupsieve==2.6
210
+ - spaces==0.30.2
211
+ - starlette==0.38.6
212
+ - tenacity==9.0.0
213
+ - tensorboard==2.17.1
214
+ - tensorboard-data-server==0.7.2
215
+ - tensorboard-plugin-wit==1.8.1
216
+ - tensorflow==2.17.0
217
+ - tensorflow-estimator==2.10.0
218
+ - tensorflow-hub==0.16.1
219
+ - tensorflow-intel==0.0.1
220
+ - tensorflow-io-gcs-filesystem==0.31.0
221
+ - termcolor==1.1.0
222
+ - tf-keras==2.17.0
223
+ - threadpoolctl==3.5.0
224
+ - tifffile==2024.8.10
225
+ - timm==1.0.10
226
+ - tokenizers==0.19.1
227
+ - tomli==2.0.1
228
+ - tomlkit==0.12.0
229
+ - torch==1.13.1
230
+ - torchvision==0.14.1
231
+ - tqdm==4.66.5
232
+ - transformers==4.43.3
233
+ - twine==5.1.1
234
+ - typer==0.12.5
235
+ - tzdata==2024.1
236
+ - urllib3==2.2.2
237
+ - uvicorn==0.31.0
238
+ - wandb==0.18.3
239
+ - websockets==12.0
240
+ - werkzeug==3.0.4
241
+ - wrapt==1.11.2
242
+ - xxhash==3.4.1
243
+ - yarl==1.9.4
244
+ - zipp==3.20.0
245
+ prefix: /home/ashhar21137/miniconda3/envs/panda
metrics/distortion.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
8
+ from bert_score import BERTScorer
9
+ from bert_score.utils import model2layers
10
+ from nltk.tokenize import word_tokenize
11
+ from Levenshtein import distance as levenshtein_distance
12
+ from sentence_transformers import SentenceTransformer
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from scipy.spatial.distance import cdist
15
+ from scipy.optimize import linear_sum_assignment
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
+
18
+ from config.config import load_config
19
+ config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
20
+ config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
21
+
22
+ class SentenceDistortionCalculator:
23
+ """
24
+ A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
25
+ """
26
+ def __init__(self, config, original_sentence, paraphrased_sentences):
27
+ """
28
+ Initialize the calculator with the original sentence and a list of modified sentences.
29
+ """
30
+ self.original_sentence = original_sentence
31
+ self.paraphrased_sentences = paraphrased_sentences
32
+
33
+ self.levenshtein_distances = {}
34
+ self.bert_scores = {}
35
+ self.mover_scores = {}
36
+
37
+ self.normalized_levenshtein = {}
38
+ self.normalized_bert_scores = {}
39
+ self.normalized_mover_scores = {}
40
+ self.combined_distortions = {}
41
+
42
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(config['Distortion'])
43
+ self.model = GPT2LMHeadModel.from_pretrained(config['Distortion'])
44
+ self.model.eval()
45
+
46
+ def calculate_all_metrics(self):
47
+ """
48
+ Calculate all distortion metrics for each modified sentence.
49
+ """
50
+ for idx, modified_sentence in tqdm(enumerate(self.paraphrased_sentences), total=len(self.paraphrased_sentences), desc="Calculating Metrics"):
51
+ key = f"Sentence_{idx+1}"
52
+ self.levenshtein_distances[key] = self._calculate_levenshtein_distance(modified_sentence)
53
+ self.bert_scores[key] = self._calculate_bert_score(modified_sentence)
54
+ self.mover_scores[key] = self._calculate_mover_score(modified_sentence)
55
+
56
+
57
+ def normalize_metrics(self):
58
+ """
59
+ Normalize all metrics to be between 0 and 1.
60
+ """
61
+ for _ in tqdm(range(1), desc="Normalizing Metrics"): # Add tqdm here (wrap the normalization process)
62
+ self.normalized_levenshtein = self._normalize_dict(self.levenshtein_distances)
63
+ self.normalized_bert_scores = self._normalize_dict(self.bert_scores)
64
+ self.normalized_mover_scores = self._normalize_dict(self.mover_scores)
65
+
66
+ def calculate_combined_distortion(self):
67
+ """
68
+ Calculate the combined distortion using the root mean square of the normalized metrics.
69
+ """
70
+ for _ in tqdm(range(1), desc="Calculating Combined Distortion"): # Add tqdm here
71
+ for key in self.normalized_levenshtein.keys():
72
+ rms = np.sqrt(
73
+ (
74
+ self.normalized_levenshtein[key] ** 2 +
75
+ self.normalized_bert_scores[key] ** 2+
76
+ self.normalized_mover_scores[key] **2
77
+ ) / 3
78
+ )
79
+ self.combined_distortions[key] = rms
80
+
81
+ def plot_metrics(self):
82
+ """
83
+ Plot each normalized metric and the combined distortion in separate graphs.
84
+ """
85
+ keys = list(self.normalized_levenshtein.keys())
86
+ indices = np.arange(len(keys))
87
+
88
+ # Prepare data for plotting
89
+ metrics = {
90
+ 'Levenshtein Distance': [self.normalized_levenshtein[key] for key in keys],
91
+ 'BERTScore': [self.normalized_bert_scores[key] for key in keys],
92
+ 'MOVERscore':[self.normalized_mover_scores[key] for key in keys],
93
+ 'Combined Distortion': [self.combined_distortions[key] for key in keys]
94
+ }
95
+
96
+ # Plot each metric separately
97
+ for metric_name, values in tqdm(metrics.items(), desc="Plotting Metrics"): # Add tqdm here
98
+ plt.figure(figsize=(12, 6))
99
+ plt.plot(indices, values, marker='o', color=np.random.rand(3,))
100
+ plt.xlabel('Sentence Index')
101
+ plt.ylabel('Normalized Value (0-1)')
102
+ plt.title(f'Normalized {metric_name}')
103
+ plt.grid(True)
104
+ plt.tight_layout()
105
+ plt.show()
106
+
107
+ def _calculate_levenshtein_distance(self, modified_sentence):
108
+ """
109
+ Calculate the word-level Levenshtein distance between the original and modified sentence.
110
+ """
111
+ words1 = word_tokenize(self.original_sentence)
112
+ words2 = word_tokenize(modified_sentence)
113
+ lev_distance = levenshtein_distance(words1, words2)
114
+ return (lev_distance / max(len(words1), len(words2)))
115
+
116
+ def _calculate_bert_score(self, modified_sentence):
117
+ """
118
+ Compute the BERTScore similarity between the original and modified sentence.
119
+ Returns 1 - F1 score to represent dissimilarity.
120
+ """
121
+ if not hasattr(self, 'original_sentence'):
122
+ raise ValueError("original_sentence is not set. Please set self.original_sentence before calling this function.")
123
+ if not isinstance(modified_sentence, str):
124
+ raise ValueError("modified_sentence must be a string.")
125
+
126
+ model_type = "microsoft/deberta-xlarge-mnli"
127
+ num_layers = model2layers[model_type]
128
+
129
+ if not hasattr(self, "cached_bertscorer"):
130
+ self.cached_bertscorer = BERTScorer(
131
+ model_type=model_type,
132
+ num_layers=num_layers,
133
+ batch_size=1, # Single sentence comparison
134
+ nthreads=4,
135
+ all_layers=False,
136
+ idf=False,
137
+ device="cuda" if torch.cuda.is_available() else "cpu",
138
+ lang="en"
139
+ )
140
+
141
+ # Compute BERTScore
142
+ _, _, F1 = self.cached_bertscorer.score(
143
+ cands=[modified_sentence],
144
+ refs=[self.original_sentence],
145
+ verbose=False,
146
+ batch_size=1
147
+ )
148
+
149
+ return 1 - F1.item() # Return dissimilarity score
150
+ def _calculate_mover_score(self,modified_sentence,model_name='all-MiniLM-L6-v2'):
151
+ """Compute MoverScore correctly using word-level embeddings."""
152
+ if not self.original_sentence:
153
+ raise ValueError("Original sentence not provided.")
154
+
155
+ # Tokenize sentences
156
+ original_tokens = self.original_sentence.split()
157
+ modified_tokens = modified_sentence.split()
158
+ model = SentenceTransformer(model_name)
159
+
160
+ # Compute word embeddings
161
+ original_embeddings = model.encode(original_tokens, convert_to_numpy=True)
162
+ modified_embeddings = model.encode(modified_tokens, convert_to_numpy=True)
163
+
164
+ # Compute cost matrix (cosine distance)
165
+ cost_matrix = cdist(original_embeddings, modified_embeddings, metric='cosine')
166
+
167
+ # Solve optimal transport problem (Hungarian Algorithm)
168
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
169
+
170
+ # Compute IDF weights
171
+ vectorizer = TfidfVectorizer()
172
+ vectorizer.fit([self.original_sentence, modified_sentence])
173
+ idf_values = dict(zip(vectorizer.get_feature_names_out(), vectorizer.idf_))
174
+
175
+ # Apply IDF weighting to aligned word pairs
176
+ idf_weights_original = np.array([idf_values.get(word.lower(), 1.0) for word in original_tokens])
177
+ idf_weights_modified = np.array([idf_values.get(word.lower(), 1.0) for word in modified_tokens])
178
+ combined_idf_weights = (idf_weights_original[row_ind] + idf_weights_modified[col_ind]) / 2
179
+ weighted_score = np.sum((1 - cost_matrix[row_ind, col_ind]) * combined_idf_weights) / np.sum(combined_idf_weights)
180
+
181
+ return 1-weighted_score # Higher score = more dissimilar
182
+
183
+ def _normalize_dict(self, metric_dict):
184
+ """
185
+ Normalize the values in a dictionary to be between 0 and 1.
186
+ """
187
+ values = np.array(list(metric_dict.values()))
188
+ min_val = values.min()
189
+ max_val = values.max()
190
+ if max_val - min_val == 0:
191
+ normalized_values = np.zeros_like(values)
192
+ else:
193
+ normalized_values = (values - min_val) / (max_val - min_val)
194
+ return dict(zip(metric_dict.keys(), normalized_values))
195
+
196
+ def get_normalized_metrics(self):
197
+ """
198
+ Get all normalized metrics as a dictionary.
199
+ """
200
+ return {
201
+ 'Min Edit Distance': self.normalized_levenshtein,
202
+ 'BERTScore': self.normalized_bert_scores,
203
+ 'Mover Score': self.normalized_mover_scores
204
+ }
205
+
206
+ def get_combined_distortions(self):
207
+ """
208
+ Get the dictionary of combined distortion values.
209
+ """
210
+ return self.combined_distortions
211
+
212
+ # Example usage
213
+ if __name__ == "__main__":
214
+
215
+ config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
216
+
217
+ # Original sentence
218
+ original_sentence = "The quick brown fox jumps over the lazy dog"
219
+
220
+ # Paraphrased sentences
221
+ paraphrased_sentences = [
222
+ # Original 1: "A swift auburn fox leaps across a sleepy canine."
223
+ "The swift auburn fox leaps across a sleepy canine.",
224
+ "A quick auburn fox leaps across a sleepy canine.",
225
+ "A swift ginger fox leaps across a sleepy canine.",
226
+ "A swift auburn fox bounds across a sleepy canine.",
227
+ "A swift auburn fox leaps across a tired canine.",
228
+ "Three swift auburn foxes leap across a sleepy canine.",
229
+ "The vulpine specimen rapidly traverses over a dormant dog.",
230
+ "Like lightning, the russet hunter soars over the drowsy guardian.",
231
+ "Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
232
+ "One rapid Vulpes vulpes traverses the path of a quiescent canine.",
233
+ "A swift auburn predator navigates across a lethargic pet.",
234
+ "Subject A (fox) demonstrates velocity over Subject B (dog).",
235
+
236
+ # Original 2: "The agile russet fox bounds over an idle hound."
237
+ "Some agile russet foxes bound over an idle hound.",
238
+ "The nimble russet fox bounds over an idle hound.",
239
+ "The agile brown fox bounds over an idle hound.",
240
+ "The agile russet fox jumps over an idle hound.",
241
+ "The agile russet fox bounds over a lazy hound.",
242
+ "Two agile russet foxes bound over an idle hound.",
243
+ "A dexterous vulpine surpasses a stationary canine.",
244
+ "Quick as thought, the copper warrior sails over the guardian.",
245
+ "Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
246
+ "A dexterous V. vulpes exceeds the plane of an inactive canine.",
247
+ "An agile russet hunter maneuvers above a resting hound.",
248
+ "Test subject F-1 achieves displacement superior to subject D-1.",
249
+
250
+ # Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
251
+ "The nimble mahogany vulpine vaults above a drowsy dog.",
252
+ "A swift mahogany vulpine vaults above a drowsy dog.",
253
+ "A nimble reddish vulpine vaults above a drowsy dog.",
254
+ "A nimble mahogany fox vaults above a drowsy dog.",
255
+ "A nimble mahogany vulpine leaps above a drowsy dog.",
256
+ "Four nimble mahogany vulpines vault above a drowsy dog.",
257
+ "An agile specimen of reddish fur surpasses a somnolent canine.",
258
+ "Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
259
+ "Tha quick brown beastie jumps o'er the tired pup, aye.",
260
+ "Single V. vulpes demonstrates vertical traverse over C. familiaris.",
261
+ "A nimble rust-colored predator crosses above a drowsy pet.",
262
+ "Observed: Subject Red executes vertical motion over Subject Gray.",
263
+
264
+ # Original 4: "The speedy copper-colored fox hops over the lethargic pup."
265
+ "A speedy copper-colored fox hops over the lethargic pup.",
266
+ "The quick copper-colored fox hops over the lethargic pup.",
267
+ "The speedy bronze fox hops over the lethargic pup.",
268
+ "The speedy copper-colored fox jumps over the lethargic pup.",
269
+ "The speedy copper-colored fox hops over the tired pup.",
270
+ "Multiple speedy copper-colored foxes hop over the lethargic pup.",
271
+ "A rapid vulpine of bronze hue traverses an inactive young canine.",
272
+ "Swift as a dart, the metallic hunter bounds over the lazy puppy.",
273
+ "Tha fast copper beastie leaps o'er the sleepy wee dog.",
274
+ "1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
275
+ "A fleet copper-toned predator moves past a sluggish young dog.",
276
+ "Field note: Adult fox subject exceeds puppy subject vertically.",
277
+
278
+ # Original 5: "A rapid tawny fox springs over a sluggish dog."
279
+ "The rapid tawny fox springs over a sluggish dog.",
280
+ "A quick tawny fox springs over a sluggish dog.",
281
+ "A rapid golden fox springs over a sluggish dog.",
282
+ "A rapid tawny fox jumps over a sluggish dog.",
283
+ "A rapid tawny fox springs over a lazy dog.",
284
+ "Six rapid tawny foxes spring over a sluggish dog.",
285
+ "An expeditious yellowish vulpine surpasses a torpid canine.",
286
+ "Fast as a bullet, the golden hunter vaults over the idle guard.",
287
+ "Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
288
+ "One V. vulpes displays rapid transit over one inactive C. familiaris.",
289
+ "A speedy yellow-brown predator bypasses a motionless dog.",
290
+ "Log entry: Vulpine subject achieves swift vertical displacement.",
291
+
292
+ # Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
293
+ "A fleet-footed chestnut fox soars above an indolent canine.",
294
+ "The swift chestnut fox soars above an indolent canine.",
295
+ "The fleet-footed brown fox soars above an indolent canine.",
296
+ "The fleet-footed chestnut fox leaps above an indolent canine.",
297
+ "The fleet-footed chestnut fox soars above a lazy canine.",
298
+ "Several fleet-footed chestnut foxes soar above an indolent canine.",
299
+ "A rapid brown vulpine specimen traverses a lethargic domestic dog.",
300
+ "Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
301
+ "Tha quick brown beastie sails o'er the sleepy hound, ken.",
302
+ "Single agile V. vulpes achieves elevation above stationary canine.",
303
+ "A nimble brown predator glides over an unmoving domestic animal.",
304
+ "Research note: Brown subject displays superior vertical mobility.",
305
+
306
+ # Original 7: "A fast ginger fox hurdles past a slothful dog."
307
+ "The fast ginger fox hurdles past a slothful dog.",
308
+ "A quick ginger fox hurdles past a slothful dog.",
309
+ "A fast red fox hurdles past a slothful dog.",
310
+ "A fast ginger fox jumps past a slothful dog.",
311
+ "A fast ginger fox hurdles past a lazy dog.",
312
+ "Five fast ginger foxes hurdle past a slothful dog.",
313
+ "A rapid orange vulpine bypasses a lethargic canine.",
314
+ "Quick as lightning, the flame-colored hunter races past the lazy guard.",
315
+ "Tha swift ginger beastie leaps past the tired doggy, ye see.",
316
+ "1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
317
+ "A speedy red-orange predator overtakes a motionless dog.",
318
+ "Data point: Orange subject demonstrates rapid transit past Gray subject.",
319
+
320
+ # Original 8: "The spry rusty-colored fox jumps across a dozing hound."
321
+ "A spry rusty-colored fox jumps across a dozing hound.",
322
+ "The agile rusty-colored fox jumps across a dozing hound.",
323
+ "The spry reddish fox jumps across a dozing hound.",
324
+ "The spry rusty-colored fox leaps across a dozing hound.",
325
+ "The spry rusty-colored fox jumps across a sleeping hound.",
326
+ "Multiple spry rusty-colored foxes jump across a dozing hound.",
327
+ "An agile rust-toned vulpine traverses a somnolent canine.",
328
+ "Nimble as thought, the copper hunter bounds over the resting guard.",
329
+ "Tha lively rust-colored beastie hops o'er the snoozin' hound.",
330
+ "Single dexterous V. vulpes crosses path of dormant C. familiaris.",
331
+ "A lithe rust-tinted predator moves past a slumbering dog.",
332
+ "Observation: Russet subject exhibits agility over dormant subject.",
333
+
334
+ # Original 9: "A quick tan fox leaps over an inactive dog."
335
+ "The quick tan fox leaps over an inactive dog.",
336
+ "A swift tan fox leaps over an inactive dog.",
337
+ "A quick beige fox leaps over an inactive dog.",
338
+ "A quick tan fox jumps over an inactive dog.",
339
+ "A quick tan fox leaps over a motionless dog.",
340
+ "Seven quick tan foxes leap over an inactive dog.",
341
+ "A rapid light-brown vulpine surpasses a stationary canine.",
342
+ "Fast as wind, the sand-colored hunter soars over the still guard.",
343
+ "Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
344
+ "One agile fawn V. vulpes traverses one immobile C. familiaris.",
345
+ "A fleet tan-colored predator bypasses an unmoving dog.",
346
+ "Field report: Tan subject demonstrates movement over static subject.",
347
+
348
+ # Original 10: "The brisk auburn vulpine bounces over a listless canine."
349
+ "Some brisk auburn vulpines bounce over a listless canine.",
350
+ "The quick auburn vulpine bounces over a listless canine.",
351
+ "The brisk russet vulpine bounces over a listless canine.",
352
+ "The brisk auburn fox bounces over a listless canine.",
353
+ "The brisk auburn vulpine jumps over a listless canine.",
354
+ "Five brisk auburn vulpines bounce over a listless canine.",
355
+ "The expeditious specimen supersedes a quiescent Canis lupus.",
356
+ "Swift as wind, the russet hunter vaults over the idle guardian.",
357
+ "Tha quick ginger beastie hops o'er the lazy mutt, aye.",
358
+ "One V. vulpes achieves displacement over inactive C. familiaris.",
359
+ "A high-velocity auburn predator traverses an immobile animal.",
360
+ "Final observation: Red subject shows mobility over Gray subject."
361
+ ]
362
+
363
+ distortion_calculator = SentenceDistortionCalculator(config, original_sentence, paraphrased_sentences)
364
+ for _ in tqdm(range(1)):
365
+ distortion_calculator.calculate_all_metrics()
366
+ distortion_calculator.normalize_metrics()
367
+ distortion_calculator.calculate_combined_distortion()
368
+ distortion_calculator.plot_metrics()
369
+ print("Normalized Metrics:", distortion_calculator.get_normalized_metrics())
370
+ print("Combined Distortion:", distortion_calculator.get_combined_distortions())
renderers/__pycache__/highlighter.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
renderers/__pycache__/highlighter.cpython-311.pyc ADDED
Binary file (6.79 kB). View file
 
renderers/__pycache__/plot_3d.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
renderers/__pycache__/plot_3d.cpython-311.pyc ADDED
Binary file (6 kB). View file
 
renderers/__pycache__/tree.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
renderers/__pycache__/tree.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
renderers/highlighter.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def highlight_common_words(common_words, sentences, title):
4
+ """
5
+ Highlight common words in sentences by adding color-coded background and unique IDs.
6
+
7
+ Args:
8
+ common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
9
+ sentences (list of str): List of sentences to search through.
10
+ title (str): The title for the HTML output.
11
+
12
+ Returns:
13
+ str: HTML string with the highlighted sentences.
14
+ """
15
+ color_map = {}
16
+ color_index = 0
17
+ highlighted_html = []
18
+
19
+ # Process each sentence
20
+ for idx, sentence in enumerate(sentences, start=1):
21
+ sentence_with_idx = f"{idx}. {sentence}"
22
+ highlighted_sentence = sentence_with_idx
23
+
24
+ # Highlight common words in each sentence
25
+ for index, word in common_words:
26
+ if word not in color_map:
27
+ color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
28
+ color_index += 1
29
+
30
+ # Escape word and create regex pattern to match whole word
31
+ escaped_word = re.escape(word)
32
+ pattern = rf'\b{escaped_word}\b'
33
+
34
+ # Replace the word with highlighted version
35
+ highlighted_sentence = re.sub(
36
+ pattern,
37
+ lambda m, idx=index, color=color_map[word]: (
38
+ f'<span style="background-color: {color}; font-weight: bold;'
39
+ f' padding: 2px 4px; border-radius: 2px; position: relative;">'
40
+ f'<span style="background-color: black; color: white; border-radius: 50%;'
41
+ f' padding: 2px 5px; margin-right: 5px;">{idx}</span>'
42
+ f'{m.group(0)}'
43
+ f'</span>'
44
+ ),
45
+ highlighted_sentence,
46
+ flags=re.IGNORECASE
47
+ )
48
+
49
+ highlighted_html.append(highlighted_sentence)
50
+
51
+ # Format the HTML output with the title
52
+ final_html = "<br><br>".join(highlighted_html)
53
+ return f'''
54
+ <div style="border: solid 1px #FFFFFF; padding: 16px; background-color: #000000; color: #FFFFFF; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
55
+ <h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
56
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
57
+ </div>
58
+ '''
59
+
60
+ def highlight_common_words_dict(common_words, sentences, title):
61
+ """
62
+ Highlight common words in sentences (from a dictionary) by adding color-coded background and unique IDs.
63
+
64
+ Args:
65
+ common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
66
+ sentences (dict): A dictionary of sentences where the key is the sentence and the value is an entailment score.
67
+ title (str): The title for the HTML output.
68
+
69
+ Returns:
70
+ str: HTML string with the highlighted sentences and their entailment scores.
71
+ """
72
+ color_map = {}
73
+ color_index = 0
74
+ highlighted_html = []
75
+
76
+ # Process each sentence and its score
77
+ for idx, (sentence, score) in enumerate(sentences.items(), start=1):
78
+ sentence_with_idx = f"{idx}. {sentence}"
79
+ highlighted_sentence = sentence_with_idx
80
+
81
+ # Highlight common words in each sentence
82
+ for index, word in common_words:
83
+ if word not in color_map:
84
+ color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
85
+ color_index += 1
86
+
87
+ # Escape word and create regex pattern to match whole word
88
+ escaped_word = re.escape(word)
89
+ pattern = rf'\b{escaped_word}\b'
90
+
91
+ # Replace the word with highlighted version
92
+ highlighted_sentence = re.sub(
93
+ pattern,
94
+ lambda m, idx=index, color=color_map[word]: (
95
+ f'<span style="background-color: {color}; font-weight: bold;'
96
+ f' padding: 1px 2px; border-radius: 2px; position: relative;">'
97
+ f'<span style="background-color: black; color: white; border-radius: 50%;'
98
+ f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
99
+ f'{m.group(0)}'
100
+ f'</span>'
101
+ ),
102
+ highlighted_sentence,
103
+ flags=re.IGNORECASE
104
+ )
105
+
106
+ # Add the entailment score
107
+ highlighted_html.append(
108
+ f'<div style="margin-bottom: 5px;">'
109
+ f'{highlighted_sentence}'
110
+ f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; '
111
+ f'background-color: #333333; color: white; font-size: 0.9em;">'
112
+ f'Entailment Score: {score}</div></div>'
113
+ )
114
+
115
+ # Format the HTML output with the title
116
+ final_html = "<br>".join(highlighted_html)
117
+ return f'''
118
+ <div style="background-color: #000000; color: #FFFFFF;border: solid 1px #FFFFFF; border-radius: 8px;">
119
+ <h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
120
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
121
+ </div>
122
+ '''
123
+
124
+ def reparaphrased_sentences_html(sentences):
125
+ """
126
+ Create an HTML representation of sentences with numbering.
127
+
128
+ Args:
129
+ sentences (list of str): List of sentences to format.
130
+
131
+ Returns:
132
+ str: HTML string with numbered sentences.
133
+ """
134
+ formatted_sentences = []
135
+
136
+ # Process each sentence
137
+ for idx, sentence in enumerate(sentences, start=1):
138
+ sentence_with_idx = f"{idx}. {sentence}"
139
+ formatted_sentences.append(sentence_with_idx)
140
+
141
+ # Format the HTML output
142
+ final_html = "<br><br>".join(formatted_sentences)
143
+ return f'''
144
+ <div style="border: solid 1px #FFFFFF; background-color: #000000; color: #FFFFFF;
145
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
146
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
147
+ </div>
148
+ '''
149
+
150
+ if __name__ == "__main__":
151
+ # Example usage
152
+ common_words = [(1, "highlight"), (2, "numbering")]
153
+ sentences = ["This is a test to highlight words.", "Numbering is important for clarity."]
154
+
155
+ # Test highlight_common_words
156
+ highlighted_html = highlight_common_words(common_words, sentences, "Test Highlighting")
157
+ print(highlighted_html)
158
+
159
+ # Test highlight_common_words_dict
160
+ sentences_with_scores = {"Highlight words in this text.": 0.95, "Number sentences for clarity.": 0.8}
161
+ highlighted_html_dict = highlight_common_words_dict(common_words, sentences_with_scores, "Test Dict Highlighting")
162
+ print(highlighted_html_dict)
renderers/plot_3d.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the code to plot a 3d tree
3
+ """
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from scipy.interpolate import griddata
7
+
8
+ def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
9
+ """
10
+ Generates a 3D surface plot showing the relationship between detectability, distortion,
11
+ and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.
12
+
13
+ The function takes three sets of values: detectability, distortion, and Euclidean distance,
14
+ normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
15
+ The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted
16
+ as a red marker on the 3D surface plot.
17
+
18
+ The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
19
+ for the Euclidean distance over the detectability and distortion values. The result is a surface plot
20
+ where the contours represent different Euclidean distances.
21
+
22
+ Args:
23
+ detectability_val (list or array): A list or array of detectability scores.
24
+ distortion_val (list or array): A list or array of distortion scores.
25
+ euclidean_val (list or array): A list or array of Euclidean distances.
26
+
27
+ Returns:
28
+ plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot,
29
+ with contour lines and a marker for the sweet spot.
30
+
31
+ Raises:
32
+ ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the
33
+ input data does not allow for a proper interpolation.
34
+
35
+ Example:
36
+ # Example of usage:
37
+ detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
38
+ distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
39
+ euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
40
+
41
+ fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
42
+ fig.show() # Displays the plot in a web browser
43
+
44
+ Notes:
45
+ - The composite score is calculated as:
46
+ `composite_score = norm_detectability - (norm_distortion + norm_euclidean)`,
47
+ where the goal is to maximize detectability and minimize distortion and Euclidean distance.
48
+ - The `griddata` function uses linear interpolation to create a smooth surface for the plot.
49
+ - The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
50
+ """
51
+
52
+ detectability = np.array(detectability_val)
53
+ distortion = np.array(distortion_val)
54
+ euclidean = np.array(euclidean_val)
55
+
56
+ # Normalize the values to range [0, 1]
57
+ norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
58
+ norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
59
+ norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
60
+
61
+ # Composite score: maximize detectability, minimize distortion and Euclidean distance
62
+ composite_score = norm_detectability - (norm_distortion + norm_euclidean)
63
+
64
+ # Find the index of the maximum score (sweet spot)
65
+ sweet_spot_index = np.argmax(composite_score)
66
+
67
+ # Sweet spot values
68
+ sweet_spot_detectability = detectability[sweet_spot_index]
69
+ sweet_spot_distortion = distortion[sweet_spot_index]
70
+ sweet_spot_euclidean = euclidean[sweet_spot_index]
71
+
72
+ # Create a meshgrid from the data
73
+ x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
74
+ np.linspace(min(distortion), max(distortion), 30))
75
+
76
+ # Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
77
+ z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')
78
+
79
+ if z_grid is None:
80
+ raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
81
+
82
+ # Create the 3D contour plot with the Plasma color scale
83
+ fig = go.Figure(data=go.Surface(
84
+ z=z_grid,
85
+ x=x_grid,
86
+ y=y_grid,
87
+ contours={
88
+ "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
89
+ },
90
+ colorscale='Plasma'
91
+ ))
92
+
93
+ # Add a marker for the sweet spot
94
+ fig.add_trace(go.Scatter3d(
95
+ x=[sweet_spot_detectability],
96
+ y=[sweet_spot_distortion],
97
+ z=[sweet_spot_euclidean],
98
+ mode='markers+text',
99
+ marker=dict(size=10, color='red', symbol='circle'),
100
+ text=["Sweet Spot"],
101
+ textposition="top center"
102
+ ))
103
+
104
+ # Set axis labels
105
+ fig.update_layout(
106
+ scene=dict(
107
+ xaxis_title='Detectability Score',
108
+ yaxis_title='Distortion Score',
109
+ zaxis_title='Euclidean Distance'
110
+ ),
111
+ margin=dict(l=0, r=0, b=0, t=0)
112
+ )
113
+
114
+ return fig
115
+
116
+ if __name__ == "__main__":
117
+ # Example input data
118
+ detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
119
+ distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
120
+ euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
121
+
122
+ # Call the function with example data
123
+ fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
124
+
125
+ # Show the plot
126
+ fig.show()
renderers/tree.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import textwrap
3
+ import re
4
+ from collections import defaultdict
5
+
6
+ def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams):
7
+ """
8
+ Generates a subplot visualizing paraphrased and masked sentences in a tree structure.
9
+ Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering.
10
+
11
+ Args:
12
+ paraphrased_sentence (str): The paraphrased sentence to be visualized.
13
+ masked_sentences (list of str): A list of masked sentences to be visualized.
14
+ strategies (list of str, optional): List of strategies used for each masked sentence.
15
+ highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
16
+ common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
17
+
18
+ Returns:
19
+ plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
20
+ """
21
+ # Combine nodes into one list with appropriate labels
22
+ if isinstance(masked_sentences, str):
23
+ masked_sentences = [masked_sentences]
24
+ nodes = [paraphrased_sentence] + masked_sentences
25
+ nodes[0] += ' L0' # Paraphrased sentence is level 0
26
+ if len(nodes) < 2:
27
+ print("[ERROR] Insufficient nodes for visualization")
28
+ return go.Figure()
29
+
30
+ for i in range(1, len(nodes)):
31
+ nodes[i] += ' L1' # masked sentences are level 1
32
+
33
+ def apply_lcs_numbering(sentence, common_grams):
34
+ """
35
+ Applies LCS numbering to the sentence based on the common_grams.
36
+
37
+ Args:
38
+ sentence (str): The sentence to which the LCS numbering should be applied.
39
+ common_grams (list of tuples): A list of common grams to be replaced with LCS numbers.
40
+
41
+ Returns:
42
+ str: The sentence with LCS numbering applied.
43
+ """
44
+ for idx, lcs in common_grams:
45
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
46
+ return sentence
47
+
48
+ # Apply LCS numbering
49
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
50
+
51
+
52
+ def highlight_words(sentence, color_map):
53
+ """
54
+ Highlights words in the sentence based on the color_map.
55
+
56
+ Args:
57
+ sentence (str): The sentence where the words will be highlighted.
58
+ color_map (dict): A dictionary mapping words to their colors.
59
+
60
+ Returns:
61
+ str: The sentence with highlighted words.
62
+ """
63
+ for word, color in color_map.items():
64
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
65
+ return sentence
66
+
67
+ # Clean and wrap nodes, and highlight specified words globally
68
+ cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
69
+ global_color_map = dict(highlight_info)
70
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
71
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
72
+
73
+ def get_levels_and_edges(nodes, strategies=None):
74
+ """
75
+ Determines tree levels and creates edges dynamically.
76
+
77
+ Args:
78
+ nodes (list of str): The nodes representing the sentences.
79
+ strategies (list of str, optional): The strategies used for each edge.
80
+
81
+ Returns:
82
+ tuple: A tuple containing two dictionaries:
83
+ - levels: A dictionary mapping node indices to their levels.
84
+ - edges: A list of edges where each edge is represented by a tuple of node indices.
85
+ """
86
+ levels = {}
87
+ edges = []
88
+ for i, node in enumerate(nodes):
89
+ level = int(node.split()[-1][1])
90
+ levels[i] = level
91
+
92
+ # Add edges from L0 to all L1 nodes
93
+ root_node = next((i for i, level in levels.items() if level == 0), 0)
94
+ for i, level in levels.items():
95
+ if level == 1:
96
+ edges.append((root_node, i))
97
+
98
+ return levels, edges
99
+
100
+ # Get levels and dynamic edges
101
+ levels, edges = get_levels_and_edges(nodes, strategies)
102
+ max_level = max(levels.values(), default=0)
103
+
104
+ # Calculate positions
105
+ positions = {}
106
+ level_heights = defaultdict(int)
107
+ for node, level in levels.items():
108
+ level_heights[level] += 1
109
+
110
+ y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
111
+ x_gap = 2
112
+ l1_y_gap = 10
113
+
114
+ for node, level in levels.items():
115
+ if level == 1:
116
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
117
+ else:
118
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
119
+ y_offsets[level] += 1
120
+
121
+ def color_highlighted_words(node, color_map):
122
+ """
123
+ Colors the highlighted words in the node text.
124
+
125
+ Args:
126
+ node (str): The node text to be highlighted.
127
+ color_map (dict): A dictionary mapping words to their colors.
128
+
129
+ Returns:
130
+ str: The node text with highlighted words.
131
+ """
132
+ parts = re.split(r'(\{\{.*?\}\})', node)
133
+ colored_parts = []
134
+ for part in parts:
135
+ match = re.match(r'\{\{(.*?)\}\}', part)
136
+ if match:
137
+ word = match.group(1)
138
+ color = color_map.get(word, 'black')
139
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
140
+ else:
141
+ colored_parts.append(part)
142
+ return ''.join(colored_parts)
143
+
144
+ # Define the text for each edge
145
+ default_edge_texts = [
146
+ "Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
147
+ "Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
148
+ "Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
149
+ "Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling",
150
+ "Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling"
151
+ ]
152
+
153
+ if len(nodes) < 2:
154
+ print("[ERROR] Insufficient nodes for visualization")
155
+ return go.Figure()
156
+
157
+ # Create figure
158
+ fig1 = go.Figure()
159
+
160
+ # Add nodes to the figure
161
+ for i, node in enumerate(wrapped_nodes):
162
+ colored_node = color_highlighted_words(node, global_color_map)
163
+ x, y = positions[i]
164
+ fig1.add_trace(go.Scatter(
165
+ x=[-x], # Reflect the x coordinate
166
+ y=[y],
167
+ mode='markers',
168
+ marker=dict(size=20, color='blue', line=dict(color='black', width=2)),
169
+ hoverinfo='none'
170
+ ))
171
+ fig1.add_annotation(
172
+ x=-x, # Reflect the x coordinate
173
+ y=y,
174
+ text=colored_node,
175
+ showarrow=False,
176
+ xshift=15,
177
+ align="center",
178
+ font=dict(size=12),
179
+ bordercolor='black',
180
+ borderwidth=2,
181
+ borderpad=4,
182
+ bgcolor='white',
183
+ width=400,
184
+ height=100
185
+ )
186
+
187
+ # Add edges and text above each edge
188
+ for i, edge in enumerate(edges):
189
+ x0, y0 = positions[edge[0]]
190
+ x1, y1 = positions[edge[1]]
191
+
192
+ # Use strategy if available, otherwise use default edge text
193
+ if strategies and i < len(strategies):
194
+ edge_text = strategies[i]
195
+ else:
196
+ edge_text = default_edge_texts[i % len(default_edge_texts)]
197
+
198
+ fig1.add_trace(go.Scatter(
199
+ x=[-x0, -x1], # Reflect the x coordinates
200
+ y=[y0, y1],
201
+ mode='lines',
202
+ line=dict(color='black', width=1)
203
+ ))
204
+
205
+ # Calculate the midpoint of the edge
206
+ mid_x = (-x0 + -x1) / 2
207
+ mid_y = (y0 + y1) / 2
208
+
209
+ # Adjust y position to shift text upwards
210
+ text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
211
+
212
+ # Add text annotation above the edge
213
+ fig1.add_annotation(
214
+ x=mid_x,
215
+ y=text_y_position,
216
+ text=edge_text, # Use the text specific to this edge
217
+ showarrow=False,
218
+ font=dict(size=12),
219
+ align="center"
220
+ )
221
+
222
+ fig1.update_layout(
223
+ showlegend=False,
224
+ margin=dict(t=50, b=50, l=50, r=50),
225
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
226
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
227
+ width=800 + max_level * 200, # Adjusted width to accommodate more levels
228
+ height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels
229
+ plot_bgcolor='rgba(240,240,240,0.2)',
230
+ paper_bgcolor='white'
231
+ )
232
+
233
+ return fig1
234
+
235
+ def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams):
236
+ """
237
+ Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure.
238
+ Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques.
239
+
240
+ Args:
241
+ masked_sentences (list of str): A list of masked sentences to be visualized as root nodes.
242
+ sampled_sentences (list of str): A list of sampled sentences derived from masked sentences.
243
+ highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
244
+ common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
245
+
246
+ Returns:
247
+ plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
248
+ """
249
+ # Define sampling techniques
250
+ sampling_techniques = [
251
+ "Greedy Sampling",
252
+ "Temperature Sampling",
253
+ "Exponential Minimum Sampling",
254
+ "Inverse Transform Sampling"
255
+ ]
256
+
257
+ # Calculate total number of nodes
258
+ num_masked = len(masked_sentences)
259
+ num_sampled_per_masked = len(sampling_techniques)
260
+ total_nodes = num_masked + (num_masked * num_sampled_per_masked)
261
+
262
+ # Combine all sentences into nodes list with appropriate labels
263
+ nodes = []
264
+ # Level 0: masked sentences (root nodes)
265
+ nodes.extend([s + ' L0' for s in masked_sentences])
266
+
267
+ # Level 1: sampled sentences (branch nodes)
268
+ # For each masked sentence, we should have samples from each technique
269
+ sampled_nodes = []
270
+
271
+ # Validate if we have the expected number of sampled sentences
272
+ expected_sampled_count = num_masked * num_sampled_per_masked
273
+ if len(sampled_sentences) < expected_sampled_count:
274
+ # If insufficient samples provided, pad with placeholder sentences
275
+ print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}")
276
+ while len(sampled_sentences) < expected_sampled_count:
277
+ sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}")
278
+
279
+ # Add all sampled sentences with level information
280
+ for s in sampled_sentences[:expected_sampled_count]:
281
+ sampled_nodes.append(s + ' L1')
282
+
283
+ nodes.extend(sampled_nodes)
284
+
285
+ def apply_lcs_numbering(sentence, common_grams):
286
+ """
287
+ Applies LCS numbering to the sentence based on the common_grams.
288
+ """
289
+ for idx, lcs in common_grams:
290
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
291
+ return sentence
292
+
293
+ # Apply LCS numbering
294
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
295
+
296
+ def highlight_words(sentence, color_map):
297
+ """
298
+ Highlights words in the sentence based on the color_map.
299
+ """
300
+ for word, color in color_map.items():
301
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
302
+ return sentence
303
+
304
+ # Helper function to color highlighted words
305
+ def color_highlighted_words(node, color_map):
306
+ """
307
+ Colors the highlighted words in the node text.
308
+ """
309
+ parts = re.split(r'(\{\{.*?\}\})', node)
310
+ colored_parts = []
311
+ for part in parts:
312
+ match = re.match(r'\{\{(.*?)\}\}', part)
313
+ if match:
314
+ word = match.group(1)
315
+ color = color_map.get(word, 'black')
316
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
317
+ else:
318
+ colored_parts.append(part)
319
+ return ''.join(colored_parts)
320
+
321
+ # Clean nodes, highlight words, and wrap text
322
+ cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
323
+ global_color_map = dict(highlight_info)
324
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
325
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
326
+
327
+ # Generate edges based on the tree structure
328
+ def get_levels_and_edges(nodes):
329
+ levels = {}
330
+ edges = []
331
+
332
+ # Extract level info from node labels
333
+ for i, node in enumerate(nodes):
334
+ level = int(node.split()[-1][1])
335
+ levels[i] = level
336
+
337
+ # Create edges from masked sentences to their sampled variants
338
+ for masked_idx in range(num_masked):
339
+ # For each masked sentence, create edges to its sampled variants
340
+ for technique_idx in range(num_sampled_per_masked):
341
+ sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
342
+ if sampled_idx < len(nodes):
343
+ edges.append((masked_idx, sampled_idx))
344
+
345
+ return levels, edges
346
+
347
+ levels, edges = get_levels_and_edges(nodes)
348
+
349
+ # Calculate positions with improved spacing
350
+ positions = {}
351
+
352
+ # Calculate horizontal spacing for the root nodes (masked sentences)
353
+ root_x_spacing = 0 # All root nodes at x=0
354
+ root_y_spacing = 8.0 # Vertical spacing between root nodes
355
+
356
+ # Calculate positions for sampled nodes
357
+ sampled_x = 3 # X position for all sampled nodes
358
+
359
+ # Calculate y positions for root nodes (masked sentences)
360
+ root_y_start = -(num_masked - 1) * root_y_spacing / 2
361
+ for i in range(num_masked):
362
+ positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing)
363
+
364
+ # Calculate y positions for sampled nodes
365
+ for masked_idx in range(num_masked):
366
+ root_y = positions[masked_idx][1] # Y position of parent masked sentence
367
+
368
+ # Calculate y-spacing for children of this root
369
+ children_y_spacing = 1.5 # Vertical spacing between children of the same root
370
+ children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2
371
+
372
+ # Position each child
373
+ for technique_idx in range(num_sampled_per_masked):
374
+ child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
375
+ child_y = children_y_start + technique_idx * children_y_spacing
376
+ positions[child_idx] = (sampled_x, child_y)
377
+
378
+ # Create figure
379
+ fig2 = go.Figure()
380
+
381
+ # Add nodes
382
+ for i, node in enumerate(wrapped_nodes):
383
+ x, y = positions[i]
384
+
385
+ # Define node color based on level
386
+ node_color = 'blue' if levels[i] == 0 else 'green'
387
+
388
+ # Add the node marker
389
+ fig2.add_trace(go.Scatter(
390
+ x=[x],
391
+ y=[y],
392
+ mode='markers',
393
+ marker=dict(size=20, color=node_color, line=dict(color='black', width=2)),
394
+ hoverinfo='none'
395
+ ))
396
+
397
+ # Add node label with highlighting
398
+ colored_node = color_highlighted_words(node, global_color_map)
399
+
400
+ fig2.add_annotation(
401
+ x=x,
402
+ y=y,
403
+ text=colored_node,
404
+ showarrow=False,
405
+ xshift=15,
406
+ align="left",
407
+ font=dict(size=12),
408
+ bordercolor='black',
409
+ borderwidth=2,
410
+ borderpad=4,
411
+ bgcolor='white',
412
+ width=400,
413
+ height=100
414
+ )
415
+
416
+ # Add edges with labels
417
+ for i, (src, dst) in enumerate(edges):
418
+ x0, y0 = positions[src]
419
+ x1, y1 = positions[dst]
420
+
421
+ # Draw the edge
422
+ fig2.add_trace(go.Scatter(
423
+ x=[x0, x1],
424
+ y=[y0, y1],
425
+ mode='lines',
426
+ line=dict(color='black', width=1)
427
+ ))
428
+
429
+ # Add sampling technique label
430
+ # Determine which sampling technique this is
431
+ parent_idx = src
432
+ technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i)
433
+ technique_label = sampling_techniques[technique_count % len(sampling_techniques)]
434
+
435
+ # Calculate midpoint for the label
436
+ mid_x = (x0 + x1) / 2
437
+ mid_y = (y0 + y1) / 2
438
+
439
+ # Add slight offset to avoid overlap
440
+ label_offset = 0.1
441
+
442
+ fig2.add_annotation(
443
+ x=mid_x,
444
+ y=mid_y + label_offset,
445
+ text=technique_label,
446
+ showarrow=False,
447
+ font=dict(size=8),
448
+ align="center"
449
+ )
450
+
451
+ # Update layout
452
+ fig2.update_layout(
453
+ showlegend=False,
454
+ margin=dict(t=20, b=20, l=20, r=20),
455
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
456
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
457
+ width=1200, # Adjusted width to accommodate more levels
458
+ height=2000, # Adjusted height to accommodate more levels
459
+ plot_bgcolor='rgba(240,240,240,0.2)',
460
+ paper_bgcolor='white'
461
+
462
+ )
463
+
464
+ return fig2
465
+
466
+ if __name__ == "__main__":
467
+ paraphrased_sentence = "The quick brown fox jumps over the lazy dog."
468
+ masked_sentences = [
469
+ "A fast brown fox leaps over the lazy dog.",
470
+ "A quick brown fox hops over a lazy dog."
471
+ ]
472
+ highlight_info = [
473
+ ("quick", "red"),
474
+ ("brown", "green"),
475
+ ("fox", "blue"),
476
+ ("lazy", "purple")
477
+ ]
478
+ common_grams = [
479
+ (1, "quick brown fox"),
480
+ (2, "lazy dog")
481
+ ]
482
+
483
+ fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams)
484
+ fig1.show()
485
+
486
+ sampled_sentence = ["A fast brown fox jumps over a lazy dog."]
487
+
488
+
489
+ fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams)
490
+ fig2.show()
utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from utils.watermark import Watermarker
2
+ from utils.paraphraser import Paraphraser
3
+ from utils.entailment import EntailmentAnalyzer
4
+ from utils.sampling import SamplingProcessor
5
+ from utils.config import load_config
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (404 Bytes). View file
 
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (509 Bytes). View file
 
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (594 Bytes). View file
 
utils/__pycache__/config.cpython-311.pyc ADDED
Binary file (971 Bytes). View file
 
utils/__pycache__/entailment.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
utils/__pycache__/entailment.cpython-311.pyc ADDED
Binary file (5.33 kB). View file
 
utils/__pycache__/masking_methods.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
utils/__pycache__/masking_methods.cpython-311.pyc ADDED
Binary file (23.5 kB). View file
 
utils/__pycache__/non_melting_point.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
utils/__pycache__/non_melting_point.cpython-311.pyc ADDED
Binary file (9.08 kB). View file
 
utils/__pycache__/paraphraser.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
utils/__pycache__/paraphraser.cpython-311.pyc ADDED
Binary file (4.89 kB). View file
 
utils/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (5.06 kB). View file
 
utils/__pycache__/sampling.cpython-311.pyc ADDED
Binary file (9.2 kB). View file
 
utils/__pycache__/watermark.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
utils/__pycache__/watermark.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
utils/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file loads config from config.yaml
3
+ """
4
+
5
+ import yaml
6
+
7
+ def load_config(path):
8
+ """
9
+ Function to load config from config.yaml
10
+ """
11
+ try:
12
+ with open(path, "r") as file:
13
+ config = yaml.safe_load(file)
14
+ return config
15
+ except FileNotFoundError:
16
+ raise FileNotFoundError("Config file not found")
17
+ except Exception as e:
18
+ raise e
utils/config.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the official config file.
2
+ PECCAVI_TEXT:
3
+ Entailment:
4
+ task: "text-classification"
5
+ model: "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
6
+
7
+ Masking:
8
+ task: "fill-mask"
9
+ tokenizer: "bert-base-uncased"
10
+ model: "bert-base-uncased"
11
+ # tokenizer: "bert-large-cased-whole-word-masking"
12
+ # model: "bert-large-cased-whole-word-masking"
13
+
14
+ Vocabulary:
15
+ tokenizer: "bert-base-uncased"
16
+ model: "bert-base-uncased"
17
+ # permissible_ratio: 0.5
18
+ # tokenizer: "bert-large-cased-whole-word-masking"
19
+ # model: "bert-large-cased-whole-word-masking"
20
+ permissible_ratio: 1.0
21
+
22
+ Sampling:
23
+ tokenizer: "bert-base-uncased"
24
+ model: "bert-base-uncased"
25
+ # tokenizer: "bert-large-cased-whole-word-masking"
26
+ # model: "bert-large-cased-whole-word-masking"
27
+
28
+ Metrics:
29
+ EuclideanDistance: "sentence-transformers/all-MiniLM-L6-v2"
30
+ Distortion: "gpt2"
31
+
32
+ Detector:
33
+ tokenizer: "bert-base-uncased"
34
+ model: "bert-base-uncased"
35
+ # tokenizer: "bert-large-cased-whole-word-masking"
36
+ # model: "bert-large-cased-whole-word-masking"
37
+
38
+ Paraphrase:
39
+ tokenizer: "humarin/chatgpt_paraphraser_on_T5_base"
40
+ model: "humarin/chatgpt_paraphraser_on_T5_base"
41
+ num_beams: 10
42
+ num_beam_groups: 10
43
+ num_return_sequences: 10
44
+ repetition_penalty: 10.0
45
+ diversity_penalty: 3.0
46
+ no_repeat_ngram_size: 2
47
+ temperature: 0.7
48
+ max_length: 64
utils/entailment.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from typing import List
8
+ from utils.config import load_config
9
+
10
+
11
+ class EntailmentAnalyzer:
12
+ # def __init__(self, config_path: str):
13
+ def __init__(self, config):
14
+ """
15
+ Initialize the EntailmentAnalyzer with the config file path.
16
+
17
+ Args:
18
+ config_path: The path to the configuration file.
19
+ """
20
+ # self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment']
21
+ self.config = config
22
+ self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model'])
23
+
24
+ def check_entailment(self, premise: str, hypothesis: str) -> float:
25
+ """
26
+ Check entailment between the premise and hypothesis.
27
+
28
+ Args:
29
+ premise: The premise sentence.
30
+ hypothesis: The hypothesis sentence.
31
+
32
+ Returns:
33
+ float: The entailment score.
34
+ """
35
+ results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None)
36
+ entailment_score = next(item['score'] for item in results if item['label'] == 'entailment')
37
+ return entailment_score
38
+
39
+ def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple:
40
+ """
41
+ Analyze entailment scores for paraphrased sentences. If no selected sentences are found,
42
+ lower the threshold and rerun the analysis.
43
+
44
+ Args:
45
+ original_sentence: The original sentence.
46
+ paraphrased_sentences: List of paraphrased sentences.
47
+ threshold: Minimum score to select a sentence.
48
+
49
+ Returns:
50
+ tuple: A dictionary of all scores, selected sentences, and discarded sentences.
51
+ """
52
+ all_sentences = {}
53
+ selected_sentences = {}
54
+ discarded_sentences = {}
55
+
56
+ # Loop to reduce threshold if no sentences are selected
57
+ while not selected_sentences:
58
+ for paraphrased_sentence in paraphrased_sentences:
59
+ entailment_score = self.check_entailment(original_sentence, paraphrased_sentence)
60
+
61
+ all_sentences[paraphrased_sentence] = entailment_score
62
+ if entailment_score >= threshold:
63
+ selected_sentences[paraphrased_sentence] = entailment_score
64
+ else:
65
+ discarded_sentences[paraphrased_sentence] = entailment_score
66
+
67
+ # If no sentences are selected, lower the threshold
68
+ if not selected_sentences:
69
+ print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).")
70
+ threshold -= 0.1
71
+ if threshold <= 0:
72
+ print("Threshold has reached 0. No sentences meet the criteria.")
73
+ break
74
+
75
+ return all_sentences, selected_sentences, discarded_sentences
76
+
77
+
78
+ if __name__ == "__main__":
79
+ config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
80
+
81
+ config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml'
82
+
83
+ config = load_config(config_path)
84
+
85
+ entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment'])
86
+
87
+ all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment(
88
+ "The weather is nice today",
89
+ [
90
+ "The climate is pleasant today",
91
+ "It's a good day weather-wise",
92
+ "Today, the weather is terrible",
93
+ "What a beautiful day it is",
94
+ "The sky is clear and the weather is perfect",
95
+ "It's pouring rain outside today",
96
+ "The weather isn't bad today",
97
+ "A lovely day for outdoor activities"
98
+ ],
99
+ 0.7
100
+ )
101
+
102
+ print("----------------------- All Sentences -----------------------")
103
+ print(all_sentences)
104
+ print("----------------------- Discarded Sentences -----------------------")
105
+ print(discarded_sentences)
106
+ print("----------------------- Selected Sentences -----------------------")
107
+ print(selected_sentences)
utils/masking_methods.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import logging
4
+ from transformers import BertTokenizer, BertForMaskedLM
5
+ from nltk.corpus import stopwords
6
+ import nltk
7
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
8
+ from tqdm import tqdm
9
+
10
+ # Set logging to WARNING for a cleaner terminal.
11
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Ensure stopwords are downloaded
15
+ try:
16
+ nltk.data.find('corpora/stopwords')
17
+ except LookupError:
18
+ nltk.download('stopwords')
19
+
20
+ class MaskingProcessor:
21
+ def __init__(self, tokenizer, model):
22
+ self.tokenizer = tokenizer
23
+ self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.stop_words = set(stopwords.words('english'))
26
+ tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}")
27
+
28
+ def remove_stopwords(self, words):
29
+ return [word for word in words if word.lower() not in self.stop_words]
30
+
31
+ def adjust_ngram_indices(self, original_words, common_ngrams):
32
+ logger.info("Adjusting n-gram indices.")
33
+ non_stop_words = self.remove_stopwords(original_words)
34
+ original_to_non_stop = []
35
+ non_stop_idx = 0
36
+ for original_idx, word in enumerate(original_words):
37
+ if word.lower() not in self.stop_words:
38
+ original_to_non_stop.append((original_idx, non_stop_idx))
39
+ non_stop_idx += 1
40
+ adjusted_ngrams = {}
41
+ for ngram, positions in common_ngrams.items():
42
+ adjusted_positions = []
43
+ for start, end in positions:
44
+ try:
45
+ new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
46
+ new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
47
+ adjusted_positions.append((new_start, new_end))
48
+ except StopIteration:
49
+ continue
50
+ adjusted_ngrams[ngram] = adjusted_positions
51
+ return adjusted_ngrams
52
+
53
+ def mask_sentence_random(self, sentence, common_ngrams):
54
+ tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}")
55
+ original_words = sentence.split()
56
+ has_punctuation = False
57
+ punctuation = ''
58
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
59
+ has_punctuation = True
60
+ punctuation = original_words[-1][-1]
61
+ original_words = original_words[:-1]
62
+
63
+ non_stop_words = self.remove_stopwords(original_words)
64
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
65
+ mask_indices = []
66
+
67
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
68
+ if ngram_positions:
69
+ first_ngram_start = ngram_positions[0][0]
70
+ if first_ngram_start > 0:
71
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
72
+ mask_indices.append(mask_index_before_ngram)
73
+
74
+ for i in range(len(ngram_positions) - 1):
75
+ end_prev = ngram_positions[i][1]
76
+ start_next = ngram_positions[i + 1][0]
77
+ if start_next > end_prev + 1:
78
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
79
+ mask_indices.append(mask_index_between_ngrams)
80
+
81
+ last_ngram_end = ngram_positions[-1][1]
82
+ if last_ngram_end < len(non_stop_words) - 1:
83
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
84
+ mask_indices.append(mask_index_after_ngram)
85
+
86
+ non_stop_to_original = {}
87
+ non_stop_idx = 0
88
+ for orig_idx, word in enumerate(original_words):
89
+ if word.lower() not in self.stop_words:
90
+ non_stop_to_original[non_stop_idx] = orig_idx
91
+ non_stop_idx += 1
92
+
93
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
94
+ masked_words = original_words.copy()
95
+ for idx in original_mask_indices:
96
+ masked_words[idx] = self.tokenizer.mask_token
97
+
98
+ if has_punctuation:
99
+ masked_words.append(punctuation)
100
+
101
+ logger.info(f"Masked sentence (random): {' '.join(masked_words)}")
102
+ return " ".join(masked_words), original_mask_indices
103
+
104
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
105
+ logger.info(f"Masking sentence using pseudorandom strategy: {sentence}")
106
+ random.seed(3)
107
+ original_words = sentence.split()
108
+ has_punctuation = False
109
+ punctuation = ''
110
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
111
+ has_punctuation = True
112
+ punctuation = original_words[-1][-1]
113
+ original_words = original_words[:-1]
114
+
115
+ non_stop_words = self.remove_stopwords(original_words)
116
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
117
+ mask_indices = []
118
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
119
+
120
+ if ngram_positions:
121
+ first_ngram_start = ngram_positions[0][0]
122
+ if first_ngram_start > 0:
123
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
124
+ mask_indices.append(mask_index_before_ngram)
125
+ for i in range(len(ngram_positions) - 1):
126
+ end_prev = ngram_positions[i][1]
127
+ start_next = ngram_positions[i + 1][0]
128
+ if start_next > end_prev + 1:
129
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
130
+ mask_indices.append(mask_index_between_ngrams)
131
+ last_ngram_end = ngram_positions[-1][1]
132
+ if last_ngram_end < len(non_stop_words) - 1:
133
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
134
+ mask_indices.append(mask_index_after_ngram)
135
+
136
+ non_stop_to_original = {}
137
+ non_stop_idx = 0
138
+ for orig_idx, word in enumerate(original_words):
139
+ if word.lower() not in self.stop_words:
140
+ non_stop_to_original[non_stop_idx] = orig_idx
141
+ non_stop_idx += 1
142
+
143
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
144
+ masked_words = original_words.copy()
145
+ for idx in original_mask_indices:
146
+ masked_words[idx] = self.tokenizer.mask_token
147
+
148
+ if has_punctuation:
149
+ masked_words.append(punctuation)
150
+
151
+ logger.info(f"Masked sentence (pseudorandom): {' '.join(masked_words)}")
152
+ return " ".join(masked_words), original_mask_indices
153
+
154
+ def mask_sentence_entropy(self, sentence, common_ngrams):
155
+ logger.info(f"Masking sentence using entropy strategy: {sentence}")
156
+ original_words = sentence.split()
157
+ has_punctuation = False
158
+ punctuation = ''
159
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
160
+ has_punctuation = True
161
+ punctuation = original_words[-1][-1]
162
+ original_words = original_words[:-1]
163
+
164
+ non_stop_words = self.remove_stopwords(original_words)
165
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
166
+ mask_indices = []
167
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
168
+ non_stop_to_original = {}
169
+ non_stop_idx = 0
170
+ for orig_idx, word in enumerate(original_words):
171
+ if word.lower() not in self.stop_words:
172
+ non_stop_to_original[non_stop_idx] = orig_idx
173
+ non_stop_idx += 1
174
+
175
+ if ngram_positions:
176
+ first_ngram_start = ngram_positions[0][0]
177
+ if first_ngram_start > 0:
178
+ candidate_positions = range(0, first_ngram_start)
179
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
180
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
181
+ for i in range(len(ngram_positions) - 1):
182
+ end_prev = ngram_positions[i][1]
183
+ start_next = ngram_positions[i + 1][0]
184
+ if start_next > end_prev + 1:
185
+ candidate_positions = range(end_prev + 1, start_next)
186
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
187
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
188
+ last_ngram_end = ngram_positions[-1][1]
189
+ if last_ngram_end < len(non_stop_words) - 1:
190
+ candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
191
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions]
192
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
193
+
194
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
195
+ masked_words = original_words.copy()
196
+ for idx in original_mask_indices:
197
+ masked_words[idx] = self.tokenizer.mask_token
198
+
199
+ if has_punctuation:
200
+ masked_words.append(punctuation)
201
+
202
+ logger.info(f"Masked sentence (entropy): {' '.join(masked_words)}")
203
+ return " ".join(masked_words), original_mask_indices
204
+
205
+ def calculate_mask_logits(self, original_sentence, original_mask_indices):
206
+ logger.info(f"Calculating mask logits for sentence: {original_sentence}")
207
+ words = original_sentence.split()
208
+ mask_logits = {}
209
+ for idx in original_mask_indices:
210
+ masked_words = words.copy()
211
+ masked_words[idx] = self.tokenizer.mask_token
212
+ masked_sentence = " ".join(masked_words)
213
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
214
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
215
+ with torch.no_grad():
216
+ outputs = self.model(input_ids)
217
+ logits = outputs.logits
218
+ mask_logits_tensor = logits[0, mask_token_index, :]
219
+ top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1)
220
+ top_tokens = []
221
+ top_logits = []
222
+ seen_words = set()
223
+ for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
224
+ token = self.tokenizer.convert_ids_to_tokens(token_id.item())
225
+ if token.startswith('##'):
226
+ continue
227
+ word = self.tokenizer.convert_tokens_to_string([token]).strip()
228
+ if word and word not in seen_words:
229
+ seen_words.add(word)
230
+ top_tokens.append(word)
231
+ top_logits.append(logit.item())
232
+ if len(top_tokens) == 50:
233
+ break
234
+ mask_logits[idx] = {
235
+ "tokens": top_tokens,
236
+ "logits": top_logits
237
+ }
238
+ logger.info("Completed calculating mask logits.")
239
+ return mask_logits
240
+
241
+ def calculate_word_entropy(self, sentence, word_position):
242
+ logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}")
243
+ words = sentence.split()
244
+ masked_words = words.copy()
245
+ masked_words[word_position] = self.tokenizer.mask_token
246
+ masked_sentence = " ".join(masked_words)
247
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
248
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
249
+ with torch.no_grad():
250
+ outputs = self.model(input_ids)
251
+ logits = outputs.logits
252
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
253
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
254
+ logger.info(f"Computed entropy: {entropy.item()}")
255
+ return entropy.item()
256
+
257
+ def process_sentences(self, sentences_list, common_grams, method="random"):
258
+ tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}")
259
+ results = {}
260
+ for sentence, ngrams in tqdm(common_grams.items(), desc="Masking Sentences"):
261
+ words = sentence.split()
262
+ last_word = words[-1]
263
+ if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
264
+ words[-1] = last_word[:-1]
265
+ punctuation = last_word[-1]
266
+ processed_sentence = " ".join(words) + " " + punctuation
267
+ else:
268
+ processed_sentence = sentence
269
+
270
+ if method == "random":
271
+ masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
272
+ elif method == "pseudorandom":
273
+ masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
274
+ else: # entropy
275
+ masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
276
+
277
+ logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
278
+ results[sentence] = {
279
+ "masked_sentence": masked_sentence,
280
+ "mask_logits": logits
281
+ }
282
+ logger.info(f"Processed sentence: {sentence}")
283
+ tqdm.write("[MaskingProcessor] Completed processing sentences.")
284
+ return results
285
+
286
+ if __name__ == "__main__":
287
+ sentences = [
288
+ "The quick brown fox jumps over small cat the lazy dog everyday again and again .",
289
+ ]
290
+ result_dict = {
291
+ 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {
292
+ 'brown fox': [(2, 3)],
293
+ 'cat': [(7, 7)],
294
+ 'dog': [(10, 10)]
295
+ }
296
+ }
297
+ processor = MaskingProcessor(
298
+ BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"),
299
+ BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
300
+ )
301
+ results_entropy = processor.process_sentences(sentences_list, common_grams, method="random")
302
+ for sentence, output in results_entropy.items():
303
+ logger.info(f"Original Sentence (Random): {sentence}")
304
+ logger.info(f"Masked Sentence (Random): {output['masked_sentence']}")
utils/non_melting_point.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import logging
3
+ from nltk.corpus import stopwords
4
+ from nltk.util import ngrams
5
+ from collections import Counter
6
+ import re
7
+ from tqdm import tqdm
8
+
9
+ # Set logging to WARNING for minimal console output.
10
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class NgramProcessor:
14
+ def __init__(self):
15
+ try:
16
+ nltk.data.find('corpora/stopwords')
17
+ except LookupError:
18
+ nltk.download('stopwords')
19
+ self.stop_words = set(stopwords.words('english'))
20
+ tqdm.write("[NgramProcessor] Initialized with stopwords.")
21
+
22
+ def remove_stopwords(self, text):
23
+ # No need for extensive logging inside this helper.
24
+ words = re.findall(r'\w+', text.lower())
25
+ filtered_words = [word for word in words if word not in self.stop_words]
26
+ return ' '.join(filtered_words)
27
+
28
+ def is_exact_match(self, ngram, sentences):
29
+ logger.info(f"Checking exact match for ngram: {ngram}")
30
+ result = all(ngram in sentence for sentence in sentences)
31
+ logger.info(f"Exact match result for '{ngram}': {result}")
32
+ return result
33
+
34
+ def is_substring_of_any(self, ngram, common_ngrams):
35
+ logger.info(f"Checking if ngram: {ngram} is substring of any common ngram.")
36
+ result = any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
37
+ logger.info(f"Substring check result for '{ngram}': {result}")
38
+ return result
39
+
40
+ def find_filtered_ngrams(self, sentences):
41
+ from collections import Counter
42
+ tqdm.write("[NgramProcessor] Cleaning sentences...")
43
+ sentences_cleaned = [self.remove_stopwords(sentence)
44
+ for sentence in tqdm(sentences, desc="Cleaning Sentences")]
45
+ ngram_lengths = [4, 3, 2, 1]
46
+ common_ngrams = []
47
+ result = {}
48
+ for n in ngram_lengths:
49
+ ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences_cleaned]
50
+ ngrams_counter = Counter(ngrams_list[0])
51
+ for ngram in ngrams_counter:
52
+ ngram_str = ' '.join(ngram)
53
+ if any(word in self.stop_words for word in ngram_str.split()):
54
+ continue
55
+ if self.is_exact_match(ngram_str, sentences_cleaned) and not self.is_substring_of_any(ngram_str, common_ngrams):
56
+ common_ngrams.append(ngram_str)
57
+ for sentence, cleaned_sentence in tqdm(zip(sentences, sentences_cleaned),
58
+ total=len(sentences),
59
+ desc="Mapping N-grams"):
60
+ sentence_result = {}
61
+ original_words = sentence.split()
62
+ cleaned_words = cleaned_sentence.split()
63
+ index_map = {}
64
+ cleaned_idx = 0
65
+ for orig_idx, word in enumerate(original_words):
66
+ if word.lower() not in self.stop_words:
67
+ index_map[cleaned_idx] = orig_idx
68
+ cleaned_idx += 1
69
+ for ngram in common_ngrams:
70
+ ngram_words = ngram.split()
71
+ indices = []
72
+ for i in range(len(cleaned_words) - len(ngram_words) + 1):
73
+ if cleaned_words[i:i + len(ngram_words)] == ngram_words:
74
+ if i in index_map:
75
+ start_idx = index_map[i]
76
+ end_idx = index_map.get(i + len(ngram_words) - 1, start_idx)
77
+ if end_idx - start_idx == len(ngram_words) - 1:
78
+ indices.append((start_idx, end_idx))
79
+
80
+ if indices:
81
+ sentence_result[ngram] = indices
82
+ result[sentence] = sentence_result
83
+ return result
84
+
85
+ # def find_relative_order(self, sentence, common_ngrams):
86
+ # from tqdm import tqdm
87
+ # relative_order = []
88
+ # for ngram in tqdm(common_ngrams, desc="Ordering N-grams", leave=False):
89
+ # index = sentence.find(ngram)
90
+ # if index != -1:
91
+ # relative_order.append((index, ngram))
92
+ # return sorted(relative_order)
93
+
94
+ def find_relative_order(self, sentence, common_ngrams):
95
+ from tqdm import tqdm
96
+ sentence = sentence.lower()
97
+ relative_order = []
98
+
99
+ for ngram in tqdm(common_ngrams, desc="Ordering N-grams", leave=False):
100
+ index = sentence.find(ngram.lower())
101
+ if index != -1:
102
+ relative_order.append((index, ngram))
103
+
104
+ sorted_pairs = sorted(relative_order)
105
+ return [(i+1, ngram) for i, (_, ngram) in enumerate(sorted_pairs)]
106
+
107
+ # Example usage
108
+ if __name__ == "__main__":
109
+ sentences = [
110
+ "The quick brown fox jumps over the lazy dog .",
111
+ "A speedy brown fox jumps over a lazy dog.",
112
+ "A swift brown fox leaps over the lethargic dog.",
113
+ ]
114
+ processor = NgramProcessor()
115
+ common_ngrams = processor.find_filtered_ngrams(sentences)
116
+ print(common_ngrams)
117
+ # modified_output = list({
118
+ # (indices[0][0], gram)
119
+ # for grams in common_ngrams.values()
120
+ # for gram, indices in grams.items()
121
+ # })
122
+ # print(modified_output)
123
+ logger.info(f"Common n-grams and their indices per sentence: {common_ngrams}")
124
+ for sentence in sentences:
125
+ order = processor.find_relative_order(sentence, common_ngrams[sentence])
126
+ logger.info(f"Sentence: {sentence} -> Order: {order}")
127
+
128
+
129
+ """
130
+
131
+ {
132
+ 'The quick brown fox jumps over the lazy dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]},
133
+ 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]},
134
+ 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(1, 2)], 'dog': [(5, 5)]}
135
+ }
136
+ """
137
+
utils/old/masking/masking_methods.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self, ):
15
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
17
+ self.stop_words = set(stopwords.words('english'))
18
+
19
+ def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
20
+ """
21
+ Adjust indices of common n-grams after removing stop words.
22
+
23
+ Args:
24
+ words (list): List of words in the original sentence.
25
+ common_ngrams (dict): Common n-grams and their indices.
26
+
27
+ Returns:
28
+ dict: Adjusted common n-grams and their indices.
29
+ """
30
+ if not remove_stopwords:
31
+ return common_ngrams
32
+
33
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
34
+ adjusted_ngrams = {}
35
+
36
+ for ngram, positions in common_ngrams.items():
37
+ adjusted_positions = []
38
+ for start, end in positions:
39
+ try:
40
+ new_start = non_stop_word_indices.index(start)
41
+ new_end = non_stop_word_indices.index(end)
42
+ adjusted_positions.append((new_start, new_end))
43
+ except ValueError:
44
+ continue # Skip if indices cannot be mapped
45
+ adjusted_ngrams[ngram] = adjusted_positions
46
+
47
+ return adjusted_ngrams
48
+
49
+ # def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
50
+ # """
51
+ # Mask one word before the first common n-gram, one between two n-grams,
52
+ # and one after the last common n-gram (random selection).
53
+
54
+ # Args:
55
+ # original_sentence (str): Original sentence
56
+ # common_ngrams (dict): Common n-grams and their indices
57
+
58
+ # Returns:
59
+ # str: Masked sentence with original stop words retained
60
+ # """
61
+ # words = original_sentence.split()
62
+ # if remove_stopwords:
63
+ # non_stop_words = [word for word in words if word.lower() not in self.stop_words]
64
+ # non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
65
+ # else:
66
+ # non_stop_words = words
67
+ # non_stop_word_indices = list(range(len(words)))
68
+ # # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
69
+ # adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
70
+
71
+ # mask_indices = []
72
+ # # Handle before the first common n-gram
73
+ # if adjusted_ngrams:
74
+ # first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
75
+ # if first_ngram_start > 0:
76
+ # mask_indices.append(random.randint(0, first_ngram_start - 1))
77
+
78
+ # # Handle between common n-grams
79
+ # ngram_positions = list(adjusted_ngrams.values())
80
+ # for i in range(len(ngram_positions) - 1):
81
+ # end_prev = ngram_positions[i][-1][1]
82
+ # start_next = ngram_positions[i + 1][0][0]
83
+ # if start_next > end_prev + 1:
84
+ # mask_indices.append(random.randint(end_prev + 1, start_next - 1))
85
+
86
+ # # Handle after the last common n-gram
87
+ # last_ngram_end = ngram_positions[-1][-1][1]
88
+ # if last_ngram_end < len(non_stop_words) - 1:
89
+ # mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
90
+
91
+ # # Mask the chosen indices
92
+ # original_masked_sentence = words[:]
93
+ # # for idx in mask_indices:
94
+ # # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
95
+ # # non_stop_words[idx] = self.tokenizer.mask_token
96
+ # # original_masked_sentence[idx] = self.tokenizer.mask_token
97
+ # for idx in mask_indices:
98
+ # if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
99
+ # continue # Skip if index belongs to common n-grams
100
+ # if remove_stopwords:
101
+ # original_idx = non_stop_word_indices[idx] # Map back to original indices
102
+ # original_masked_sentence[original_idx] = self.tokenizer.mask_token
103
+ # else:
104
+ # original_masked_sentence[idx] = self.tokenizer.mask_token
105
+
106
+
107
+ # return " ".join(original_masked_sentence)
108
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
109
+ """
110
+ Mask one word before the first common n-gram, one between two n-grams,
111
+ and one after the last common n-gram (random selection).
112
+
113
+ Args:
114
+ original_sentence (str): Original sentence
115
+ common_ngrams (dict): Common n-grams and their indices
116
+ remove_stopwords (bool): Whether to remove stop words
117
+
118
+ Returns:
119
+ str: Masked sentence with original stop words retained
120
+ """
121
+ words = original_sentence.split()
122
+ if remove_stopwords:
123
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words]
124
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
125
+ else:
126
+ non_stop_words = words
127
+ non_stop_word_indices = list(range(len(words)))
128
+
129
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
130
+
131
+ # Collect all indices corresponding to common n-grams
132
+ common_ngram_indices = {
133
+ idx for ngram_positions in adjusted_ngrams.values()
134
+ for start, end in ngram_positions
135
+ for idx in range(start, end + 1)
136
+ }
137
+
138
+ mask_indices = []
139
+ # Handle before the first common n-gram
140
+ if adjusted_ngrams:
141
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
142
+ if first_ngram_start > 0:
143
+ potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
144
+ if potential_indices:
145
+ mask_indices.append(random.choice(potential_indices))
146
+
147
+ # Handle between common n-grams
148
+ ngram_positions = list(adjusted_ngrams.values())
149
+ for i in range(len(ngram_positions) - 1):
150
+ end_prev = ngram_positions[i][-1][1]
151
+ start_next = ngram_positions[i + 1][0][0]
152
+ potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
153
+ if potential_indices:
154
+ mask_indices.append(random.choice(potential_indices))
155
+
156
+ # Handle after the last common n-gram
157
+ last_ngram_end = ngram_positions[-1][-1][1]
158
+ if last_ngram_end < len(non_stop_words) - 1:
159
+ potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
160
+ if potential_indices:
161
+ mask_indices.append(random.choice(potential_indices))
162
+
163
+ # Mask the chosen indices
164
+ original_masked_sentence = words[:]
165
+ for idx in mask_indices:
166
+ if remove_stopwords:
167
+ original_idx = non_stop_word_indices[idx] # Map back to original indices
168
+ original_masked_sentence[original_idx] = self.tokenizer.mask_token
169
+ else:
170
+ original_masked_sentence[idx] = self.tokenizer.mask_token
171
+
172
+ return " ".join(original_masked_sentence)
173
+
174
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
175
+ """
176
+ Mask one word before the first common n-gram, one between two n-grams,
177
+ and one after the last common n-gram (highest entropy selection).
178
+
179
+ Args:
180
+ original_sentence (str): Original sentence
181
+ common_ngrams (dict): Common n-grams and their indices
182
+
183
+ Returns:
184
+ str: Masked sentence with original stop words retained
185
+ """
186
+ words = original_sentence.split()
187
+ # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
188
+ if remove_stopwords:
189
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words]
190
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
191
+ else:
192
+ non_stop_words = words
193
+ non_stop_word_indices = list(range(len(words)))
194
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
195
+ entropy_scores = {}
196
+
197
+ for idx, word in enumerate(non_stop_words):
198
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
199
+ continue # Skip words in common n-grams
200
+
201
+ masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
202
+ masked_sentence = " ".join(masked_sentence)
203
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
204
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
205
+
206
+ with torch.no_grad():
207
+ outputs = self.model(input_ids)
208
+ logits = outputs.logits
209
+
210
+ filtered_logits = logits[0, mask_token_index, :]
211
+ probs = torch.softmax(filtered_logits, dim=-1)
212
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
213
+ entropy_scores[idx] = entropy
214
+
215
+ mask_indices = []
216
+
217
+ # Handle before the first common n-gram
218
+ if adjusted_ngrams:
219
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
220
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
221
+ if candidates:
222
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
223
+
224
+ # Handle between common n-grams
225
+ ngram_positions = list(adjusted_ngrams.values())
226
+ for i in range(len(ngram_positions) - 1):
227
+ end_prev = ngram_positions[i][-1][1]
228
+ start_next = ngram_positions[i + 1][0][0]
229
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
230
+ if candidates:
231
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
232
+
233
+ # Handle after the last common n-gram
234
+ last_ngram_end = ngram_positions[-1][-1][1]
235
+ candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
236
+ if candidates:
237
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
238
+
239
+ # Mask the chosen indices
240
+ original_masked_sentence = words[:]
241
+ # for idx in mask_indices:
242
+ # non_stop_words[idx] = self.tokenizer.mask_token
243
+ # original_masked_sentence[idx] = self.tokenizer.mask_token
244
+
245
+ for idx in mask_indices:
246
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
247
+ continue # Skip if index belongs to common n-grams
248
+ if remove_stopwords:
249
+ original_idx = non_stop_word_indices[idx] # Map back to original indices
250
+ original_masked_sentence[original_idx] = self.tokenizer.mask_token
251
+ else:
252
+ original_masked_sentence[idx] = self.tokenizer.mask_token
253
+
254
+
255
+ return " ".join(original_masked_sentence)
256
+
257
+ def calculate_mask_logits(self, masked_sentence):
258
+ """
259
+ Calculate logits for masked tokens in the sentence using BERT.
260
+
261
+ Args:
262
+ masked_sentence (str): Sentence with [MASK] tokens
263
+
264
+ Returns:
265
+ dict: Masked token indices and their logits
266
+ """
267
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
268
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
269
+
270
+ with torch.no_grad():
271
+ outputs = self.model(input_ids)
272
+ logits = outputs.logits
273
+
274
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
275
+ return mask_logits
276
+
277
+ def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
278
+ """
279
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
280
+
281
+ Args:
282
+ original_sentences (list): List of original sentences
283
+ result_dict (dict): Common n-grams and their indices for each sentence
284
+ method (str): Masking method ("random" or "entropy")
285
+
286
+ Returns:
287
+ dict: Masked sentences and their logits for each sentence
288
+ """
289
+ results = {}
290
+
291
+ for sentence, ngrams in result_dict.items():
292
+ if method == "random":
293
+ masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
294
+ elif method == "entropy":
295
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
296
+ else:
297
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
298
+
299
+ logits = self.calculate_mask_logits(masked_sentence)
300
+ results[sentence] = {
301
+ "masked_sentence": masked_sentence,
302
+ "mask_logits": logits
303
+ }
304
+
305
+ return results
306
+
307
+ # Example usage
308
+ if __name__ == "__main__":
309
+ # !!! Working both the cases regardless if the stopword is removed or not
310
+ sentences = [
311
+ "The quick brown fox jumps over the lazy dog.",
312
+ "A speedy brown fox jumps over a lazy dog.",
313
+ "A swift brown fox leaps over the lethargic dog."
314
+ ]
315
+ result_dict ={
316
+ 'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
317
+ 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
318
+ 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
319
+ }
320
+
321
+
322
+ processor = MaskingProcessor()
323
+ results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
324
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
325
+
326
+ for sentence, output in results_random.items():
327
+ print(f"Original Sentence (Random): {sentence}")
328
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
329
+ # # print(f"Mask Logits (Random): {output['mask_logits']}")
330
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
331
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
332
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
333
+ print('--------------------------------')
334
+ # for mask_idx, logits in output["mask_logits"].items():
335
+ # print(f"Logits for [MASK] at position {mask_idx}:")
336
+ # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
337
+
338
+
339
+
340
+
341
+ # result_dict = {
342
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
343
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
344
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
345
+ # }
346
+
347
+
348
+ # print('--------------------------------')
349
+ # for sentence, output in results_entropy.items():
350
+ # print(f"Original Sentence (Entropy): {sentence}")
351
+ # print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
352
+ # # print(f"Mask Logits (Entropy): {output['mask_logits']}")
353
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
354
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
355
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
utils/old/masking/masking_methods_new_work.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self):
15
+
16
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
18
+ self.stop_words = set(stopwords.words('english'))
19
+
20
+ def remove_stopwords(self, words):
21
+ """
22
+ Remove stopwords from the given list of words.
23
+
24
+ Args:
25
+ words (list): List of words.
26
+
27
+ Returns:
28
+ list: List of non-stop words.
29
+ """
30
+ return [word for word in words if word.lower() not in self.stop_words]
31
+
32
+ def adjust_ngram_indices(self, original_words, common_ngrams):
33
+ """
34
+ Adjust indices of common n-grams after removing stopwords.
35
+
36
+ Args:
37
+ original_words (list): Original list of words.
38
+ common_ngrams (dict): Common n-grams and their indices.
39
+
40
+ Returns:
41
+ dict: Adjusted common n-grams with updated indices.
42
+ """
43
+ non_stop_words = self.remove_stopwords(original_words)
44
+ original_to_non_stop = []
45
+ non_stop_idx = 0
46
+
47
+ for original_idx, word in enumerate(original_words):
48
+ if word.lower() not in self.stop_words:
49
+ original_to_non_stop.append((original_idx, non_stop_idx))
50
+ non_stop_idx += 1
51
+
52
+ adjusted_ngrams = {}
53
+ for ngram, positions in common_ngrams.items():
54
+ adjusted_positions = []
55
+ for start, end in positions:
56
+ try:
57
+ new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
58
+ new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
59
+ adjusted_positions.append((new_start, new_end))
60
+ except StopIteration:
61
+ continue # Skip if indices cannot be mapped
62
+ adjusted_ngrams[ngram] = adjusted_positions
63
+
64
+ return adjusted_ngrams
65
+
66
+ def mask_sentence_random(self, sentence, common_ngrams):
67
+ """
68
+ Mask words in the sentence based on the specified rules after removing stopwords.
69
+ """
70
+ original_words = sentence.split()
71
+ print(f' ---- original_words : {original_words} ----- ')
72
+ non_stop_words = self.remove_stopwords(original_words)
73
+ print(f' ---- non_stop_words : {non_stop_words} ----- ')
74
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
75
+ print(f' ---- common_ngrams : {common_ngrams} ----- ')
76
+ print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
77
+
78
+ mask_indices = []
79
+
80
+ # Extract n-gram positions in non-stop words
81
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
82
+
83
+ # Mask a word before the first common n-gram
84
+ if ngram_positions:
85
+ print(f' ---- ngram_positions : {ngram_positions} ----- ')
86
+ first_ngram_start = ngram_positions[0][0]
87
+ print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
88
+ if first_ngram_start > 0:
89
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
90
+ print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
91
+ mask_indices.append(mask_index_before_ngram)
92
+
93
+ # Mask words between common n-grams
94
+ for i in range(len(ngram_positions) - 1):
95
+ end_prev = ngram_positions[i][1]
96
+ print(f' ---- end_prev : {end_prev} ----- ') # END INDICE FROM PREV LOOP FUNKNLKNLKNLKNLKNLKNLSKDNFLKSDHJFLSDJKFH:KLSDHF:LHKSDF:HJKLDFS:HJKLDFSHJK:
97
+ start_next = ngram_positions[i + 1][0]
98
+ print(f' ---- start_next : {start_next} ----- ')
99
+ if start_next > end_prev + 1:
100
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
101
+ print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
102
+ mask_indices.append(mask_index_between_ngrams)
103
+
104
+ # Mask a word after the last common n-gram
105
+ last_ngram_end = ngram_positions[-1][1]
106
+ if last_ngram_end < len(non_stop_words) - 1:
107
+ print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
108
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
109
+ print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
110
+ mask_indices.append(mask_index_after_ngram)
111
+
112
+ # Create mapping from non-stop words to original indices
113
+ non_stop_to_original = {}
114
+ non_stop_idx = 0
115
+ for orig_idx, word in enumerate(original_words):
116
+ if word.lower() not in self.stop_words:
117
+ non_stop_to_original[non_stop_idx] = orig_idx
118
+ non_stop_idx += 1
119
+
120
+ # Map mask indices from non-stop word positions to original positions
121
+ print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
122
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
123
+ print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
124
+
125
+ # Apply masks to the original sentence
126
+ masked_words = original_words.copy()
127
+ for idx in original_mask_indices:
128
+ masked_words[idx] = self.tokenizer.mask_token
129
+
130
+ return " ".join(masked_words)
131
+
132
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
133
+ """
134
+ Mask words in the sentence based on the specified rules after removing stopwords.
135
+ """
136
+ random.seed(42)
137
+ original_words = sentence.split()
138
+ print(f' ---- original_words : {original_words} ----- ')
139
+ non_stop_words = self.remove_stopwords(original_words)
140
+ print(f' ---- non_stop_words : {non_stop_words} ----- ')
141
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
142
+ print(f' ---- common_ngrams : {common_ngrams} ----- ')
143
+ print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
144
+
145
+ mask_indices = []
146
+
147
+ # Extract n-gram positions in non-stop words
148
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
149
+
150
+ # Mask a word before the first common n-gram
151
+ if ngram_positions:
152
+ print(f' ---- ngram_positions : {ngram_positions} ----- ')
153
+ first_ngram_start = ngram_positions[0][0]
154
+ print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
155
+ if first_ngram_start > 0:
156
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
157
+ print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
158
+ mask_indices.append(mask_index_before_ngram)
159
+
160
+ # Mask words between common n-grams
161
+ for i in range(len(ngram_positions) - 1):
162
+ end_prev = ngram_positions[i][1]
163
+ print(f' ---- end_prev : {end_prev} ----- ')
164
+ start_next = ngram_positions[i + 1][0]
165
+ print(f' ---- start_next : {start_next} ----- ')
166
+ if start_next > end_prev + 1:
167
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
168
+ print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
169
+ mask_indices.append(mask_index_between_ngrams)
170
+
171
+ # Mask a word after the last common n-gram
172
+ last_ngram_end = ngram_positions[-1][1]
173
+ if last_ngram_end < len(non_stop_words) - 1:
174
+ print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
175
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
176
+ print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
177
+ mask_indices.append(mask_index_after_ngram)
178
+
179
+ # Create mapping from non-stop words to original indices
180
+ non_stop_to_original = {}
181
+ non_stop_idx = 0
182
+ for orig_idx, word in enumerate(original_words):
183
+ if word.lower() not in self.stop_words:
184
+ non_stop_to_original[non_stop_idx] = orig_idx
185
+ non_stop_idx += 1
186
+
187
+ # Map mask indices from non-stop word positions to original positions
188
+ print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
189
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
190
+ print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
191
+
192
+ # Apply masks to the original sentence
193
+ masked_words = original_words.copy()
194
+ for idx in original_mask_indices:
195
+ masked_words[idx] = self.tokenizer.mask_token
196
+
197
+ return " ".join(masked_words)
198
+
199
+
200
+ def calculate_word_entropy(self, sentence, word_position):
201
+ """
202
+ Calculate entropy for a specific word position in the sentence.
203
+
204
+ Args:
205
+ sentence (str): The input sentence
206
+ word_position (int): Position of the word to calculate entropy for
207
+
208
+ Returns:
209
+ float: Entropy value for the word
210
+ """
211
+ words = sentence.split()
212
+ masked_words = words.copy()
213
+ masked_words[word_position] = self.tokenizer.mask_token
214
+ masked_sentence = " ".join(masked_words)
215
+
216
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
217
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
218
+
219
+ with torch.no_grad():
220
+ outputs = self.model(input_ids)
221
+ logits = outputs.logits
222
+
223
+ # Get probabilities for the masked position
224
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
225
+ # Calculate entropy: -sum(p * log(p))
226
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
227
+
228
+ return entropy.item()
229
+
230
+ def mask_sentence_entropy(self, sentence, common_ngrams):
231
+ """
232
+ Mask words in the sentence based on entropy, following n-gram positioning rules.
233
+
234
+ Args:
235
+ sentence (str): Original sentence
236
+ common_ngrams (dict): Common n-grams and their indices
237
+
238
+ Returns:
239
+ str: Masked sentence
240
+ """
241
+ original_words = sentence.split()
242
+ non_stop_words = self.remove_stopwords(original_words)
243
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
244
+
245
+ # Create mapping from non-stop words to original indices
246
+ non_stop_to_original = {}
247
+ original_to_non_stop = {}
248
+ non_stop_idx = 0
249
+ for orig_idx, word in enumerate(original_words):
250
+ if word.lower() not in self.stop_words:
251
+ non_stop_to_original[non_stop_idx] = orig_idx
252
+ original_to_non_stop[orig_idx] = non_stop_idx
253
+ non_stop_idx += 1
254
+
255
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
256
+ mask_indices = []
257
+
258
+ if ngram_positions:
259
+ # Handle words before first n-gram
260
+ first_ngram_start = ngram_positions[0][0]
261
+ if first_ngram_start > 0:
262
+ # Calculate entropy for all candidate positions
263
+ candidate_positions = range(0, first_ngram_start)
264
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
265
+ for pos in candidate_positions]
266
+ # Select position with highest entropy
267
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
268
+
269
+ # Handle words between n-grams
270
+ for i in range(len(ngram_positions) - 1):
271
+ end_prev = ngram_positions[i][1]
272
+ start_next = ngram_positions[i + 1][0]
273
+ if start_next > end_prev + 1:
274
+ candidate_positions = range(end_prev + 1, start_next)
275
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
276
+ for pos in candidate_positions]
277
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
278
+
279
+ # Handle words after last n-gram
280
+ last_ngram_end = ngram_positions[-1][1]
281
+ if last_ngram_end < len(non_stop_words) - 1:
282
+ candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
283
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
284
+ for pos in candidate_positions]
285
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
286
+
287
+ # Map mask indices to original sentence positions and apply masks
288
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
289
+ masked_words = original_words.copy()
290
+ for idx in original_mask_indices:
291
+ masked_words[idx] = self.tokenizer.mask_token
292
+
293
+ return " ".join(masked_words)
294
+
295
+
296
+ def calculate_mask_logits(self, masked_sentence):
297
+ """
298
+ Calculate logits for masked tokens in the sentence using BERT.
299
+
300
+ Args:
301
+ masked_sentence (str): Sentence with [MASK] tokens.
302
+
303
+ Returns:
304
+ dict: Masked token indices and their logits.
305
+ """
306
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
307
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
308
+
309
+ with torch.no_grad():
310
+ outputs = self.model(input_ids)
311
+ logits = outputs.logits
312
+
313
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
314
+ return mask_logits
315
+
316
+ def process_sentences(self, sentences, result_dict, method="random"):
317
+ """
318
+ Process sentences and calculate logits for masked tokens.
319
+
320
+ Args:
321
+ sentences (list): List of sentences
322
+ result_dict (dict): Dictionary of common n-grams
323
+ method (str): Masking method ("random" or "entropy")
324
+
325
+ Returns:
326
+ dict: Masked sentences and logits for each sentence
327
+ """
328
+ results = {}
329
+
330
+ for sentence, ngrams in result_dict.items():
331
+ if method == "random":
332
+ masked_sentence = self.mask_sentence_random(sentence, ngrams)
333
+ elif method == "pseudorandom":
334
+ masked_sentence = self.mask_sentence_pseudorandom(sentence, ngrams)
335
+ else: # entropy
336
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
337
+
338
+ logits = self.calculate_mask_logits(masked_sentence)
339
+ results[sentence] = {
340
+ "masked_sentence": masked_sentence,
341
+ "mask_logits": logits
342
+ }
343
+
344
+ return results
345
+
346
+
347
+
348
+ if __name__ == "__main__":
349
+ # !!! Working both the cases regardless if the stopword is removed or not
350
+ sentences = [
351
+ "The quick brown fox jumps over the lazy dog everyday.",
352
+ # "A speedy brown fox jumps over a lazy dog.",
353
+ # "A swift brown fox leaps over the lethargic dog."
354
+ ]
355
+ result_dict ={
356
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
357
+ # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
358
+ # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
359
+ }
360
+
361
+
362
+ processor = MaskingProcessor()
363
+ # results_random = processor.process_sentences(sentences, result_dict)
364
+ results_entropy = processor.process_sentences(sentences, result_dict, method="random")
365
+
366
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
367
+
368
+ for sentence, output in results_entropy.items():
369
+ print(f"Original Sentence (Random): {sentence}")
370
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
371
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
372
+ print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
373
+ print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
374
+ print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
375
+ print('--------------------------------')
376
+ for mask_idx, logits in output["mask_logits"].items():
377
+ print(f"Logits for [MASK] at position {mask_idx}:")
378
+ print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
379
+ print(f' len(logits) : {len(logits)}')
380
+
381
+
382
+
383
+
384
+ # -------------------------------------------------------------------------------------------
385
+ # def mask_sentence(self, sentence, common_ngrams):
386
+ # """
387
+ # Mask words in the sentence based on the specified rules after removing stopwords.
388
+
389
+ # Args:
390
+ # sentence (str): Original sentence.
391
+ # common_ngrams (dict): Common n-grams and their indices.
392
+
393
+ # Returns:
394
+ # str: Masked sentence.
395
+ # """
396
+ # original_words = sentence.split()
397
+ # print(f' ---- original_words : {original_words} ----- ')
398
+ # non_stop_words = self.remove_stopwords(original_words)
399
+ # print(f' ---- non_stop_words : {non_stop_words} ----- ')
400
+ # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
401
+ # print(f' ---- common_ngrams : {common_ngrams} ----- ')
402
+ # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
403
+
404
+ # mask_indices = []
405
+
406
+ # # Extract n-gram positions in non-stop words
407
+ # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
408
+ # print(f' ---- ngram_positions : {ngram_positions} ----- ')
409
+ # # Mask a word before the first common n-gram
410
+ # if ngram_positions:
411
+ # first_ngram_start = ngram_positions[0][0]
412
+ # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
413
+ # if first_ngram_start > 0:
414
+ # mask_index_before_ngram = random.randint(0, first_ngram_start-1)
415
+ # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
416
+ # mask_indices.append(mask_index_before_ngram)
417
+
418
+ # # Mask words between common n-grams
419
+ # for i in range(len(ngram_positions) - 1):
420
+ # end_prev = ngram_positions[i][1]
421
+ # print(f' ---- end_prev : {end_prev} ----- ')
422
+ # start_next = ngram_positions[i + 1][0]
423
+ # print(f' ---- start_next : {start_next} ----- ')
424
+ # if start_next > end_prev + 1:
425
+ # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
426
+ # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
427
+ # mask_indices.append(mask_index_between_ngrams)
428
+
429
+ # # Mask a word after the last common n-gram
430
+ # last_ngram_end = ngram_positions[-1][1]
431
+ # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
432
+ # if last_ngram_end < len(non_stop_words) - 1:
433
+ # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
434
+ # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
435
+ # mask_indices.append(mask_index_after_ngram)
436
+
437
+ # # Map mask indices back to original sentence
438
+ # adjusted_indices = [
439
+ # orig for orig, non_stop in enumerate(original_words)
440
+ # if non_stop in mask_indices
441
+ # ]
442
+
443
+ # # Apply masks to the original sentence
444
+ # for idx in adjusted_indices:
445
+ # original_words[idx] = self.tokenizer.mask_token
446
+
447
+ # return " ".join(original_words)
utils/old/masking/masking_methods_ok_working.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self, ):
15
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
17
+ self.stop_words = set(stopwords.words('english'))
18
+
19
+ def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
20
+ """
21
+ Adjust indices of common n-grams after removing stop words.
22
+
23
+ Args:
24
+ words (list): List of words in the original sentence.
25
+ common_ngrams (dict): Common n-grams and their indices.
26
+
27
+ Returns:
28
+ dict: Adjusted common n-grams and their indices.
29
+ """
30
+ if not remove_stopwords:
31
+ return common_ngrams
32
+
33
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
34
+ adjusted_ngrams = {}
35
+
36
+ for ngram, positions in common_ngrams.items():
37
+ adjusted_positions = []
38
+ for start, end in positions:
39
+ try:
40
+ new_start = non_stop_word_indices.index(start)
41
+ new_end = non_stop_word_indices.index(end)
42
+ adjusted_positions.append((new_start, new_end))
43
+ except ValueError:
44
+ continue # Skip if indices cannot be mapped
45
+ adjusted_ngrams[ngram] = adjusted_positions
46
+
47
+ return adjusted_ngrams
48
+
49
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
50
+ """
51
+ Mask one word before the first common n-gram, one between two n-grams,
52
+ and one after the last common n-gram (random selection).
53
+
54
+ Args:
55
+ original_sentence (str): Original sentence
56
+ common_ngrams (dict): Common n-grams and their indices
57
+
58
+ Returns:
59
+ str: Masked sentence with original stop words retained
60
+ """
61
+ words = original_sentence.split()
62
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
63
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
64
+
65
+ mask_indices = []
66
+ # Handle before the first common n-gram
67
+ if adjusted_ngrams:
68
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
69
+ if first_ngram_start > 0:
70
+ mask_indices.append(random.randint(0, first_ngram_start - 1))
71
+
72
+ # Handle between common n-grams
73
+ ngram_positions = list(adjusted_ngrams.values())
74
+ for i in range(len(ngram_positions) - 1):
75
+ end_prev = ngram_positions[i][-1][1]
76
+ start_next = ngram_positions[i + 1][0][0]
77
+ if start_next > end_prev + 1:
78
+ mask_indices.append(random.randint(end_prev + 1, start_next - 1))
79
+
80
+ # Handle after the last common n-gram
81
+ last_ngram_end = ngram_positions[-1][-1][1]
82
+ if last_ngram_end < len(non_stop_words) - 1:
83
+ mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
84
+
85
+ # Mask the chosen indices
86
+ original_masked_sentence = words[:]
87
+ for idx in mask_indices:
88
+ if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
89
+ non_stop_words[idx] = self.tokenizer.mask_token
90
+ original_masked_sentence[idx] = self.tokenizer.mask_token
91
+
92
+ return " ".join(original_masked_sentence)
93
+
94
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
95
+ """
96
+ Mask one word before the first common n-gram, one between two n-grams,
97
+ and one after the last common n-gram (highest entropy selection).
98
+
99
+ Args:
100
+ original_sentence (str): Original sentence
101
+ common_ngrams (dict): Common n-grams and their indices
102
+
103
+ Returns:
104
+ str: Masked sentence with original stop words retained
105
+ """
106
+ words = original_sentence.split()
107
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
108
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
109
+ entropy_scores = {}
110
+
111
+ for idx, word in enumerate(non_stop_words):
112
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
113
+ continue # Skip words in common n-grams
114
+
115
+ masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
116
+ masked_sentence = " ".join(masked_sentence)
117
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
118
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
119
+
120
+ with torch.no_grad():
121
+ outputs = self.model(input_ids)
122
+ logits = outputs.logits
123
+
124
+ filtered_logits = logits[0, mask_token_index, :]
125
+ probs = torch.softmax(filtered_logits, dim=-1)
126
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
127
+ entropy_scores[idx] = entropy
128
+
129
+ mask_indices = []
130
+
131
+ # Handle before the first common n-gram
132
+ if adjusted_ngrams:
133
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
134
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
135
+ if candidates:
136
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
137
+
138
+ # Handle between common n-grams
139
+ ngram_positions = list(adjusted_ngrams.values())
140
+ for i in range(len(ngram_positions) - 1):
141
+ end_prev = ngram_positions[i][-1][1]
142
+ start_next = ngram_positions[i + 1][0][0]
143
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
144
+ if candidates:
145
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
146
+
147
+ # Handle after the last common n-gram
148
+ last_ngram_end = ngram_positions[-1][-1][1]
149
+ candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
150
+ if candidates:
151
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
152
+
153
+ # Mask the chosen indices
154
+ original_masked_sentence = words[:]
155
+ for idx in mask_indices:
156
+ non_stop_words[idx] = self.tokenizer.mask_token
157
+ original_masked_sentence[idx] = self.tokenizer.mask_token
158
+
159
+ return " ".join(original_masked_sentence)
160
+
161
+ def calculate_mask_logits(self, masked_sentence):
162
+ """
163
+ Calculate logits for masked tokens in the sentence using BERT.
164
+
165
+ Args:
166
+ masked_sentence (str): Sentence with [MASK] tokens
167
+
168
+ Returns:
169
+ dict: Masked token indices and their logits
170
+ """
171
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
172
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
173
+
174
+ with torch.no_grad():
175
+ outputs = self.model(input_ids)
176
+ logits = outputs.logits
177
+
178
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
179
+ return mask_logits
180
+
181
+ def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
182
+ """
183
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
184
+
185
+ Args:
186
+ original_sentences (list): List of original sentences
187
+ result_dict (dict): Common n-grams and their indices for each sentence
188
+ method (str): Masking method ("random" or "entropy")
189
+
190
+ Returns:
191
+ dict: Masked sentences and their logits for each sentence
192
+ """
193
+ results = {}
194
+
195
+ for sentence, ngrams in result_dict.items():
196
+ if method == "random":
197
+ masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
198
+ elif method == "entropy":
199
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
200
+ else:
201
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
202
+
203
+ logits = self.calculate_mask_logits(masked_sentence)
204
+ results[sentence] = {
205
+ "masked_sentence": masked_sentence,
206
+ "mask_logits": logits
207
+ }
208
+
209
+ return results
210
+
211
+ # Example usage
212
+ if __name__ == "__main__":
213
+ # !!! Working both the cases regardless if the stopword is removed or not
214
+ sentences = [
215
+ "The quick brown fox jumps over the lazy dog.",
216
+ "A quick brown dog outpaces a lazy fox.",
217
+ "Quick brown animals leap over lazy obstacles."
218
+ ]
219
+
220
+ result_dict = {
221
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
222
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
223
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
224
+ }
225
+
226
+ # result_dict = {
227
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
228
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
229
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
230
+ # }
231
+
232
+ processor = MaskingProcessor()
233
+ results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
234
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
235
+
236
+ for sentence, output in results_random.items():
237
+ print(f"Original Sentence (Random): {sentence}")
238
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
239
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
240
+ print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
241
+ print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
242
+ print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
243
+ print('--------------------------------')
244
+ for mask_idx, logits in output["mask_logits"].items():
245
+ print(f"Logits for [MASK] at position {mask_idx}:")
246
+ print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
247
+
248
+
249
+
250
+ # print('--------------------------------')
251
+ # for sentence, output in results_entropy.items():
252
+ # print(f"Original Sentence (Entropy): {sentence}")
253
+ # print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
254
+ # # print(f"Mask Logits (Entropy): {output['mask_logits']}")
255
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
256
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
257
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
utils/old/masking/masking_methods_v1_working.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
8
+ # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
9
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10
+
11
+
12
+ # Ensure stopwords are downloaded
13
+ try:
14
+ nltk.data.find('corpora/stopwords')
15
+ except LookupError:
16
+ nltk.download('stopwords')
17
+
18
+ class MaskingProcessor:
19
+ def __init__(self):
20
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
21
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
22
+ self.stop_words = set(stopwords.words('english'))
23
+
24
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False):
25
+ """
26
+ Mask one word before the first common n-gram, one between two n-grams,
27
+ and one after the last common n-gram (random selection).
28
+
29
+ Args:
30
+ original_sentence (str): Original sentence
31
+ common_ngrams (dict): Common n-grams and their indices
32
+
33
+ Returns:
34
+ str: Masked sentence
35
+ """
36
+ if remove_stopwords:
37
+ words = original_sentence.split()
38
+ words = [word for word in words if word not in self.stop_words]
39
+ else:
40
+ words = original_sentence.split()
41
+
42
+ mask_indices = []
43
+ # Handle before the first common n-gram
44
+ if common_ngrams:
45
+ first_ngram_start = list(common_ngrams.values())[0][0][0]
46
+ if first_ngram_start > 0:
47
+ mask_indices.append(random.randint(0, first_ngram_start - 1))
48
+
49
+ # Handle between common n-grams
50
+ ngram_positions = list(common_ngrams.values())
51
+ for i in range(len(ngram_positions) - 1):
52
+ end_prev = ngram_positions[i][-1][1]
53
+ start_next = ngram_positions[i + 1][0][0]
54
+ if start_next > end_prev + 1:
55
+ mask_indices.append(random.randint(end_prev + 1, start_next - 1))
56
+
57
+ # Handle after the last common n-gram
58
+ last_ngram_end = ngram_positions[-1][-1][1]
59
+ if last_ngram_end < len(words) - 1:
60
+ mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1))
61
+
62
+ # Mask the chosen indices
63
+ for idx in mask_indices:
64
+ if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
65
+ words[idx] = self.tokenizer.mask_token
66
+
67
+ return " ".join(words)
68
+
69
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False):
70
+ """
71
+ Mask one word before the first common n-gram, one between two n-grams,
72
+ and one after the last common n-gram (highest entropy selection).
73
+
74
+ Args:
75
+ original_sentence (str): Original sentence
76
+ common_ngrams (dict): Common n-grams and their indices
77
+
78
+ Returns:
79
+ str: Masked sentence
80
+ """
81
+ if remove_stopwords:
82
+ words = original_sentence.split()
83
+ words = [word for word in words if word not in self.stop_words]
84
+ else:
85
+ words = original_sentence.split()
86
+ entropy_scores = {}
87
+
88
+ for idx, word in enumerate(words):
89
+ if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
90
+ continue # Skip words in common n-grams
91
+
92
+ masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:]
93
+ masked_sentence = " ".join(masked_sentence)
94
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
95
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
96
+
97
+ with torch.no_grad():
98
+ outputs = self.model(input_ids)
99
+ logits = outputs.logits
100
+
101
+ filtered_logits = logits[0, mask_token_index, :]
102
+ probs = torch.softmax(filtered_logits, dim=-1)
103
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
104
+ entropy_scores[idx] = entropy
105
+
106
+ mask_indices = []
107
+
108
+ # Handle before the first common n-gram
109
+ if common_ngrams:
110
+ first_ngram_start = list(common_ngrams.values())[0][0][0]
111
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
112
+ if candidates:
113
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
114
+
115
+ # Handle between common n-grams
116
+ ngram_positions = list(common_ngrams.values())
117
+ for i in range(len(ngram_positions) - 1):
118
+ end_prev = ngram_positions[i][-1][1]
119
+ start_next = ngram_positions[i + 1][0][0]
120
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
121
+ if candidates:
122
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
123
+
124
+ # Handle after the last common n-gram
125
+ last_ngram_end = ngram_positions[-1][-1][1]
126
+ candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores]
127
+ if candidates:
128
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
129
+
130
+ # Mask the chosen indices
131
+ for idx in mask_indices:
132
+ words[idx] = self.tokenizer.mask_token
133
+
134
+ return " ".join(words)
135
+
136
+ def calculate_mask_logits(self, masked_sentence):
137
+ """
138
+ Calculate logits for masked tokens in the sentence using BERT.
139
+
140
+ Args:
141
+ masked_sentence (str): Sentence with [MASK] tokens
142
+
143
+ Returns:
144
+ dict: Masked token indices and their logits
145
+ """
146
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
147
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
148
+
149
+ with torch.no_grad():
150
+ outputs = self.model(input_ids)
151
+ logits = outputs.logits
152
+
153
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
154
+ return mask_logits
155
+
156
+ def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"):
157
+ """
158
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
159
+
160
+ Args:
161
+ original_sentences (list): List of original sentences
162
+ result_dict (dict): Common n-grams and their indices for each sentence
163
+ method (str): Masking method ("random" or "entropy")
164
+
165
+ Returns:
166
+ dict: Masked sentences and their logits for each sentence
167
+ """
168
+ results = {}
169
+
170
+ for sentence, ngrams in result_dict.items():
171
+ if method == "random":
172
+ masked_sentence = self.mask_sentence_random(sentence, ngrams)
173
+ elif method == "entropy":
174
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
175
+ else:
176
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
177
+
178
+ logits = self.calculate_mask_logits(masked_sentence)
179
+ results[sentence] = {
180
+ "masked_sentence": masked_sentence,
181
+ "mask_logits": logits
182
+ }
183
+
184
+ return results
185
+
186
+ # Example usage
187
+ if __name__ == "__main__":
188
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
189
+ # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
190
+
191
+ sentences = [
192
+ "The quick brown fox jumps over the lazy dog.",
193
+ "A quick brown dog outpaces a lazy fox.",
194
+ "Quick brown animals leap over lazy obstacles."
195
+ ]
196
+
197
+ result_dict = {
198
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
199
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
200
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
201
+ }
202
+
203
+ # result_dict = {
204
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
205
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
206
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
207
+ # }
208
+
209
+ processor = MaskingProcessor()
210
+ results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random")
211
+ results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy")
212
+
213
+ for sentence, output in results_random.items():
214
+ print(f"Original Sentence (Random): {sentence}")
215
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
216
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
217
+
218
+ for sentence, output in results_entropy.items():
219
+ print(f"Original Sentence (Entropy): {sentence}")
220
+ print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
221
+ # print(f"Mask Logits (Entropy): {output['mask_logits']}")
222
+
223
+
224
+
225
+
226
+ '''
227
+ result_dict = {
228
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
229
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
230
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
231
+ }
232
+
233
+ '''
utils/old/masking_methods_final_copy.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
7
+
8
+
9
+ # Ensure stopwords are downloaded
10
+ try:
11
+ nltk.data.find('corpora/stopwords')
12
+ except LookupError:
13
+ nltk.download('stopwords')
14
+
15
+ class MaskingProcessor:
16
+ # def __init__(self, tokenizer, model):
17
+ def __init__(self):
18
+ # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
+ # self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
20
+
21
+ # self.tokenizer = tokenizer
22
+ # self.model = model
23
+
24
+ self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
25
+ self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
26
+
27
+ # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
28
+ # self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
29
+
30
+ self.stop_words = set(stopwords.words('english'))
31
+
32
+ def remove_stopwords(self, words):
33
+ """
34
+ Remove stopwords from the given list of words.
35
+
36
+ Args:
37
+ words (list): List of words.
38
+
39
+ Returns:
40
+ list: List of non-stop words.
41
+ """
42
+ return [word for word in words if word.lower() not in self.stop_words]
43
+
44
+ def adjust_ngram_indices(self, original_words, common_ngrams):
45
+ """
46
+ Adjust indices of common n-grams after removing stopwords.
47
+
48
+ Args:
49
+ original_words (list): Original list of words.
50
+ common_ngrams (dict): Common n-grams and their indices.
51
+
52
+ Returns:
53
+ dict: Adjusted common n-grams with updated indices.
54
+ """
55
+ non_stop_words = self.remove_stopwords(original_words)
56
+ original_to_non_stop = []
57
+ non_stop_idx = 0
58
+
59
+ for original_idx, word in enumerate(original_words):
60
+ if word.lower() not in self.stop_words:
61
+ original_to_non_stop.append((original_idx, non_stop_idx))
62
+ non_stop_idx += 1
63
+
64
+ adjusted_ngrams = {}
65
+ for ngram, positions in common_ngrams.items():
66
+ adjusted_positions = []
67
+ for start, end in positions:
68
+ try:
69
+ new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
70
+ new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
71
+ adjusted_positions.append((new_start, new_end))
72
+ except StopIteration:
73
+ continue # Skip if indices cannot be mapped
74
+ adjusted_ngrams[ngram] = adjusted_positions
75
+
76
+ return adjusted_ngrams
77
+
78
+ def mask_sentence_random(self, sentence, common_ngrams):
79
+ """
80
+ Mask words in the sentence based on the specified rules after removing stopwords.
81
+ """
82
+ # Split sentence into words
83
+ original_words = sentence.split()
84
+
85
+ # Handle punctuation at the end
86
+ has_punctuation = False
87
+ punctuation = None
88
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
89
+ has_punctuation = True
90
+ punctuation = original_words[-1][-1]
91
+ original_words = original_words[:-1]
92
+
93
+ print(f' ---- original_words : {original_words} ----- ')
94
+
95
+ # Process words without punctuation
96
+ non_stop_words = self.remove_stopwords(original_words)
97
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
98
+
99
+ # Rest of the existing function code...
100
+ mask_indices = []
101
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
102
+
103
+ if ngram_positions:
104
+ first_ngram_start = ngram_positions[0][0]
105
+ if first_ngram_start > 0:
106
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
107
+ mask_indices.append(mask_index_before_ngram)
108
+
109
+ # Mask words between common n-grams
110
+ for i in range(len(ngram_positions) - 1):
111
+ end_prev = ngram_positions[i][1]
112
+ start_next = ngram_positions[i + 1][0]
113
+ if start_next > end_prev + 1:
114
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
115
+ mask_indices.append(mask_index_between_ngrams)
116
+
117
+ # Mask a word after the last common n-gram
118
+ last_ngram_end = ngram_positions[-1][1]
119
+ if last_ngram_end < len(non_stop_words) - 1:
120
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
121
+ mask_indices.append(mask_index_after_ngram)
122
+
123
+ # Create mapping from non-stop words to original indices
124
+ non_stop_to_original = {}
125
+ non_stop_idx = 0
126
+ for orig_idx, word in enumerate(original_words):
127
+ if word.lower() not in self.stop_words:
128
+ non_stop_to_original[non_stop_idx] = orig_idx
129
+ non_stop_idx += 1
130
+
131
+ # Map mask indices and apply masks
132
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
133
+ masked_words = original_words.copy()
134
+ for idx in original_mask_indices:
135
+ masked_words[idx] = self.tokenizer.mask_token
136
+ # masked_words[idx] = '<mask>' # for roberta
137
+
138
+ # Add back punctuation if it existed
139
+ if has_punctuation:
140
+ masked_words.append(punctuation)
141
+
142
+ print(f' ***** masked_words at end : {masked_words} ***** ')
143
+ print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
144
+ print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
145
+
146
+ return " ".join(masked_words), original_mask_indices
147
+
148
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
149
+ """
150
+ Mask words in the sentence based on the specified rules after removing stopwords.
151
+ """
152
+ # Split sentence into words
153
+ random.seed(3)
154
+ original_words = sentence.split()
155
+
156
+ # Handle punctuation at the end
157
+ has_punctuation = False
158
+ punctuation = None
159
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
160
+ has_punctuation = True
161
+ punctuation = original_words[-1][-1]
162
+ original_words = original_words[:-1]
163
+
164
+ print(f' ---- original_words : {original_words} ----- ')
165
+
166
+ # Process words without punctuation
167
+ non_stop_words = self.remove_stopwords(original_words)
168
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
169
+
170
+ # Rest of the existing function code...
171
+ mask_indices = []
172
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
173
+
174
+ if ngram_positions:
175
+ first_ngram_start = ngram_positions[0][0]
176
+ if first_ngram_start > 0:
177
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
178
+ mask_indices.append(mask_index_before_ngram)
179
+
180
+ # Mask words between common n-grams
181
+ for i in range(len(ngram_positions) - 1):
182
+ end_prev = ngram_positions[i][1]
183
+ start_next = ngram_positions[i + 1][0]
184
+ if start_next > end_prev + 1:
185
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
186
+ mask_indices.append(mask_index_between_ngrams)
187
+
188
+ # Mask a word after the last common n-gram
189
+ last_ngram_end = ngram_positions[-1][1]
190
+ if last_ngram_end < len(non_stop_words) - 1:
191
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
192
+ mask_indices.append(mask_index_after_ngram)
193
+
194
+ # Create mapping from non-stop words to original indices
195
+ non_stop_to_original = {}
196
+ non_stop_idx = 0
197
+ for orig_idx, word in enumerate(original_words):
198
+ if word.lower() not in self.stop_words:
199
+ non_stop_to_original[non_stop_idx] = orig_idx
200
+ non_stop_idx += 1
201
+
202
+ # Map mask indices and apply masks
203
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
204
+ masked_words = original_words.copy()
205
+ for idx in original_mask_indices:
206
+ masked_words[idx] = self.tokenizer.mask_token
207
+ # masked_words[idx] = '<mask>' # for roberta
208
+
209
+ # Add back punctuation if it existed
210
+ if has_punctuation:
211
+ masked_words.append(punctuation)
212
+
213
+ print(f' ***** masked_words at end : {masked_words} ***** ')
214
+ print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
215
+ print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
216
+
217
+ return " ".join(masked_words), original_mask_indices
218
+
219
+
220
+ def calculate_word_entropy(self, sentence, word_position):
221
+ """
222
+ Calculate entropy for a specific word position in the sentence.
223
+
224
+ Args:
225
+ sentence (str): The input sentence
226
+ word_position (int): Position of the word to calculate entropy for
227
+
228
+ Returns:
229
+ float: Entropy value for the word
230
+ """
231
+ words = sentence.split()
232
+ masked_words = words.copy()
233
+ masked_words[word_position] = self.tokenizer.mask_token
234
+ masked_sentence = " ".join(masked_words)
235
+
236
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
237
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
238
+
239
+ with torch.no_grad():
240
+ outputs = self.model(input_ids)
241
+ logits = outputs.logits
242
+
243
+ # Get probabilities for the masked position
244
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
245
+ # Calculate entropy: -sum(p * log(p))
246
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
247
+
248
+ return entropy.item()
249
+
250
+ def mask_sentence_entropy(self, sentence, common_ngrams):
251
+ """
252
+ Mask words in the sentence based on entropy, following n-gram positioning rules.
253
+
254
+ Args:
255
+ sentence (str): Original sentence
256
+ common_ngrams (dict): Common n-grams and their indices
257
+
258
+ Returns:
259
+ str: Masked sentence
260
+ """
261
+ # Split sentence into words
262
+ original_words = sentence.split()
263
+
264
+ # Handle punctuation at the end
265
+ has_punctuation = False
266
+ punctuation = None
267
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
268
+ has_punctuation = True
269
+ punctuation = original_words[-1][-1]
270
+ original_words = original_words[:-1]
271
+
272
+ # Process words without punctuation
273
+ non_stop_words = self.remove_stopwords(original_words)
274
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
275
+
276
+ # Create mapping from non-stop words to original indices
277
+ non_stop_to_original = {}
278
+ original_to_non_stop = {}
279
+ non_stop_idx = 0
280
+ for orig_idx, word in enumerate(original_words):
281
+ if word.lower() not in self.stop_words:
282
+ non_stop_to_original[non_stop_idx] = orig_idx
283
+ original_to_non_stop[orig_idx] = non_stop_idx
284
+ non_stop_idx += 1
285
+
286
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
287
+ mask_indices = []
288
+
289
+ if ngram_positions:
290
+ # Handle words before first n-gram
291
+ first_ngram_start = ngram_positions[0][0]
292
+ if first_ngram_start > 0:
293
+ candidate_positions = range(0, first_ngram_start)
294
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
295
+ for pos in candidate_positions]
296
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
297
+
298
+ # Handle words between n-grams
299
+ for i in range(len(ngram_positions) - 1):
300
+ end_prev = ngram_positions[i][1]
301
+ start_next = ngram_positions[i + 1][0]
302
+ if start_next > end_prev + 1:
303
+ candidate_positions = range(end_prev + 1, start_next)
304
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
305
+ for pos in candidate_positions]
306
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
307
+
308
+ # Handle words after last n-gram
309
+ last_ngram_end = ngram_positions[-1][1]
310
+ if last_ngram_end < len(non_stop_words) - 1:
311
+ candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
312
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
313
+ for pos in candidate_positions]
314
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
315
+
316
+ # Map mask indices to original sentence positions and apply masks
317
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
318
+ masked_words = original_words.copy()
319
+ for idx in original_mask_indices:
320
+ masked_words[idx] = self.tokenizer.mask_token
321
+
322
+ # Add back punctuation if it existed
323
+ if has_punctuation:
324
+ masked_words.append(punctuation)
325
+
326
+ return " ".join(masked_words), original_mask_indices
327
+
328
+ def calculate_mask_logits(self, original_sentence, original_mask_indices):
329
+ """
330
+ Calculate logits for masked tokens in the sentence using BERT.
331
+
332
+ Args:
333
+ original_sentence (str): Original sentence without masks
334
+ original_mask_indices (list): List of indices to mask
335
+
336
+ Returns:
337
+ dict: Masked token indices and their logits
338
+ """
339
+ print('==========================================================================================================')
340
+ words = original_sentence.split()
341
+ print(f' ##### calculate_mask_logits >> words : {words} ##### ')
342
+ mask_logits = {}
343
+
344
+ for idx in original_mask_indices:
345
+ # Create a copy of words and mask the current position
346
+ print(f' ---- idx : {idx} ----- ')
347
+ masked_words = words.copy()
348
+ masked_words[idx] = '[MASK]'
349
+ # masked_words[idx] = '<mask>' # for roberta
350
+ masked_sentence = " ".join(masked_words)
351
+ print(f' ---- masked_sentence : {masked_sentence} ----- ')
352
+
353
+ # Calculate logits for the current mask
354
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
355
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
356
+
357
+ with torch.no_grad():
358
+ outputs = self.model(input_ids)
359
+ logits = outputs.logits
360
+
361
+ # Extract logits for the masked position
362
+ mask_logits_tensor = logits[0, mask_token_index, :]
363
+
364
+ # Get top logits and corresponding tokens
365
+ top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates
366
+
367
+ # Convert token IDs to words and filter out subword tokens
368
+ top_tokens = []
369
+ top_logits = []
370
+ seen_words = set() # To keep track of unique words
371
+
372
+ for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
373
+ token = self.tokenizer.convert_ids_to_tokens(token_id.item())
374
+
375
+ # Skip if it's a subword token (starts with ##)
376
+ if token.startswith('##'):
377
+ continue
378
+
379
+ # Convert token to proper word
380
+ word = self.tokenizer.convert_tokens_to_string([token]).strip()
381
+
382
+ # Only add if it's a new word and not empty
383
+ if word and word not in seen_words:
384
+ seen_words.add(word)
385
+ top_tokens.append(word)
386
+ top_logits.append(logit.item())
387
+
388
+ # Break if we have 50 unique complete words
389
+ if len(top_tokens) == 50:
390
+ break
391
+
392
+ # print(f' ---- top_tokens : {top_tokens} ----- ')
393
+
394
+ # Store results
395
+ mask_logits[idx] = {
396
+ "tokens": top_tokens,
397
+ "logits": top_logits
398
+ }
399
+
400
+ return mask_logits
401
+
402
+ # def calculate_mask_logits(self, original_sentence, original_mask_indices):
403
+ # """
404
+ # Calculate logits for masked tokens in the sentence using BERT.
405
+
406
+ # Args:
407
+ # original_sentence (str): Original sentence without masks
408
+ # original_mask_indices (list): List of indices to mask
409
+
410
+ # Returns:
411
+ # dict: Masked token indices and their logits
412
+ # """
413
+ # words = original_sentence.split()
414
+ # print(f' ##### calculate_mask_logits >> words : {words} ##### ')
415
+ # mask_logits = {}
416
+
417
+ # for idx in original_mask_indices:
418
+ # # Create a copy of words and mask the current position
419
+ # print(f' ---- idx : {idx} ----- ')
420
+ # masked_words = words.copy()
421
+ # print(f' ---- words : {masked_words} ----- ')
422
+ # # masked_words[idx] = self.tokenizer.mask_token
423
+ # masked_words[idx] = '[MASK]'
424
+ # print(f' ---- masked_words : {masked_words} ----- ')
425
+ # masked_sentence = " ".join(masked_words)
426
+ # print(f' ---- masked_sentence : {masked_sentence} ----- ')
427
+
428
+ # # Calculate logits for the current mask
429
+ # input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
430
+ # mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
431
+
432
+ # with torch.no_grad():
433
+ # outputs = self.model(input_ids)
434
+ # logits = outputs.logits
435
+
436
+ # # Extract logits for the masked position
437
+ # mask_logits_tensor = logits[0, mask_token_index, :]
438
+
439
+ # # Get top 50 logits and corresponding tokens
440
+ # top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1)
441
+
442
+ # # Convert token IDs to words
443
+ # top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]]
444
+ # print(f' ---- top_tokens : {top_tokens} ----- ')
445
+
446
+ # # Store results
447
+ # mask_logits[idx] = {
448
+ # "tokens": top_tokens,
449
+ # "logits": top_mask_logits.tolist()
450
+ # }
451
+
452
+ # return mask_logits
453
+
454
+
455
+ def process_sentences(self, sentences, result_dict, method="random"):
456
+ """
457
+ Process sentences and calculate logits for masked tokens.
458
+ """
459
+ results = {}
460
+
461
+ for sentence, ngrams in result_dict.items():
462
+ # Split punctuation from the last word before processing
463
+ words = sentence.split()
464
+ last_word = words[-1]
465
+ if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
466
+ # Split the last word and punctuation
467
+ words[-1] = last_word[:-1]
468
+ punctuation = last_word[-1]
469
+ # Rejoin with space before punctuation to treat it as separate token
470
+ processed_sentence = " ".join(words) + " " + punctuation
471
+ else:
472
+ processed_sentence = sentence
473
+
474
+ if method == "random":
475
+ masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
476
+ elif method == "pseudorandom":
477
+ masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
478
+ else: # entropy
479
+ masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
480
+
481
+ logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
482
+ results[sentence] = {
483
+ "masked_sentence": masked_sentence,
484
+ "mask_logits": logits
485
+ }
486
+
487
+ return results
488
+
489
+
490
+
491
+ if __name__ == "__main__":
492
+ # !!! Working both the cases regardless if the stopword is removed or not
493
+ sentences = [
494
+ "The quick brown fox jumps over small cat the lazy dog everyday again and again .",
495
+ # "A speedy brown fox jumps over a lazy dog.",
496
+ # "A swift brown fox leaps over the lethargic dog."
497
+
498
+ ]
499
+ result_dict ={
500
+ 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]},
501
+ # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
502
+ # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
503
+ }
504
+
505
+
506
+ processor = MaskingProcessor()
507
+ # results_random = processor.process_sentences(sentences, result_dict)
508
+ results_entropy = processor.process_sentences(sentences, result_dict, method="random")
509
+
510
+ '''
511
+ results structure :
512
+ results = {
513
+ "The quick brown fox jumps over the lazy dog everyday.":
514
+ { # Original sentence as key
515
+ "masked_sentence": str, # The sentence with [MASK] tokens
516
+ "mask_logits":
517
+ { # Dictionary of mask positions and their predictions
518
+ 1:
519
+ { # Position of mask in sentence
520
+ "tokens" (words) : list, # List of top 50 predicted tokens
521
+ "logits" (probabilities) : list # Corresponding logits for those tokens
522
+ },
523
+ 7:
524
+ {
525
+ "tokens" (words) : list,
526
+ "logits" (probabilities) : list
527
+ },
528
+ 10:
529
+ {
530
+ "tokens (words)": list,
531
+ "logits (probabilities)": list
532
+ }
533
+ }
534
+ }
535
+ }
536
+
537
+ '''
538
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
539
+
540
+ for sentence, output in results_entropy.items():
541
+ print(f"Original Sentence (Random): {sentence}")
542
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
543
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
544
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
545
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
546
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
547
+ # print('--------------------------------')
548
+ # for mask_idx, logits in output["mask_logits"].items():
549
+ # print(f"Logits for [MASK] at position {mask_idx}:")
550
+ # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
551
+ # print(f' len(logits) : {len(logits)}')
552
+
553
+
554
+ # ------------------------------------------------------------------------------------------------
555
+ # def mask_sentence_random(self, sentence, common_ngrams):
556
+ # """
557
+ # Mask words in the sentence based on the specified rules after removing stopwords.
558
+ # """
559
+ # original_words = sentence.split()
560
+ # # print(f' ---- original_words : {original_words} ----- ')
561
+ # non_stop_words = self.remove_stopwords(original_words)
562
+ # # print(f' ---- non_stop_words : {non_stop_words} ----- ')
563
+ # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
564
+ # # print(f' ---- common_ngrams : {common_ngrams} ----- ')
565
+ # # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
566
+
567
+ # mask_indices = []
568
+
569
+ # # Extract n-gram positions in non-stop words
570
+ # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
571
+
572
+ # # Mask a word before the first common n-gram
573
+ # if ngram_positions:
574
+ # # print(f' ---- ngram_positions : {ngram_positions} ----- ')
575
+ # first_ngram_start = ngram_positions[0][0]
576
+ # # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
577
+ # if first_ngram_start > 0:
578
+ # mask_index_before_ngram = random.randint(0, first_ngram_start-1)
579
+ # # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
580
+ # mask_indices.append(mask_index_before_ngram)
581
+
582
+ # # Mask words between common n-grams
583
+ # for i in range(len(ngram_positions) - 1):
584
+ # end_prev = ngram_positions[i][1]
585
+ # # print(f' ---- end_prev : {end_prev} ----- ')
586
+ # start_next = ngram_positions[i + 1][0]
587
+ # # print(f' ---- start_next : {start_next} ----- ')
588
+ # if start_next > end_prev + 1:
589
+ # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
590
+ # # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
591
+ # mask_indices.append(mask_index_between_ngrams)
592
+
593
+ # # Mask a word after the last common n-gram
594
+ # last_ngram_end = ngram_positions[-1][1]
595
+ # if last_ngram_end < len(non_stop_words) - 1:
596
+ # # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
597
+ # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
598
+ # # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
599
+ # mask_indices.append(mask_index_after_ngram)
600
+
601
+ # # Create mapping from non-stop words to original indices
602
+ # non_stop_to_original = {}
603
+ # non_stop_idx = 0
604
+ # for orig_idx, word in enumerate(original_words):
605
+ # if word.lower() not in self.stop_words:
606
+ # non_stop_to_original[non_stop_idx] = orig_idx
607
+ # non_stop_idx += 1
608
+
609
+ # # Map mask indices from non-stop word positions to original positions
610
+ # # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
611
+ # original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
612
+ # # print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
613
+
614
+ # # Apply masks to the original sentence
615
+ # masked_words = original_words.copy()
616
+ # for idx in original_mask_indices:
617
+ # masked_words[idx] = self.tokenizer.mask_token
618
+
619
+ # return " ".join(masked_words), original_mask_indices
utils/old/non_melting_points_v1.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.corpus import stopwords
3
+ from nltk.util import ngrams
4
+ from collections import Counter
5
+ import re
6
+
7
+ class NgramProcessor:
8
+ def __init__(self):
9
+ try:
10
+ nltk.data.find('corpora/stopwords')
11
+ except LookupError:
12
+ nltk.download('stopwords')
13
+
14
+ self.stop_words = set(stopwords.words('english'))
15
+
16
+ def remove_stopwords(self, text):
17
+ """
18
+ Remove stopwords using NLTK's stopword list
19
+
20
+ Args:
21
+ text (str): Input text
22
+
23
+ Returns:
24
+ str: Cleaned text with stopwords removed
25
+ """
26
+ words = re.findall(r'\w+', text.lower())
27
+ filtered_words = [word for word in words if word not in self.stop_words]
28
+ return ' '.join(filtered_words)
29
+
30
+ def is_exact_match(self, ngram, sentences):
31
+ """
32
+ Check if the given n-gram has an exact match in all sentences
33
+
34
+ Args:
35
+ ngram (str): The n-gram to search for
36
+ sentences (list): List of sentences to search in
37
+
38
+ Returns:
39
+ bool: True if n-gram has exact match in all sentences, False otherwise
40
+ """
41
+ return all(ngram in sentence for sentence in sentences)
42
+
43
+ def is_substring_of_any(self, ngram, common_ngrams):
44
+ """
45
+ Check if the given n-gram is an exact substring of any previously found common n-grams
46
+
47
+ Args:
48
+ ngram (str): The n-gram to check
49
+ common_ngrams (list): List of previously found common n-grams
50
+
51
+ Returns:
52
+ bool: True if ngram is a substring of any common_ngrams, False otherwise
53
+ """
54
+ return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
55
+
56
+ def find_filtered_ngrams(self, sentences):
57
+ """
58
+ Find all n-grams that have exact matches across all sentences,
59
+ excluding those that are part of larger common n-grams
60
+
61
+ Args:
62
+ sentences (list): List of sentences to analyze
63
+
64
+ Returns:
65
+ list: List of tuples where each tuple contains the n-gram and its indices in each sentence
66
+ """
67
+ original_sentences = sentences[:]
68
+ sentences = [self.remove_stopwords(sentence) for sentence in sentences]
69
+ ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
70
+ common_ngrams = []
71
+
72
+ for n in ngram_lengths:
73
+ ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
74
+ ngrams_counter = Counter(ngrams_list[0])
75
+
76
+ for ngram in ngrams_counter:
77
+ ngram_str = ' '.join(ngram)
78
+ if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, [ng[0] for ng in common_ngrams]):
79
+ indices = []
80
+ for original_sentence in original_sentences:
81
+ words = original_sentence.split()
82
+ ngram_indices = [
83
+ (i, i + n - 1) for i in range(len(words) - n + 1)
84
+ if ' '.join(words[i:i + n]).lower() == ngram_str
85
+ ]
86
+ indices.append(ngram_indices)
87
+ common_ngrams.append((ngram_str, indices))
88
+
89
+ return common_ngrams
90
+
91
+ def find_relative_order(self, sentence, common_ngrams):
92
+ """
93
+ Find the relative order of the common n-grams in the sentence
94
+
95
+ Args:
96
+ sentence (str): Sentence in which to find the relative order
97
+ common_ngrams (list): List of common n-grams
98
+
99
+ Returns:
100
+ list: List of tuples with the relative position and the n-gram
101
+ """
102
+ relative_order = []
103
+ for ngram, _ in common_ngrams:
104
+ index = sentence.find(ngram)
105
+ if index != -1:
106
+ relative_order.append((index, ngram))
107
+
108
+ return sorted(relative_order)
109
+
110
+ # Example usage
111
+ if __name__ == "__main__":
112
+ sentences = [
113
+ "The quick brown fox jumps over the lazy dog.",
114
+ "A quick brown dog outpaces a lazy fox.",
115
+ "Quick brown animals leap over lazy obstacles."
116
+ ]
117
+
118
+ processor = NgramProcessor()
119
+ common_ngrams = processor.find_filtered_ngrams(sentences)
120
+ print("Common n-grams and their indices:")
121
+ for ngram, indices in common_ngrams:
122
+ print(f"{ngram}: {indices}")
123
+
124
+ for sentence in sentences:
125
+ relative_order = processor.find_relative_order(sentence, common_ngrams)
126
+ print(f"Relative order in sentence '{sentence}':", relative_order)
127
+
128
+
129
+
130
+ # import nltk
131
+ # from nltk.corpus import stopwords
132
+ # from nltk.util import ngrams
133
+ # from collections import Counter
134
+ # import re
135
+
136
+ # class NgramProcessor:
137
+ # def __init__(self):
138
+ # try:
139
+ # nltk.data.find('corpora/stopwords')
140
+ # except LookupError:
141
+ # nltk.download('stopwords')
142
+
143
+ # self.stop_words = set(stopwords.words('english'))
144
+
145
+ # def remove_stopwords(self, text):
146
+ # """
147
+ # Remove stopwords using NLTK's stopword list
148
+
149
+ # Args:
150
+ # text (str): Input text
151
+
152
+ # Returns:
153
+ # str: Cleaned text with stopwords removed
154
+ # """
155
+ # words = re.findall(r'\w+', text.lower())
156
+ # filtered_words = [word for word in words if word not in self.stop_words]
157
+ # return ' '.join(filtered_words)
158
+
159
+ # def is_exact_match(self, ngram, sentences):
160
+ # """
161
+ # Check if the given n-gram has an exact match in all sentences
162
+
163
+ # Args:
164
+ # ngram (str): The n-gram to search for
165
+ # sentences (list): List of sentences to search in
166
+
167
+ # Returns:
168
+ # bool: True if n-gram has exact match in all sentences, False otherwise
169
+ # """
170
+ # return all(ngram in sentence for sentence in sentences)
171
+
172
+ # def is_substring_of_any(self, ngram, common_ngrams):
173
+ # """
174
+ # Check if the given n-gram is an exact substring of any previously found common n-grams
175
+
176
+ # Args:
177
+ # ngram (str): The n-gram to check
178
+ # common_ngrams (list): List of previously found common n-grams
179
+
180
+ # Returns:
181
+ # bool: True if ngram is a substring of any common_ngrams, False otherwise
182
+ # """
183
+ # return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
184
+
185
+ # def find_filtered_ngrams(self, sentences):
186
+ # """
187
+ # Find all n-grams that have exact matches across all sentences,
188
+ # excluding those that are part of larger common n-grams
189
+
190
+ # Args:
191
+ # sentences (list): List of sentences to analyze
192
+
193
+ # Returns:
194
+ # list: List of all common n-grams in order of their appearance in the first sentence
195
+ # """
196
+ # sentences = [self.remove_stopwords(sentence) for sentence in sentences]
197
+ # ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
198
+ # common_ngrams = []
199
+
200
+ # for n in ngram_lengths:
201
+ # ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
202
+ # ngrams_counter = Counter(ngrams_list[0])
203
+
204
+ # for ngram in ngrams_counter:
205
+ # ngram_str = ' '.join(ngram)
206
+ # if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, common_ngrams):
207
+ # common_ngrams.append(ngram_str)
208
+
209
+ # return common_ngrams
210
+
211
+ # def find_relative_order(self, sentence, common_ngrams):
212
+ # """
213
+ # Find the relative order of the common n-grams in the sentence
214
+
215
+ # Args:
216
+ # sentence (str): Sentence in which to find the relative order
217
+ # common_ngrams (list): List of common n-grams
218
+
219
+ # Returns:
220
+ # list: List of tuples with the relative position and the n-gram
221
+ # """
222
+ # relative_order = []
223
+ # for ngram in common_ngrams:
224
+ # index = sentence.find(ngram)
225
+ # if index != -1:
226
+ # relative_order.append((index, ngram))
227
+
228
+ # return sorted(relative_order)
229
+
230
+ # # Example usage
231
+ # if __name__ == "__main__":
232
+ # sentences = [
233
+ # "The quick brown fox jumps over the lazy dog.",
234
+ # "A quick brown dog outpaces a lazy fox.",
235
+ # "Quick brown animals leap over lazy obstacles."
236
+ # ]
237
+
238
+ # processor = NgramProcessor()
239
+ # common_ngrams = processor.find_filtered_ngrams(sentences)
240
+ # print("Common n-grams:", common_ngrams)
241
+
242
+ # for sentence in sentences:
243
+ # relative_order = processor.find_relative_order(sentence, common_ngrams)
244
+ # print(f"Relative order in sentence '{sentence}':", relative_order)
utils/old/sampling/sampling.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+ import nltk
5
+ from nltk.corpus import words
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class SamplingProcessor:
10
+ def __init__(self, tokenizer):
11
+ """
12
+ Initialize the SamplingProcessor.
13
+
14
+ Args:
15
+ tokenizer: BERT tokenizer instance
16
+ """
17
+ self.tokenizer = tokenizer
18
+ self.subtoken_prefix = self._get_subtoken_prefix()
19
+ self.subtoken_ids = self._get_subtoken_ids()
20
+ try:
21
+ nltk.data.find('corpora/words')
22
+ except LookupError:
23
+ nltk.download('words')
24
+ self.english_words = set(words.words())
25
+
26
+ # def _get_subtoken_prefix(self):
27
+ # """
28
+ # Identify the subtoken prefix based on the tokenizer.
29
+
30
+ # Returns:
31
+ # str: The prefix used for subtokens (e.g., "##" for BERT).
32
+ # """
33
+ # # This method assumes that the tokenizer uses a consistent subtoken prefix.
34
+ # # Adjust accordingly if using different tokenizers.
35
+ # # For BERT's WordPiece tokenizer:
36
+ # if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
37
+ # return self.tokenizer.init_kwargs["wordpiece_prefix"]
38
+ # elif hasattr(self.tokenizer, "prefix_tokens"):
39
+ # return self.tokenizer.prefix_tokens
40
+ # else:
41
+ # # Default to BERT's subtoken prefix
42
+ # return "##"
43
+
44
+ def _get_subtoken_prefix(self):
45
+ """
46
+ Identify the subtoken prefix based on the tokenizer.
47
+
48
+ Returns:
49
+ str: The prefix used for subtokens (e.g., "##" for BERT).
50
+ """
51
+ # This method assumes that the tokenizer uses a consistent subtoken prefix.
52
+ # Adjust accordingly if using different tokenizers.
53
+ # For BERT's WordPiece tokenizer:
54
+ if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
55
+ return self.tokenizer.init_kwargs["wordpiece_prefix"]
56
+ elif hasattr(self.tokenizer, "prefix_tokens"):
57
+ return self.tokenizer.prefix_tokens
58
+ else:
59
+ # Default to BERT's subtoken prefix
60
+ return "##"
61
+
62
+
63
+ # def _get_subtoken_ids(self):
64
+ # """
65
+ # Retrieve all token IDs that correspond to subtokens.
66
+
67
+ # Returns:
68
+ # set: A set of subtoken IDs.
69
+ # """
70
+ # vocab = self.tokenizer.get_vocab()
71
+ # subtoken_ids = set()
72
+ # for token, idx in vocab.items():
73
+ # if token.startswith(self.subtoken_prefix):
74
+ # subtoken_ids.add(idx)
75
+ # return subtoken_ids
76
+
77
+ def _get_subtoken_ids(self):
78
+ """
79
+ Retrieve all token IDs that correspond to subtokens.
80
+
81
+ Returns:
82
+ list: A list of subtoken IDs.
83
+ """
84
+ vocab = self.tokenizer.get_vocab()
85
+ subtoken_ids = []
86
+ for token, idx in vocab.items():
87
+ if token.startswith(self.subtoken_prefix):
88
+ subtoken_ids.append(idx)
89
+ return subtoken_ids # Changed from set to list
90
+
91
+
92
+ def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
93
+ tokens = self.tokenizer.tokenize(masked_sentence)
94
+
95
+ for mask_pos in sorted(mask_logits_dict.keys()):
96
+ try:
97
+ # Get logits and squeeze extra dimension
98
+ mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension
99
+
100
+ # Create a mask for valid tokens (no special tokens, no subwords)
101
+ valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
102
+ for idx in range(len(mask_logits)):
103
+ token = self.tokenizer.convert_ids_to_tokens([idx])[0]
104
+ # Only allow regular words (no special tokens, no subwords)
105
+ if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
106
+ valid_mask[idx] = True
107
+
108
+ # Get valid logits
109
+ valid_logits = mask_logits[valid_mask]
110
+ valid_indices = torch.where(valid_mask)[0]
111
+
112
+ if len(valid_logits) == 0:
113
+ print(f"Warning: No valid tokens found for position {mask_pos}")
114
+ continue
115
+
116
+ if sampling_technique == "inverse_transform":
117
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
118
+ cumulative_probs = torch.cumsum(probs, dim=-1)
119
+ random_prob = random.random()
120
+ sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
121
+ sampled_index = valid_indices[sampled_idx].item()
122
+
123
+ elif sampling_technique == "exponential_minimum":
124
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
125
+ exp_probs = torch.exp(-torch.log(probs))
126
+ random_probs = torch.rand_like(exp_probs)
127
+ sampled_idx = torch.argmax(random_probs * exp_probs).item()
128
+ sampled_index = valid_indices[sampled_idx].item()
129
+
130
+ elif sampling_technique == "temperature":
131
+ valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
132
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
133
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
134
+ raise ValueError("The computed probabilities contain NaN or inf values.")
135
+ probs = torch.max(probs, torch.tensor(1e-8))
136
+ probs = probs / torch.sum(probs)
137
+ sampled_idx = torch.multinomial(probs, 1)[0].item()
138
+ sampled_index = valid_indices[sampled_idx].item()
139
+
140
+ elif sampling_technique == 'greedy':
141
+ sampled_idx = torch.argmax(valid_logits).item()
142
+ sampled_index = valid_indices[sampled_idx].item()
143
+
144
+ # Replace mask with sampled token
145
+ sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
146
+ tokens[mask_pos] = sampled_token
147
+
148
+ except Exception as e:
149
+ print(f"Error sampling for position {mask_pos}: {str(e)}")
150
+ continue
151
+
152
+ return self.tokenizer.convert_tokens_to_string(tokens)
153
+
154
+
155
+
156
+ def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
157
+ """
158
+ Process all masked sentences in the results dictionary.
159
+
160
+ Args:
161
+ results_dict (dict): Dictionary containing masked sentences and their logits
162
+ sampling_technique (str): Sampling method to use
163
+ temperature (float): Temperature parameter for sampling
164
+
165
+ Returns:
166
+ dict: Dictionary containing original, masked, and sampled sentences
167
+ """
168
+ processed_results = {}
169
+
170
+ for original_sentence, data in results_dict.items():
171
+ masked_sentence = data["masked_sentence"]
172
+ mask_logits = data["mask_logits"]
173
+
174
+ sampled_sentence = self.sample_tokens(
175
+ mask_logits,
176
+ masked_sentence,
177
+ sampling_technique,
178
+ temperature
179
+ )
180
+
181
+ processed_results[original_sentence] = {
182
+ "masked_sentence": masked_sentence,
183
+ "sampled_sentence": sampled_sentence
184
+ }
185
+
186
+ return processed_results
187
+
188
+ if __name__ == "__main__":
189
+ sentences = [
190
+ "The quick brown fox jumps over the lazy dog everyday.",
191
+ ]
192
+ result_dict = {
193
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
194
+ }
195
+
196
+ # First, mask the sentences
197
+ masking_processor = MaskingProcessor()
198
+ masking_results = masking_processor.process_sentences(sentences, result_dict)
199
+
200
+ # Then, sample replacements for the masks
201
+ sampling_processor = SamplingProcessor(masking_processor.tokenizer)
202
+
203
+ # Try different sampling techniques
204
+ sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
205
+
206
+ for technique in sampling_techniques:
207
+ print(f"\nSampling using {technique}:")
208
+ sampled_results = sampling_processor.process_masked_sentences(
209
+ masking_results,
210
+ sampling_technique=technique,
211
+ temperature=1.0
212
+ )
213
+
214
+ for original_sentence, result in sampled_results.items():
215
+ print(f"Original: {original_sentence}")
216
+ print(f"Masked: {result['masked_sentence']}")
217
+ print(f"Sampled: {result['sampled_sentence']}")
218
+ print("---")
219
+
220
+ # --------------------------------------------------------------------------------------------------
221
+ # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
222
+ # words = masked_sentence.split()
223
+ # mask_positions = sorted(mask_logits_dict.keys())
224
+
225
+ # for mask_pos in mask_positions:
226
+ # mask_logits = torch.tensor(mask_logits_dict[mask_pos])
227
+
228
+ # try:
229
+ # if sampling_technique == "inverse_transform":
230
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
231
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
232
+ # random_prob = random.random()
233
+ # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
234
+
235
+ # elif sampling_technique == "exponential_minimum":
236
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
237
+ # exp_probs = torch.exp(-torch.log(probs))
238
+ # random_probs = torch.rand_like(exp_probs)
239
+ # sampled_index = torch.argmax(random_probs * exp_probs).item()
240
+
241
+ # elif sampling_technique == "temperature":
242
+ # mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
243
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
244
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
245
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
246
+ # probs = torch.max(probs, torch.tensor(1e-8))
247
+ # probs = probs / torch.sum(probs)
248
+ # sampled_index = torch.multinomial(probs, 1)[0].item()
249
+
250
+ # elif sampling_technique == 'greedy':
251
+ # sampled_index = torch.argmax(mask_logits).item()
252
+
253
+ # else:
254
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
255
+
256
+ # # Replace mask with sampled token
257
+ # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
258
+ # words[mask_pos] = sampled_token
259
+
260
+ # except Exception as e:
261
+ # print(f"Error sampling for position {mask_pos}: {str(e)}")
262
+ # continue
263
+
264
+ # return " ".join(words)
265
+
266
+ ## MORE WEIRD RESULTS
267
+ # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
268
+ # words = masked_sentence.split()
269
+ # mask_positions = sorted(mask_logits_dict.keys())
270
+
271
+ # for mask_pos in mask_positions:
272
+ # mask_logits = torch.tensor(mask_logits_dict[mask_pos])
273
+
274
+ # try:
275
+ # # Create a mask for valid tokens (no special tokens, no subwords)
276
+ # valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
277
+ # for idx in range(len(mask_logits)):
278
+ # token = self.tokenizer.convert_ids_to_tokens([idx])[0]
279
+ # # Only allow regular words (no special tokens, no subwords)
280
+ # if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
281
+ # valid_mask[idx] = True
282
+
283
+ # # Get valid logits
284
+ # valid_logits = mask_logits[valid_mask]
285
+ # valid_indices = torch.where(valid_mask)[0]
286
+
287
+ # if len(valid_logits) == 0:
288
+ # print(f"Warning: No valid tokens found for position {mask_pos}")
289
+ # continue
290
+
291
+ # if sampling_technique == "inverse_transform":
292
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
293
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
294
+ # random_prob = random.random()
295
+ # sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
296
+ # sampled_index = valid_indices[sampled_idx].item()
297
+
298
+ # elif sampling_technique == "exponential_minimum":
299
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
300
+ # exp_probs = torch.exp(-torch.log(probs))
301
+ # random_probs = torch.rand_like(exp_probs)
302
+ # sampled_idx = torch.argmax(random_probs * exp_probs).item()
303
+ # sampled_index = valid_indices[sampled_idx].item()
304
+
305
+ # elif sampling_technique == "temperature":
306
+ # valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
307
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
308
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
309
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
310
+ # probs = torch.max(probs, torch.tensor(1e-8))
311
+ # probs = probs / torch.sum(probs)
312
+ # sampled_idx = torch.multinomial(probs, 1)[0].item()
313
+ # sampled_index = valid_indices[sampled_idx].item()
314
+
315
+ # elif sampling_technique == 'greedy':
316
+ # sampled_idx = torch.argmax(valid_logits).item()
317
+ # sampled_index = valid_indices[sampled_idx].item()
318
+
319
+ # else:
320
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
321
+
322
+ # # Replace mask with sampled token
323
+ # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
324
+ # words[mask_pos] = sampled_token
325
+
326
+ # except Exception as e:
327
+ # print(f"Error sampling for position {mask_pos}: {str(e)}")
328
+ # continue
329
+
330
+ # return " ".join(words)
utils/old/sampling/sampling_methods.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForMaskedLM
2
+ import torch
3
+ import random
4
+ from masking_methods import MaskingProcessor
5
+ from transformers import pipeline
6
+
7
+ class SamplingProcessorWithModel:
8
+ def __init__(self, model_name='bert-base-uncased'):
9
+ self.tokenizer = BertTokenizer.from_pretrained(model_name)
10
+ self.model = BertForMaskedLM.from_pretrained(model_name)
11
+ self.model.eval() # Set the model to evaluation mode
12
+
13
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
14
+ """
15
+ Fills each mask in the masked sentence using the specified sampling technique.
16
+
17
+ Args:
18
+ masked_sentence (str): Sentence with [MASK] tokens.
19
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
20
+ temperature (float): Temperature parameter for sampling methods.
21
+
22
+ Returns:
23
+ str: Sentence with the masks filled.
24
+ """
25
+ input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
26
+
27
+ while self.tokenizer.mask_token_id in input_ids[0]:
28
+ # Find indices of all [MASK] tokens
29
+ mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
30
+
31
+ # Process the first [MASK] token in the sequence
32
+ mask_index = mask_indices[0].item()
33
+
34
+ # Get logits from the model
35
+ with torch.no_grad():
36
+ outputs = self.model(input_ids)
37
+ logits = outputs.logits
38
+
39
+ # Extract logits for the [MASK] token
40
+ mask_logits = logits[0, mask_index]
41
+
42
+ if sampling_technique == "inverse_transform":
43
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
44
+ cumulative_probs = torch.cumsum(probs, dim=-1)
45
+ random_prob = random.random()
46
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
47
+
48
+ elif sampling_technique == "exponential_minimum":
49
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
50
+ exp_probs = torch.exp(-torch.log(probs))
51
+ random_probs = torch.rand_like(exp_probs)
52
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
53
+
54
+ elif sampling_technique == "temperature":
55
+ mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
56
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
57
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
58
+ raise ValueError("The computed probabilities contain NaN or inf values.")
59
+ probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
60
+ probs = probs / torch.sum(probs)
61
+ probs = probs.flatten()
62
+ if probs.size(0) > 1:
63
+ sampled_index = torch.multinomial(probs, 1).item()
64
+ else:
65
+ sampled_index = torch.argmax(probs).item()
66
+
67
+ elif sampling_technique == 'greedy':
68
+ sampled_index = torch.argmax(mask_logits).item()
69
+
70
+ else:
71
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
72
+
73
+ # Replace the first [MASK] with the selected token
74
+ input_ids[0, mask_index] = sampled_index
75
+
76
+ return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
77
+
78
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
79
+ """
80
+ Fills each mask in the masked sentence using the specified sampling technique.
81
+
82
+ Args:
83
+ masked_sentence (str): Sentence with [MASK] tokens.
84
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
85
+ temperature (float): Temperature parameter for sampling methods.
86
+
87
+ Returns:
88
+ str: Sentence with the masks filled.
89
+ """
90
+ while '[MASK]' in masked_sentence:
91
+ # Get predictions for the first [MASK]
92
+ predictions = self.unmasker(masked_sentence)
93
+
94
+ # Ensure predictions is a list of dictionaries
95
+ if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
96
+ raise ValueError("Unexpected structure in predictions from the pipeline.")
97
+
98
+ # Extract logits (scores) from the predictions
99
+ logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
100
+
101
+ if sampling_technique == "inverse_transform":
102
+ probs = torch.softmax(logits / temperature, dim=-1)
103
+ cumulative_probs = torch.cumsum(probs, dim=-1)
104
+ random_prob = random.random()
105
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
106
+
107
+ elif sampling_technique == "exponential_minimum":
108
+ probs = torch.softmax(logits / temperature, dim=-1)
109
+ exp_probs = torch.exp(-torch.log(probs))
110
+ random_probs = torch.rand_like(exp_probs)
111
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
112
+
113
+ elif sampling_technique == "temperature":
114
+ logits = torch.clamp(logits, min=-1e8, max=1e8)
115
+ probs = torch.softmax(logits / temperature, dim=-1)
116
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
117
+ raise ValueError("The computed probabilities contain NaN or inf values.")
118
+ probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
119
+ probs = probs / torch.sum(probs)
120
+ probs = probs.flatten()
121
+ if probs.size(0) > 1:
122
+ sampled_index = torch.multinomial(probs, 1).item()
123
+ else:
124
+ sampled_index = torch.argmax(probs).item()
125
+
126
+ elif sampling_technique == 'greedy':
127
+ sampled_index = torch.argmax(logits).item()
128
+
129
+ else:
130
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
131
+
132
+ # Replace the first [MASK] with the selected word
133
+ sampled_token = predictions[sampled_index]['token_str']
134
+ masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
135
+
136
+ return masked_sentence
137
+
138
+
139
+ # Example usage
140
+ if __name__ == "__main__":
141
+ from transformers import BertTokenizer
142
+
143
+ # Define sentences and result_dict
144
+ sentences = [
145
+ "The quick brown fox jumps over the lazy dog.",
146
+ "A quick brown dog outpaces a lazy fox.",
147
+ "Quick brown dog leaps over lazy the fox."
148
+ ]
149
+ result_dict = {
150
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
151
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
152
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
153
+ }
154
+
155
+ masking_processor = MaskingProcessor()
156
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
157
+
158
+ # Use SamplingProcessor
159
+ sampling_processor = SamplingProcessorWithModel()
160
+
161
+ # Iterate through masking results to apply sampling
162
+ for sentence, result in masking_results.items():
163
+ print(f"Original Sentence (Random): {sentence}")
164
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
165
+ masked_sentence = result["masked_sentence"]
166
+
167
+ # Apply different sampling techniques
168
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
169
+ print(f"Sampling Technique: {technique}")
170
+ filled_sentence = sampling_processor.fill_masked_sentence(
171
+ masked_sentence=masked_sentence,
172
+ sampling_technique=technique,
173
+ temperature=1.0 # Adjust temperature as needed
174
+ )
175
+ print(f"Filled Sentence: {filled_sentence}\n")
176
+ print('--------------------------------')
177
+
178
+
179
+
180
+ # from transformers import pipeline
181
+ # import torch
182
+ # import random
183
+ # from masking_methods import MaskingProcessor
184
+
185
+
186
+ # class SamplingProcessorWithPipeline:
187
+ # def __init__(self, model_name='bert-base-uncased'):
188
+ # self.unmasker = pipeline('fill-mask', model=model_name)
189
+ # self.tokenizer = self.unmasker.tokenizer
190
+
191
+ # def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
192
+ # """
193
+ # Fills each mask in the masked sentence using the specified sampling technique.
194
+
195
+ # Args:
196
+ # masked_sentence (str): Sentence with [MASK] tokens.
197
+ # sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
198
+ # temperature (float): Temperature parameter for sampling methods.
199
+
200
+ # Returns:
201
+ # str: Sentence with the masks filled.
202
+ # """
203
+ # while '[MASK]' in masked_sentence:
204
+ # # Get predictions for the first [MASK]
205
+ # predictions = self.unmasker(masked_sentence)
206
+ # print(f' predictions : {predictions}')
207
+ # print(f' type of predictions : {type(predictions)}')
208
+
209
+ # # Ensure predictions is a list of dictionaries for the first [MASK]
210
+ # if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
211
+ # raise ValueError("Unexpected structure in predictions from the pipeline.")
212
+
213
+ # # Extract logits (scores) from the predictions
214
+ # logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
215
+
216
+ # if sampling_technique == "inverse_transform":
217
+ # probs = torch.softmax(logits / temperature, dim=-1)
218
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
219
+ # random_prob = random.random()
220
+ # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
221
+
222
+ # elif sampling_technique == "exponential_minimum":
223
+ # probs = torch.softmax(logits / temperature, dim=-1)
224
+ # exp_probs = torch.exp(-torch.log(probs))
225
+ # random_probs = torch.rand_like(exp_probs)
226
+ # sampled_index = torch.argmax(random_probs * exp_probs).item()
227
+
228
+ # elif sampling_technique == "temperature":
229
+ # logits = torch.clamp(logits, min=-1e8, max=1e8)
230
+ # probs = torch.softmax(logits / temperature, dim=-1)
231
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
232
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
233
+ # probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
234
+ # probs = probs / torch.sum(probs)
235
+ # probs = probs.flatten()
236
+ # if probs.size(0) > 1:
237
+ # sampled_index = torch.multinomial(probs, 1).item()
238
+ # else:
239
+ # sampled_index = torch.argmax(probs).item()
240
+
241
+ # elif sampling_technique == 'greedy':
242
+ # sampled_index = torch.argmax(logits).item()
243
+
244
+ # else:
245
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
246
+
247
+ # # Replace the first [MASK] with the selected word
248
+ # sampled_token = predictions[sampled_index]['token_str']
249
+ # masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
250
+
251
+ # return masked_sentence
252
+
253
+
254
+ # # Example usage
255
+ # if __name__ == "__main__":
256
+ # from transformers import BertTokenizer
257
+
258
+ # # Define sentences and result_dict
259
+ # sentences = [
260
+ # "The quick brown fox jumps over the lazy dog.",
261
+ # "A quick brown dog outpaces a lazy fox.",
262
+ # "Quick brown animals leap over lazy obstacles."
263
+ # ]
264
+ # result_dict = {
265
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
266
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
267
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
268
+ # }
269
+
270
+ # masking_processor = MaskingProcessor()
271
+ # masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
272
+
273
+ # # Use SamplingProcessor
274
+ # sampling_processor = SamplingProcessorWithPipeline()
275
+
276
+ # # Iterate through masking results to apply sampling
277
+ # for sentence, result in masking_results.items():
278
+ # print(f"Original Sentence (Random): {sentence}")
279
+ # print(f"Masked Sentence (Random): {result['masked_sentence']}")
280
+ # masked_sentence = result["masked_sentence"]
281
+
282
+ # # Apply different sampling techniques
283
+ # for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
284
+ # print(f"Sampling Technique: {technique}")
285
+ # filled_sentence = sampling_processor.fill_masked_sentence(
286
+ # masked_sentence=masked_sentence,
287
+ # sampling_technique=technique,
288
+ # temperature=1.0 # Adjust temperature as needed
289
+ # )
290
+ # print(f"Filled Sentence: {filled_sentence}\n")
291
+ # print('--------------------------------')
utils/old/sampling/sampling_methods_v1.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+
5
+ class SamplingProcessor:
6
+ def __init__(self, tokenizer):
7
+ self.tokenizer = tokenizer
8
+
9
+ def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
10
+ """
11
+ Fills each mask in the masked sentence using the specified sampling technique.
12
+
13
+ Args:
14
+ original_sentence (str): The original masked sentence.
15
+ mask_logits (dict): Logits for each [MASK] token.
16
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
17
+ temperature (float): Temperature parameter for sampling methods.
18
+
19
+ Returns:
20
+ str: Sentence with the masks filled.
21
+ """
22
+ sentence_tokens = self.tokenizer.tokenize(original_sentence)
23
+ mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
24
+
25
+ if len(mask_token_indices) != len(mask_logits):
26
+ raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
27
+
28
+ for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
29
+ # Convert logits to a tensor
30
+ filtered_logits = torch.tensor(filtered_logits)
31
+ # filtered_logits, _ = torch.sort(filtered_logits, descending=True)
32
+ # print(f' type of filtered_logits : {type(filtered_logits)}')
33
+ # filtered_logits = filtered_logits[:5]
34
+
35
+ if sampling_technique == "inverse_transform":
36
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
37
+ cumulative_probs = torch.cumsum(probs, dim=-1)
38
+ random_prob = random.random()
39
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
40
+
41
+ elif sampling_technique == "exponential_minimum":
42
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
43
+ exp_probs = torch.exp(-torch.log(probs))
44
+ random_probs = torch.rand_like(exp_probs)
45
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
46
+
47
+ elif sampling_technique == "temperature":
48
+ filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
49
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
50
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
51
+ raise ValueError("The computed probabilities contain NaN or inf values.")
52
+ probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
53
+ probs = probs / torch.sum(probs)
54
+ probs = probs.flatten()
55
+ if probs.size(0) > 1:
56
+ sampled_index = torch.multinomial(probs, 1).item()
57
+ else:
58
+ sampled_index = torch.argmax(probs).item()
59
+
60
+ elif sampling_technique == 'greedy':
61
+ sampled_index = torch.argmax(filtered_logits).item()
62
+
63
+ else:
64
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
65
+
66
+ sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
67
+ sentence_tokens[mask_idx] = sampled_token
68
+
69
+ return self.tokenizer.convert_tokens_to_string(sentence_tokens)
70
+
71
+
72
+
73
+ def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
74
+ """
75
+ Process multiple masked sentences and fill their masks using the specified sampling technique.
76
+
77
+ Args:
78
+ masked_sentences (list): List of masked sentences.
79
+ mask_logits (dict): Logits for each [MASK] token in each sentence.
80
+ sampling_technique (str): Sampling technique to use.
81
+ temperature (float): Temperature parameter for sampling methods.
82
+
83
+ Returns:
84
+ list: List of sentences with masks filled.
85
+ """
86
+ filled_sentences = []
87
+ for sentence, logits in zip(masked_sentences, mask_logits):
88
+ filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
89
+ filled_sentences.append(filled_sentence)
90
+ return filled_sentences
91
+
92
+ # Example usage
93
+ if __name__ == "__main__":
94
+ from transformers import BertTokenizer
95
+
96
+ # tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
97
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
98
+ processor = SamplingProcessor(tokenizer)
99
+
100
+ sentences = [
101
+ "The quick brown fox jumps over the lazy dog.",
102
+ "A quick brown dog outpaces a lazy fox.",
103
+ "Quick brown dog leaps over lazy the fox."
104
+ ]
105
+ result_dict = {
106
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
107
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
108
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
109
+ }
110
+
111
+
112
+ masking_processor = MaskingProcessor()
113
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
114
+ # masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
115
+ # mask_logits = {
116
+ # 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
117
+ # 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
118
+ # }
119
+
120
+ # Iterate through masking results to apply sampling
121
+ for sentence, result in masking_results.items():
122
+ print(f"Original Sentence (Random): {sentence}")
123
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
124
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
125
+ print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
126
+ print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
127
+ print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
128
+ masked_sentence = result["masked_sentence"]
129
+ mask_logits = result["mask_logits"]
130
+
131
+ print(f"Original Masked Sentence: {masked_sentence}")
132
+
133
+ # Apply different sampling techniques
134
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
135
+ print(f"Sampling Technique: {technique}")
136
+
137
+ # Fill the masks using the sampling processor
138
+ filled_sentence = processor.fill_masked_sentence(
139
+ original_sentence=masked_sentence,
140
+ mask_logits=mask_logits,
141
+ sampling_technique=technique,
142
+ temperature=1.0 # Adjust temperature as needed
143
+ )
144
+
145
+ print(f"Filled Sentence: {filled_sentence}\n")
146
+ print('--------------------------------')
utils/old/sampling/sampling_methods_v2.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+ import random
4
+ from masking_methods import MaskingProcessor
5
+
6
+
7
+ class SamplingProcessorWithPipeline:
8
+ def __init__(self, model_name='bert-base-uncased'):
9
+ self.unmasker = pipeline('fill-mask', model=model_name)
10
+ self.tokenizer = self.unmasker.tokenizer
11
+
12
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
13
+ """
14
+ Fills each mask in the masked sentence using the specified sampling technique.
15
+
16
+ Args:
17
+ masked_sentence (str): Sentence with [MASK] tokens.
18
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
19
+ temperature (float): Temperature parameter for sampling methods.
20
+
21
+ Returns:
22
+ str: Sentence with the masks filled.
23
+ """
24
+ while '[MASK]' in masked_sentence:
25
+ # Get predictions for the first [MASK]
26
+ predictions = self.unmasker(masked_sentence)
27
+ print(f' predictions : {predictions}')
28
+ print(f' type of predictions : {type(predictions)}')
29
+
30
+ # Ensure predictions is a list of dictionaries
31
+ if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
32
+ raise ValueError("Unexpected structure in predictions from the pipeline.")
33
+
34
+ # Extract logits (scores) from the predictions
35
+ logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
36
+
37
+ if sampling_technique == "inverse_transform":
38
+ probs = torch.softmax(logits / temperature, dim=-1)
39
+ cumulative_probs = torch.cumsum(probs, dim=-1)
40
+ random_prob = random.random()
41
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
42
+
43
+ elif sampling_technique == "exponential_minimum":
44
+ probs = torch.softmax(logits / temperature, dim=-1)
45
+ exp_probs = torch.exp(-torch.log(probs))
46
+ random_probs = torch.rand_like(exp_probs)
47
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
48
+
49
+ elif sampling_technique == "temperature":
50
+ logits = torch.clamp(logits, min=-1e8, max=1e8)
51
+ probs = torch.softmax(logits / temperature, dim=-1)
52
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
53
+ raise ValueError("The computed probabilities contain NaN or inf values.")
54
+ probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
55
+ probs = probs / torch.sum(probs)
56
+ probs = probs.flatten()
57
+ if probs.size(0) > 1:
58
+ sampled_index = torch.multinomial(probs, 1).item()
59
+ else:
60
+ sampled_index = torch.argmax(probs).item()
61
+
62
+ elif sampling_technique == 'greedy':
63
+ sampled_index = torch.argmax(logits).item()
64
+
65
+ else:
66
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
67
+
68
+ # Replace the first [MASK] with the selected word
69
+ sampled_token = predictions[sampled_index]['token_str']
70
+ masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
71
+
72
+ return masked_sentence
73
+
74
+
75
+ # Example usage
76
+ if __name__ == "__main__":
77
+ from transformers import BertTokenizer
78
+
79
+ # Define sentences and result_dict
80
+ sentences = [
81
+ "The quick brown fox jumps over the lazy dog.",
82
+ "A quick brown dog outpaces a lazy fox.",
83
+ "Quick brown dog leaps over lazy the fox."
84
+ ]
85
+ result_dict = {
86
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
87
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
88
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
89
+ }
90
+
91
+ masking_processor = MaskingProcessor()
92
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
93
+
94
+ # Use SamplingProcessor
95
+ sampling_processor = SamplingProcessorWithPipeline()
96
+
97
+ # Iterate through masking results to apply sampling
98
+ for sentence, result in masking_results.items():
99
+ print(f"Original Sentence (Random): {sentence}")
100
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
101
+ masked_sentence = result["masked_sentence"]
102
+
103
+ # Apply different sampling techniques
104
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
105
+ print(f"Sampling Technique: {technique}")
106
+ filled_sentence = sampling_processor.fill_masked_sentence(
107
+ masked_sentence=masked_sentence,
108
+ sampling_technique=technique,
109
+ temperature=1.0 # Adjust temperature as needed
110
+ )
111
+ print(f"Filled Sentence: {filled_sentence}\n")
112
+ print('--------------------------------')
utils/old/sampling_final_copy.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+
5
+ class SamplingProcessor:
6
+ def __init__(self, tokenizer):
7
+ """
8
+ Initialize the SamplingProcessor.
9
+
10
+ Args:
11
+ tokenizer: BERT tokenizer instance
12
+ """
13
+ self.tokenizer = tokenizer
14
+
15
+ def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
16
+ """
17
+ Sample tokens for each mask in the sentence using the specified sampling technique.
18
+
19
+ Args:
20
+ mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens
21
+ masked_sentence (str): Sentence with [MASK] tokens
22
+ sampling_technique (str): Sampling method to use
23
+ temperature (float): Temperature parameter for sampling
24
+
25
+ Returns:
26
+ str: Sentence with sampled tokens replacing masks
27
+ """
28
+ words = masked_sentence.split()
29
+
30
+ # Convert positions and logits to sorted list to process masks in order
31
+ mask_positions = sorted(mask_logits_dict.keys())
32
+
33
+ for mask_pos in mask_positions:
34
+ mask_data = mask_logits_dict[mask_pos]
35
+ mask_logits = torch.tensor(mask_data['logits'])
36
+ candidate_tokens = mask_data['tokens']
37
+
38
+ try:
39
+ if sampling_technique == "inverse_transform":
40
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
41
+ cumulative_probs = torch.cumsum(probs, dim=-1)
42
+ random_prob = random.random()
43
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
44
+
45
+ elif sampling_technique == "exponential_minimum":
46
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
47
+ exp_probs = torch.exp(-torch.log(probs))
48
+ random_probs = torch.rand_like(exp_probs)
49
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
50
+
51
+ elif sampling_technique == "temperature":
52
+ mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
53
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
54
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
55
+ raise ValueError("The computed probabilities contain NaN or inf values.")
56
+ probs = torch.max(probs, torch.tensor(1e-8))
57
+ probs = probs / torch.sum(probs)
58
+ probs = probs.flatten()
59
+ if probs.size(0) > 1:
60
+ sampled_index = torch.multinomial(probs, 1).item()
61
+ else:
62
+ sampled_index = torch.argmax(probs).item()
63
+
64
+ elif sampling_technique == 'greedy':
65
+ sampled_index = torch.argmax(mask_logits).item()
66
+
67
+ else:
68
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
69
+
70
+ # Use the sampled index to get the corresponding token
71
+ sampled_token = candidate_tokens[sampled_index]
72
+ # Remove ## if it's a subword token
73
+ sampled_token = sampled_token.replace('##', '')
74
+ words[mask_pos] = sampled_token
75
+
76
+ except Exception as e:
77
+ print(f"Error sampling for position {mask_pos}: {str(e)}")
78
+ continue
79
+
80
+ return " ".join(words)
81
+
82
+ def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
83
+ """
84
+ Process all masked sentences in the results dictionary.
85
+
86
+ Args:
87
+ results_dict (dict): Dictionary containing masked sentences and their logits
88
+ sampling_technique (str): Sampling method to use
89
+ temperature (float): Temperature parameter for sampling
90
+
91
+ Returns:
92
+ dict: Dictionary containing original, masked, and sampled sentences
93
+ """
94
+ processed_results = {}
95
+
96
+ for original_sentence, data in results_dict.items():
97
+ masked_sentence = data["masked_sentence"]
98
+ mask_logits = data["mask_logits"]
99
+
100
+ sampled_sentence = self.sample_tokens(
101
+ mask_logits,
102
+ masked_sentence,
103
+ sampling_technique,
104
+ temperature
105
+ )
106
+
107
+ processed_results[original_sentence] = {
108
+ "masked_sentence": masked_sentence,
109
+ "sampled_sentence": sampled_sentence
110
+ }
111
+
112
+ return processed_results
113
+
114
+
115
+ if __name__ == "__main__":
116
+ sentences = [
117
+ "The quick brown fox jumps over the lazy dog everyday.",
118
+ "A speedy brown fox jumps over a lazy dog.",
119
+ "A swift brown fox leaps over the lethargic dog."
120
+
121
+ ]
122
+ result_dict ={
123
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
124
+ 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
125
+ 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
126
+ }
127
+
128
+ # First, mask the sentences
129
+ masking_processor = MaskingProcessor()
130
+ masking_results = masking_processor.process_sentences(sentences, result_dict)
131
+
132
+ # Then, sample replacements for the masks
133
+ sampling_processor = SamplingProcessor(masking_processor.tokenizer)
134
+
135
+ # Try different sampling techniques
136
+ sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
137
+
138
+ for technique in sampling_techniques:
139
+ print(f"\nSampling using {technique}:")
140
+ sampled_results = sampling_processor.process_masked_sentences(
141
+ masking_results,
142
+ sampling_technique=technique,
143
+ temperature=1.0
144
+ )
145
+
146
+ '''
147
+ {
148
+ "original_sentence_1":
149
+ {
150
+ "masked_sentence": "sentence with [MASK] tokens",
151
+ "sampling_method1": "sentence with sampled tokens",
152
+ },
153
+ "original_sentence_2":
154
+ {
155
+ "masked_sentence": "sentence with [MASK] tokens",
156
+ "sampling_method": "sentence with sampled tokens"
157
+ },
158
+ # ... and so on for each input sentence
159
+ },
160
+
161
+ '''
162
+
163
+ for original_sentence, result in sampled_results.items():
164
+ print(f"Original: {original_sentence}")
165
+ print(f"Masked: {result['masked_sentence']}")
166
+ print(f"Sampled: {result['sampled_sentence']}")
167
+ print("---")
168
+
utils/paraphraser.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the code to generate paraphrases of sentences.
3
+ """
4
+ import os
5
+ import sys
6
+ import logging
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ from tqdm import tqdm # for progress bars
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+
11
+ from utils.config import load_config
12
+ # config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
13
+ # config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase']
14
+
15
+ # Configure logging to show only warnings or above on the terminal.
16
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class Paraphraser:
20
+ """
21
+ Paraphraser class to generate paraphrases of sentences.
22
+ """
23
+ def __init__(self, config):
24
+ self.config = config
25
+ import torch
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ tqdm.write(f"[Paraphraser] Initializing on device: {self.device}")
28
+ self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
29
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device)
30
+ self.num_beams = config['num_beams']
31
+ self.num_beam_groups = config['num_beam_groups']
32
+ self.num_return_sequences = config['num_return_sequences']
33
+ self.repetition_penalty = config['repetition_penalty']
34
+ self.diversity_penalty = config['diversity_penalty']
35
+ self.no_repeat_ngram_size = config['no_repeat_ngram_size']
36
+ self.temperature = config['temperature']
37
+ self.max_length = config['max_length']
38
+
39
+ def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None):
40
+ tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}")
41
+ if num_return_sequences is None:
42
+ num_return_sequences = self.num_return_sequences
43
+ if num_beams is None:
44
+ num_beams = self.num_beams
45
+ if num_beam_groups is None:
46
+ num_beam_groups = self.num_beam_groups
47
+
48
+ inputs = self.tokenizer.encode("paraphrase: " + sentence,
49
+ return_tensors="pt",
50
+ max_length=self.max_length,
51
+ truncation=True).to(self.device)
52
+ outputs = self.model.generate(
53
+ inputs,
54
+ max_length=self.max_length,
55
+ num_beams=num_beams,
56
+ num_beam_groups=num_beam_groups,
57
+ num_return_sequences=num_return_sequences,
58
+ repetition_penalty=self.repetition_penalty,
59
+ diversity_penalty=self.diversity_penalty,
60
+ no_repeat_ngram_size=self.no_repeat_ngram_size,
61
+ temperature=self.temperature
62
+ )
63
+ paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
64
+ for output in tqdm(outputs, desc="Decoding Paraphrases")]
65
+ tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.")
66
+ return paraphrases
67
+
68
+ if __name__ == "__main__":
69
+ config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml'
70
+ config = load_config(config_path)
71
+ paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase'])
72
+ sentence = "The quick brown fox jumps over the lazy dog."
73
+ paraphrases = paraphraser.paraphrase(sentence)
74
+ for paraphrase in paraphrases:
75
+ print(paraphrase)