|
import gradio as gr |
|
from utils.watermark import Watermarker |
|
from utils.config import load_config |
|
from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html |
|
from renderers.tree import generate_subplot1, generate_subplot2 |
|
from pathlib import Path |
|
import time |
|
from typing import Dict, List, Tuple, Any |
|
import plotly.graph_objects as go |
|
|
|
class WatermarkerInterface: |
|
def __init__(self, config): |
|
|
|
self.pipeline = Watermarker(config) |
|
self.common_grams = {} |
|
self.highlight_info = [] |
|
self.masked_sentences = [] |
|
|
|
def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]: |
|
"""Wrapper for paraphrasing that includes highlighting""" |
|
start_time = time.time() |
|
|
|
|
|
self.pipeline.Paraphrase(prompt) |
|
|
|
|
|
seen_ngrams = {} |
|
original_indexed_ngrams = [] |
|
|
|
original_sentence = self.pipeline.user_prompt |
|
original_ngrams = self.pipeline.common_grams.get(original_sentence, {}) |
|
|
|
|
|
ngram_occurrences = [ |
|
(min(indices, key=lambda x: x[0])[0], gram) |
|
for gram, indices in original_ngrams.items() |
|
] |
|
|
|
|
|
ngram_occurrences.sort() |
|
|
|
|
|
for idx, (position, gram) in enumerate(ngram_occurrences, start=1): |
|
seen_ngrams[gram] = idx |
|
original_indexed_ngrams.append((idx, gram)) |
|
|
|
print("Original Indexed N-grams:", original_indexed_ngrams) |
|
|
|
|
|
colors = ["red", "blue", "green", "purple", "orange"] |
|
highlight_info = [ |
|
(ngram, colors[i % len(colors)]) |
|
for i, (index, ngram) in enumerate(original_indexed_ngrams) |
|
] |
|
common_grams = original_indexed_ngrams |
|
self.highlight_info = highlight_info |
|
self.common_grams = common_grams |
|
|
|
|
|
paraphrase_indexed_ngrams = {} |
|
|
|
for sentence in self.pipeline.paraphrased_sentences: |
|
sentence_ngrams = [] |
|
sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {}) |
|
|
|
for gram, indices in sentence_ngrams_dict.items(): |
|
first_occurrence = min(indices, key=lambda x: x[0])[0] |
|
|
|
|
|
if gram in seen_ngrams: |
|
index = seen_ngrams[gram] |
|
else: |
|
index = len(seen_ngrams) + 1 |
|
seen_ngrams[gram] = index |
|
|
|
sentence_ngrams.append((index, gram)) |
|
|
|
sentence_ngrams.sort() |
|
paraphrase_indexed_ngrams[sentence] = sentence_ngrams |
|
|
|
print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams) |
|
|
|
|
|
highlighted_prompt = highlight_common_words( |
|
common_grams, |
|
[self.pipeline.user_prompt], |
|
"Original Prompt with Highlighted Common Sequences" |
|
) |
|
|
|
highlighted_accepted = highlight_common_words_dict( |
|
common_grams, |
|
self.pipeline.selected_sentences, |
|
"Accepted Paraphrased Sentences with Entailment Scores" |
|
) |
|
|
|
highlighted_discarded = highlight_common_words_dict( |
|
common_grams, |
|
self.pipeline.discarded_sentences, |
|
"Discarded Paraphrased Sentences with Entailment Scores" |
|
) |
|
|
|
execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>" |
|
self.highlight_info = highlight_info |
|
self.common_grams = common_grams |
|
|
|
return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time |
|
|
|
def handle_masking(self) -> Tuple[List[go.Figure], str]: |
|
"""Wrapper for masking that generates visualization trees""" |
|
start_time = time.time() |
|
|
|
masking_results = self.pipeline.Masking() |
|
trees = [] |
|
highlight_info = self.highlight_info |
|
common_grams = self.common_grams |
|
sentence_to_masked = {} |
|
|
|
|
|
original_sentence = None |
|
|
|
|
|
for strategy, sentence_dict in masking_results.items(): |
|
for sent, data in sentence_dict.items(): |
|
if sent not in sentence_to_masked: |
|
sentence_to_masked[sent] = [] |
|
try: |
|
if not isinstance(data, dict): |
|
print(f"[ERROR] Data is not a dictionary for {sent} with strategy {strategy}") |
|
continue |
|
|
|
masked_sentence = data.get("masked_sentence", "") |
|
if masked_sentence: |
|
sentence_to_masked[sent].append((masked_sentence, strategy)) |
|
except Exception as e: |
|
print(f"Error processing {strategy} for sentence {sent}: {e}") |
|
|
|
for original_sentence, masked_sentences_data in sentence_to_masked.items(): |
|
if not masked_sentences_data: |
|
continue |
|
masked_sentences = [ms[0] for ms in masked_sentences_data] |
|
strategies = [ms[1] for ms in masked_sentences_data] |
|
try: |
|
|
|
fig = generate_subplot1( |
|
original_sentence, |
|
masked_sentences, |
|
strategies, |
|
highlight_info, |
|
common_grams |
|
) |
|
trees.append(fig) |
|
except Exception as e: |
|
print(f"Error generating multi-strategy tree: {e}") |
|
trees.append(go.Figure()) |
|
|
|
|
|
while len(trees) < 10: |
|
trees.append(go.Figure()) |
|
|
|
execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>" |
|
|
|
return trees[:10] + [execution_time] |
|
|
|
def handle_sampling(self) -> Tuple[List[go.Figure], str]: |
|
"""Wrapper for sampling that generates visualization trees""" |
|
start_time = time.time() |
|
sampling_results = self.pipeline.Sampling() |
|
trees = [] |
|
|
|
|
|
organized_results = {} |
|
|
|
|
|
for sampling_strategy, masking_dict in sampling_results.items(): |
|
for masking_strategy, sentences in masking_dict.items(): |
|
for original_sentence, data in sentences.items(): |
|
if original_sentence not in organized_results: |
|
organized_results[original_sentence] = {} |
|
|
|
if masking_strategy not in organized_results[original_sentence]: |
|
organized_results[original_sentence][masking_strategy] = { |
|
"masked_sentence": data.get("masked_sentence", ""), |
|
"sampled_sentences": {} |
|
} |
|
|
|
|
|
organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "") |
|
|
|
for original_sentence, data in organized_results.items(): |
|
masked_sentences = [] |
|
all_sampled_sentences = [] |
|
|
|
for masking_strategy, masking_data in list(data.items())[:3]: |
|
masked_sentence = masking_data.get("masked_sentence", "") |
|
if masked_sentence: |
|
masked_sentences.append(masked_sentence) |
|
|
|
for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items(): |
|
if sampled_sentence: |
|
all_sampled_sentences.append(sampled_sentence) |
|
|
|
if masked_sentences: |
|
try: |
|
fig = generate_subplot2( |
|
masked_sentences, |
|
all_sampled_sentences, |
|
self.highlight_info, |
|
self.common_grams |
|
) |
|
trees.append(fig) |
|
except Exception as e: |
|
print(f"Error generating subplot for {original_sentence}: {e}") |
|
trees.append(go.Figure()) |
|
|
|
while len(trees) < 10: |
|
trees.append(go.Figure()) |
|
|
|
execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>" |
|
|
|
return trees[:10] + [execution_time] |
|
|
|
def handle_reparaphrasing(self) -> Tuple[List[str], str]: |
|
"""Wrapper for re-paraphrasing that formats results as HTML""" |
|
start_time = time.time() |
|
|
|
results = self.pipeline.re_paraphrasing() |
|
html_outputs = [] |
|
|
|
|
|
for sampling_strategy, masking_dict in results.items(): |
|
for masking_strategy, sentences in masking_dict.items(): |
|
for original_sent, data in sentences.items(): |
|
if data["re_paraphrased_sentences"]: |
|
html = reparaphrased_sentences_html(data["re_paraphrased_sentences"]) |
|
html_outputs.append(html) |
|
|
|
|
|
while len(html_outputs) < 120: |
|
html_outputs.append("") |
|
|
|
execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>" |
|
|
|
return html_outputs[:120] + [execution_time] |
|
|
|
|
|
def create_gradio_interface(config): |
|
"""Creates the Gradio interface with the updated pipeline""" |
|
interface = WatermarkerInterface(config) |
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
|
|
|
demo.css = """ |
|
/* Set fixed height for the reparaphrased tabs container only */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"], |
|
.gradio-container .tabs[id="sampling-tabs"] { |
|
overflow-x: hidden; |
|
white-space: normal; |
|
border-radius: 8px; |
|
max-height: 600px; /* Set fixed height for the entire tabs component */ |
|
overflow-y: auto; /* Enable vertical scrolling inside the container */ |
|
} |
|
|
|
/* Tab content styling for reparaphrased and sampling tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tabitem, |
|
.gradio-container .tabs[id="sampling-tabs"] .tabitem { |
|
overflow-x: hidden; |
|
white-space: normal; |
|
display: block; |
|
border-radius: 8px; |
|
} |
|
|
|
/* Make the tab navigation fixed at the top for scrollable tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav { |
|
display: flex; |
|
overflow-x: auto; |
|
white-space: nowrap; |
|
scrollbar-width: thin; |
|
border-radius: 8px; |
|
scrollbar-color: #888 #f1f1f1; |
|
position: sticky; |
|
top: 0; |
|
background: white; |
|
z-index: 100; |
|
} |
|
|
|
/* Dropdown menu for scrollable tabs styling */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown { |
|
position: relative; |
|
display: inline-block; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content { |
|
display: none; |
|
position: absolute; |
|
background-color: #f9f9f9; |
|
min-width: 160px; |
|
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); |
|
z-index: 1; |
|
max-height: 300px; |
|
overflow-y: auto; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content { |
|
display: block; |
|
} |
|
|
|
/* Scrollbar styling for scrollable tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar { |
|
height: 8px; |
|
border-radius: 8px; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track { |
|
background: #f1f1f1; |
|
border-radius: 8px; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb { |
|
background: #888; |
|
border-radius: 8px; |
|
} |
|
|
|
/* Tab button styling for scrollable tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item { |
|
flex: 0 0 auto; |
|
border-radius: 8px; |
|
} |
|
|
|
/* Plot container styling specifically for sampling tabs */ |
|
.gradio-container .tabs[id="sampling-tabs"] .plot-container { |
|
min-height: 600px; |
|
max-height: 1800px; |
|
overflow-y: auto; |
|
} |
|
|
|
/* Ensure text wraps in HTML components */ |
|
.gradio-container .prose { |
|
white-space: normal; |
|
word-wrap: break-word; |
|
overflow-wrap: break-word; |
|
} |
|
|
|
/* Dropdown button styling for scrollable tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button { |
|
background-color: #f0f0f0; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
padding: 5px 10px; |
|
cursor: pointer; |
|
margin: 2px; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover { |
|
background-color: #e0e0e0; |
|
} |
|
|
|
/* Style dropdown content items for scrollable tabs */ |
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div { |
|
padding: 8px 12px; |
|
cursor: pointer; |
|
} |
|
|
|
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover, |
|
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover { |
|
background-color: #e0e0e0; |
|
} |
|
|
|
/* Custom styling for execution time display */ |
|
.execution-time { |
|
text-align: right; |
|
padding: 8px 16px; |
|
font-family: inherit; |
|
color: #555; |
|
font-size: 0.9rem; |
|
font-style: italic; |
|
margin-left: auto; |
|
width: 100%; |
|
border-top: 1px solid #eee; |
|
margin-top: 8px; |
|
} |
|
|
|
/* Layout for section headers with execution time */ |
|
.section-header { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
width: 100%; |
|
margin-bottom: 12px; |
|
} |
|
|
|
.section-header h3 { |
|
margin: 0; |
|
} |
|
""" |
|
gr.Markdown("# **AIISC Watermarking Model**") |
|
|
|
with gr.Column(): |
|
gr.Markdown("## Input Prompt") |
|
user_input = gr.Textbox( |
|
label="Enter Your Prompt", |
|
placeholder="Type your text here..." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis") |
|
with gr.Column(scale=1): |
|
step1_time = gr.HTML() |
|
|
|
paraphrase_button = gr.Button("Generate Paraphrases") |
|
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Accepted Paraphrased Sentences"): |
|
highlighted_accepted_sentences = gr.HTML() |
|
with gr.TabItem("Discarded Paraphrased Sentences"): |
|
highlighted_discarded_sentences = gr.HTML() |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
gr.Markdown("## Step 2: Where to Mask?") |
|
with gr.Column(scale=1): |
|
step2_time = gr.HTML() |
|
|
|
masking_button = gr.Button("Apply Masking") |
|
gr.Markdown("### Masked Sentence Trees") |
|
tree1_plots = [] |
|
with gr.Tabs() as tree1_tabs: |
|
for i in range(10): |
|
with gr.TabItem(f"Masked Sentence {i+1}"): |
|
tree1 = gr.Plot() |
|
tree1_plots.append(tree1) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
gr.Markdown("## Step 3: How to Mask?") |
|
with gr.Column(scale=1): |
|
step3_time = gr.HTML() |
|
|
|
sampling_button = gr.Button("Sample Words") |
|
gr.Markdown("### Sampled Sentence Trees") |
|
|
|
tree2_plots = [] |
|
|
|
with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs: |
|
for i in range(10): |
|
with gr.TabItem(f"Sampled Sentence {i+1}"): |
|
|
|
with gr.Column(elem_classes=["plot-container"]): |
|
tree2 = gr.Plot() |
|
tree2_plots.append(tree2) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
gr.Markdown("## Step 4: Re-paraphrasing") |
|
with gr.Column(scale=1): |
|
step4_time = gr.HTML() |
|
|
|
reparaphrase_button = gr.Button("Re-paraphrase") |
|
gr.Markdown("### Reparaphrased Sentences") |
|
reparaphrased_sentences_tabs = [] |
|
with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs: |
|
for i in range(120): |
|
with gr.TabItem(f"Reparaphrased Batch {i+1}"): |
|
reparaphrased_sent_html = gr.HTML() |
|
reparaphrased_sentences_tabs.append(reparaphrased_sent_html) |
|
|
|
|
|
paraphrase_button.click( |
|
interface.handle_paraphrase, |
|
inputs=user_input, |
|
outputs=[ |
|
highlighted_user_prompt, |
|
highlighted_accepted_sentences, |
|
highlighted_discarded_sentences, |
|
step1_time |
|
] |
|
) |
|
|
|
masking_button.click( |
|
interface.handle_masking, |
|
inputs=None, |
|
outputs=tree1_plots + [step2_time] |
|
) |
|
|
|
sampling_button.click( |
|
interface.handle_sampling, |
|
inputs=None, |
|
outputs=tree2_plots + [step3_time] |
|
) |
|
|
|
reparaphrase_button.click( |
|
interface.handle_reparaphrasing, |
|
inputs=None, |
|
outputs=reparaphrased_sentences_tabs + [step4_time] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
project_root = Path(__file__).parent.parent |
|
config_path = project_root / "utils" / "config.yaml" |
|
config = load_config(config_path)['PECCAVI_TEXT'] |
|
|
|
create_gradio_interface(config).launch() |