jgyasu's picture
Add entire pipeline
060ac52
import gradio as gr
from utils.watermark import Watermarker
from utils.config import load_config
from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
from renderers.tree import generate_subplot1, generate_subplot2
from pathlib import Path
import time
from typing import Dict, List, Tuple, Any
import plotly.graph_objects as go
class WatermarkerInterface:
def __init__(self, config):
self.pipeline = Watermarker(config)
self.common_grams = {}
self.highlight_info = []
self.masked_sentences = []
def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]:
"""Wrapper for paraphrasing that includes highlighting"""
start_time = time.time()
# Run paraphrasing
self.pipeline.Paraphrase(prompt)
# Step 1: Process the original sentence first
seen_ngrams = {} # Stores first occurrence index of each n-gram
original_indexed_ngrams = [] # Final indexed list for original
original_sentence = self.pipeline.user_prompt
original_ngrams = self.pipeline.common_grams.get(original_sentence, {})
# Step 1.1: Extract n-grams and their first occurrence index
ngram_occurrences = [
(min(indices, key=lambda x: x[0])[0], gram) # Get first index
for gram, indices in original_ngrams.items()
]
# Step 1.2: Sort n-grams based on their first occurrence
ngram_occurrences.sort()
# Step 1.3: Assign sequential indices
for idx, (position, gram) in enumerate(ngram_occurrences, start=1):
seen_ngrams[gram] = idx # Assign sequential index
original_indexed_ngrams.append((idx, gram))
print("Original Indexed N-grams:", original_indexed_ngrams)
#generate highlight_info
colors = ["red", "blue", "green", "purple", "orange"]
highlight_info = [
(ngram, colors[i % len(colors)])
for i, (index, ngram) in enumerate(original_indexed_ngrams)
]
common_grams = original_indexed_ngrams
self.highlight_info = highlight_info
self.common_grams = common_grams
# Step 2: Process paraphrased sentences and match indices
paraphrase_indexed_ngrams = {}
for sentence in self.pipeline.paraphrased_sentences:
sentence_ngrams = [] # Stores n-grams for this sentence
sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {})
for gram, indices in sentence_ngrams_dict.items():
first_occurrence = min(indices, key=lambda x: x[0])[0]
# Use the original's index if exists, otherwise assign a new one
if gram in seen_ngrams:
index = seen_ngrams[gram] # Use the same index as original
else:
index = len(seen_ngrams) + 1 # Assign new index
seen_ngrams[gram] = index # Store it
sentence_ngrams.append((index, gram))
sentence_ngrams.sort()
paraphrase_indexed_ngrams[sentence] = sentence_ngrams
print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams)
# Step 3: Generate highlighted versions using the renderer
highlighted_prompt = highlight_common_words(
common_grams,
[self.pipeline.user_prompt],
"Original Prompt with Highlighted Common Sequences"
)
highlighted_accepted = highlight_common_words_dict(
common_grams,
self.pipeline.selected_sentences,
"Accepted Paraphrased Sentences with Entailment Scores"
)
highlighted_discarded = highlight_common_words_dict(
common_grams,
self.pipeline.discarded_sentences,
"Discarded Paraphrased Sentences with Entailment Scores"
)
execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>"
self.highlight_info = highlight_info
self.common_grams = common_grams
return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time
def handle_masking(self) -> Tuple[List[go.Figure], str]:
"""Wrapper for masking that generates visualization trees"""
start_time = time.time()
masking_results = self.pipeline.Masking()
trees = []
highlight_info = self.highlight_info
common_grams = self.common_grams
sentence_to_masked = {}
# Create a consolidated figure with all strategies
original_sentence = None
# First pass - gather all sentences and strategies
for strategy, sentence_dict in masking_results.items():
for sent, data in sentence_dict.items():
if sent not in sentence_to_masked:
sentence_to_masked[sent] = []
try:
if not isinstance(data, dict):
print(f"[ERROR] Data is not a dictionary for {sent} with strategy {strategy}")
continue
masked_sentence = data.get("masked_sentence", "")
if masked_sentence:
sentence_to_masked[sent].append((masked_sentence, strategy))
except Exception as e:
print(f"Error processing {strategy} for sentence {sent}: {e}")
for original_sentence, masked_sentences_data in sentence_to_masked.items():
if not masked_sentences_data:
continue
masked_sentences = [ms[0] for ms in masked_sentences_data]
strategies = [ms[1] for ms in masked_sentences_data]
try:
fig = generate_subplot1(
original_sentence,
masked_sentences,
strategies,
highlight_info,
common_grams
)
trees.append(fig)
except Exception as e:
print(f"Error generating multi-strategy tree: {e}")
trees.append(go.Figure())
# Pad with empty plots if needed
while len(trees) < 10:
trees.append(go.Figure())
execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>"
return trees[:10] + [execution_time]
def handle_sampling(self) -> Tuple[List[go.Figure], str]:
"""Wrapper for sampling that generates visualization trees"""
start_time = time.time()
sampling_results = self.pipeline.Sampling()
trees = []
# Group sentences by original sentence
organized_results = {}
# Generate trees for each sampled sentence
for sampling_strategy, masking_dict in sampling_results.items():
for masking_strategy, sentences in masking_dict.items():
for original_sentence, data in sentences.items():
if original_sentence not in organized_results:
organized_results[original_sentence] = {}
if masking_strategy not in organized_results[original_sentence]:
organized_results[original_sentence][masking_strategy] = {
"masked_sentence": data.get("masked_sentence", ""), # Corrected reference
"sampled_sentences": {}
}
# Add this sampling result
organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "")
for original_sentence, data in organized_results.items():
masked_sentences = []
all_sampled_sentences = []
for masking_strategy, masking_data in list(data.items())[:3]: # Ensure this iteration is safe
masked_sentence = masking_data.get("masked_sentence", "")
if masked_sentence:
masked_sentences.append(masked_sentence)
for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items():
if sampled_sentence:
all_sampled_sentences.append(sampled_sentence)
if masked_sentences:
try:
fig = generate_subplot2(
masked_sentences,
all_sampled_sentences,
self.highlight_info,
self.common_grams
)
trees.append(fig)
except Exception as e:
print(f"Error generating subplot for {original_sentence}: {e}")
trees.append(go.Figure())
while len(trees) < 10:
trees.append(go.Figure())
execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>"
return trees[:10] + [execution_time]
def handle_reparaphrasing(self) -> Tuple[List[str], str]:
"""Wrapper for re-paraphrasing that formats results as HTML"""
start_time = time.time()
results = self.pipeline.re_paraphrasing()
html_outputs = []
# Generate HTML for each batch of re-paraphrased sentences
for sampling_strategy, masking_dict in results.items():
for masking_strategy, sentences in masking_dict.items():
for original_sent, data in sentences.items():
if data["re_paraphrased_sentences"]:
html = reparaphrased_sentences_html(data["re_paraphrased_sentences"])
html_outputs.append(html)
# Pad with empty HTML if needed
while len(html_outputs) < 120:
html_outputs.append("")
execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>"
return html_outputs[:120] + [execution_time]
def create_gradio_interface(config):
"""Creates the Gradio interface with the updated pipeline"""
interface = WatermarkerInterface(config)
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
#CSS to enable scrolling for reparaphrased sentences and sampling plots
demo.css = """
/* Set fixed height for the reparaphrased tabs container only */
.gradio-container .tabs[id="reparaphrased-tabs"],
.gradio-container .tabs[id="sampling-tabs"] {
overflow-x: hidden;
white-space: normal;
border-radius: 8px;
max-height: 600px; /* Set fixed height for the entire tabs component */
overflow-y: auto; /* Enable vertical scrolling inside the container */
}
/* Tab content styling for reparaphrased and sampling tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tabitem,
.gradio-container .tabs[id="sampling-tabs"] .tabitem {
overflow-x: hidden;
white-space: normal;
display: block;
border-radius: 8px;
}
/* Make the tab navigation fixed at the top for scrollable tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav {
display: flex;
overflow-x: auto;
white-space: nowrap;
scrollbar-width: thin;
border-radius: 8px;
scrollbar-color: #888 #f1f1f1;
position: sticky;
top: 0;
background: white;
z-index: 100;
}
/* Dropdown menu for scrollable tabs styling */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown {
position: relative;
display: inline-block;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content {
display: none;
position: absolute;
background-color: #f9f9f9;
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 1;
max-height: 300px;
overflow-y: auto;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content {
display: block;
}
/* Scrollbar styling for scrollable tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar {
height: 8px;
border-radius: 8px;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 8px;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb {
background: #888;
border-radius: 8px;
}
/* Tab button styling for scrollable tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item {
flex: 0 0 auto;
border-radius: 8px;
}
/* Plot container styling specifically for sampling tabs */
.gradio-container .tabs[id="sampling-tabs"] .plot-container {
min-height: 600px;
max-height: 1800px;
overflow-y: auto;
}
/* Ensure text wraps in HTML components */
.gradio-container .prose {
white-space: normal;
word-wrap: break-word;
overflow-wrap: break-word;
}
/* Dropdown button styling for scrollable tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button {
background-color: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
padding: 5px 10px;
cursor: pointer;
margin: 2px;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover {
background-color: #e0e0e0;
}
/* Style dropdown content items for scrollable tabs */
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div {
padding: 8px 12px;
cursor: pointer;
}
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover,
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover {
background-color: #e0e0e0;
}
/* Custom styling for execution time display */
.execution-time {
text-align: right;
padding: 8px 16px;
font-family: inherit;
color: #555;
font-size: 0.9rem;
font-style: italic;
margin-left: auto;
width: 100%;
border-top: 1px solid #eee;
margin-top: 8px;
}
/* Layout for section headers with execution time */
.section-header {
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
margin-bottom: 12px;
}
.section-header h3 {
margin: 0;
}
"""
gr.Markdown("# **AIISC Watermarking Model**")
with gr.Column():
gr.Markdown("## Input Prompt")
user_input = gr.Textbox(
label="Enter Your Prompt",
placeholder="Type your text here..."
)
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
with gr.Column(scale=1):
step1_time = gr.HTML()
paraphrase_button = gr.Button("Generate Paraphrases")
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
with gr.Tabs():
with gr.TabItem("Accepted Paraphrased Sentences"):
highlighted_accepted_sentences = gr.HTML()
with gr.TabItem("Discarded Paraphrased Sentences"):
highlighted_discarded_sentences = gr.HTML()
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Step 2: Where to Mask?")
with gr.Column(scale=1):
step2_time = gr.HTML()
masking_button = gr.Button("Apply Masking")
gr.Markdown("### Masked Sentence Trees")
tree1_plots = []
with gr.Tabs() as tree1_tabs:
for i in range(10):
with gr.TabItem(f"Masked Sentence {i+1}"):
tree1 = gr.Plot()
tree1_plots.append(tree1)
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Step 3: How to Mask?")
with gr.Column(scale=1):
step3_time = gr.HTML()
sampling_button = gr.Button("Sample Words")
gr.Markdown("### Sampled Sentence Trees")
tree2_plots = []
# Add elem_id to make this tab container scrollable
with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs:
for i in range(10):
with gr.TabItem(f"Sampled Sentence {i+1}"):
# Add a custom class to the container to enable proper styling
with gr.Column(elem_classes=["plot-container"]):
tree2 = gr.Plot()
tree2_plots.append(tree2)
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Step 4: Re-paraphrasing")
with gr.Column(scale=1):
step4_time = gr.HTML()
reparaphrase_button = gr.Button("Re-paraphrase")
gr.Markdown("### Reparaphrased Sentences")
reparaphrased_sentences_tabs = []
with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs:
for i in range(120):
with gr.TabItem(f"Reparaphrased Batch {i+1}"):
reparaphrased_sent_html = gr.HTML()
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
# Connect the interface functions to the buttons
paraphrase_button.click(
interface.handle_paraphrase,
inputs=user_input,
outputs=[
highlighted_user_prompt,
highlighted_accepted_sentences,
highlighted_discarded_sentences,
step1_time
]
)
masking_button.click(
interface.handle_masking,
inputs=None,
outputs=tree1_plots + [step2_time]
)
sampling_button.click(
interface.handle_sampling,
inputs=None,
outputs=tree2_plots + [step3_time]
)
reparaphrase_button.click(
interface.handle_reparaphrasing,
inputs=None,
outputs=reparaphrased_sentences_tabs + [step4_time]
)
return demo
if __name__ == "__main__":
project_root = Path(__file__).parent.parent
config_path = project_root / "utils" / "config.yaml"
config = load_config(config_path)['PECCAVI_TEXT']
create_gradio_interface(config).launch()