|
import plotly.graph_objects as go |
|
import textwrap |
|
import re |
|
from collections import defaultdict |
|
|
|
def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams): |
|
""" |
|
Generates a subplot visualizing paraphrased and masked sentences in a tree structure. |
|
Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering. |
|
|
|
Args: |
|
paraphrased_sentence (str): The paraphrased sentence to be visualized. |
|
masked_sentences (list of str): A list of masked sentences to be visualized. |
|
strategies (list of str, optional): List of strategies used for each masked sentence. |
|
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. |
|
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. |
|
|
|
Returns: |
|
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. |
|
""" |
|
|
|
if isinstance(masked_sentences, str): |
|
masked_sentences = [masked_sentences] |
|
nodes = [paraphrased_sentence] + masked_sentences |
|
nodes[0] += ' L0' |
|
if len(nodes) < 2: |
|
print("[ERROR] Insufficient nodes for visualization") |
|
return go.Figure() |
|
|
|
for i in range(1, len(nodes)): |
|
nodes[i] += ' L1' |
|
|
|
def apply_lcs_numbering(sentence, common_grams): |
|
""" |
|
Applies LCS numbering to the sentence based on the common_grams. |
|
|
|
Args: |
|
sentence (str): The sentence to which the LCS numbering should be applied. |
|
common_grams (list of tuples): A list of common grams to be replaced with LCS numbers. |
|
|
|
Returns: |
|
str: The sentence with LCS numbering applied. |
|
""" |
|
for idx, lcs in common_grams: |
|
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) |
|
return sentence |
|
|
|
|
|
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] |
|
|
|
|
|
def highlight_words(sentence, color_map): |
|
""" |
|
Highlights words in the sentence based on the color_map. |
|
|
|
Args: |
|
sentence (str): The sentence where the words will be highlighted. |
|
color_map (dict): A dictionary mapping words to their colors. |
|
|
|
Returns: |
|
str: The sentence with highlighted words. |
|
""" |
|
for word, color in color_map.items(): |
|
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) |
|
return sentence |
|
|
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
global_color_map = dict(highlight_info) |
|
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] |
|
|
|
def get_levels_and_edges(nodes, strategies=None): |
|
""" |
|
Determines tree levels and creates edges dynamically. |
|
|
|
Args: |
|
nodes (list of str): The nodes representing the sentences. |
|
strategies (list of str, optional): The strategies used for each edge. |
|
|
|
Returns: |
|
tuple: A tuple containing two dictionaries: |
|
- levels: A dictionary mapping node indices to their levels. |
|
- edges: A list of edges where each edge is represented by a tuple of node indices. |
|
""" |
|
levels = {} |
|
edges = [] |
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
root_node = next((i for i, level in levels.items() if level == 0), 0) |
|
for i, level in levels.items(): |
|
if level == 1: |
|
edges.append((root_node, i)) |
|
|
|
return levels, edges |
|
|
|
|
|
levels, edges = get_levels_and_edges(nodes, strategies) |
|
max_level = max(levels.values(), default=0) |
|
|
|
|
|
positions = {} |
|
level_heights = defaultdict(int) |
|
for node, level in levels.items(): |
|
level_heights[level] += 1 |
|
|
|
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} |
|
x_gap = 2 |
|
l1_y_gap = 10 |
|
|
|
for node, level in levels.items(): |
|
if level == 1: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
else: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
y_offsets[level] += 1 |
|
|
|
def color_highlighted_words(node, color_map): |
|
""" |
|
Colors the highlighted words in the node text. |
|
|
|
Args: |
|
node (str): The node text to be highlighted. |
|
color_map (dict): A dictionary mapping words to their colors. |
|
|
|
Returns: |
|
str: The node text with highlighted words. |
|
""" |
|
parts = re.split(r'(\{\{.*?\}\})', node) |
|
colored_parts = [] |
|
for part in parts: |
|
match = re.match(r'\{\{(.*?)\}\}', part) |
|
if match: |
|
word = match.group(1) |
|
color = color_map.get(word, 'black') |
|
colored_parts.append(f"<span style='color: {color};'>{word}</span>") |
|
else: |
|
colored_parts.append(part) |
|
return ''.join(colored_parts) |
|
|
|
|
|
default_edge_texts = [ |
|
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking", |
|
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", |
|
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling", |
|
"Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling", |
|
"Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling" |
|
] |
|
|
|
if len(nodes) < 2: |
|
print("[ERROR] Insufficient nodes for visualization") |
|
return go.Figure() |
|
|
|
|
|
fig1 = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
colored_node = color_highlighted_words(node, global_color_map) |
|
x, y = positions[i] |
|
fig1.add_trace(go.Scatter( |
|
x=[-x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=20, color='blue', line=dict(color='black', width=2)), |
|
hoverinfo='none' |
|
)) |
|
fig1.add_annotation( |
|
x=-x, |
|
y=y, |
|
text=colored_node, |
|
showarrow=False, |
|
xshift=15, |
|
align="center", |
|
font=dict(size=12), |
|
bordercolor='black', |
|
borderwidth=2, |
|
borderpad=4, |
|
bgcolor='white', |
|
width=400, |
|
height=100 |
|
) |
|
|
|
|
|
for i, edge in enumerate(edges): |
|
x0, y0 = positions[edge[0]] |
|
x1, y1 = positions[edge[1]] |
|
|
|
|
|
if strategies and i < len(strategies): |
|
edge_text = strategies[i] |
|
else: |
|
edge_text = default_edge_texts[i % len(default_edge_texts)] |
|
|
|
fig1.add_trace(go.Scatter( |
|
x=[-x0, -x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=1) |
|
)) |
|
|
|
|
|
mid_x = (-x0 + -x1) / 2 |
|
mid_y = (y0 + y1) / 2 |
|
|
|
|
|
text_y_position = mid_y + 0.8 |
|
|
|
|
|
fig1.add_annotation( |
|
x=mid_x, |
|
y=text_y_position, |
|
text=edge_text, |
|
showarrow=False, |
|
font=dict(size=12), |
|
align="center" |
|
) |
|
|
|
fig1.update_layout( |
|
showlegend=False, |
|
margin=dict(t=50, b=50, l=50, r=50), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=800 + max_level * 200, |
|
height=300 + len(nodes) * 100, |
|
plot_bgcolor='rgba(240,240,240,0.2)', |
|
paper_bgcolor='white' |
|
) |
|
|
|
return fig1 |
|
|
|
def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams): |
|
""" |
|
Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure. |
|
Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques. |
|
|
|
Args: |
|
masked_sentences (list of str): A list of masked sentences to be visualized as root nodes. |
|
sampled_sentences (list of str): A list of sampled sentences derived from masked sentences. |
|
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. |
|
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. |
|
|
|
Returns: |
|
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. |
|
""" |
|
|
|
sampling_techniques = [ |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling" |
|
] |
|
|
|
|
|
num_masked = len(masked_sentences) |
|
num_sampled_per_masked = len(sampling_techniques) |
|
total_nodes = num_masked + (num_masked * num_sampled_per_masked) |
|
|
|
|
|
nodes = [] |
|
|
|
nodes.extend([s + ' L0' for s in masked_sentences]) |
|
|
|
|
|
|
|
sampled_nodes = [] |
|
|
|
|
|
expected_sampled_count = num_masked * num_sampled_per_masked |
|
if len(sampled_sentences) < expected_sampled_count: |
|
|
|
print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}") |
|
while len(sampled_sentences) < expected_sampled_count: |
|
sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}") |
|
|
|
|
|
for s in sampled_sentences[:expected_sampled_count]: |
|
sampled_nodes.append(s + ' L1') |
|
|
|
nodes.extend(sampled_nodes) |
|
|
|
def apply_lcs_numbering(sentence, common_grams): |
|
""" |
|
Applies LCS numbering to the sentence based on the common_grams. |
|
""" |
|
for idx, lcs in common_grams: |
|
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) |
|
return sentence |
|
|
|
|
|
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] |
|
|
|
def highlight_words(sentence, color_map): |
|
""" |
|
Highlights words in the sentence based on the color_map. |
|
""" |
|
for word, color in color_map.items(): |
|
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) |
|
return sentence |
|
|
|
|
|
def color_highlighted_words(node, color_map): |
|
""" |
|
Colors the highlighted words in the node text. |
|
""" |
|
parts = re.split(r'(\{\{.*?\}\})', node) |
|
colored_parts = [] |
|
for part in parts: |
|
match = re.match(r'\{\{(.*?)\}\}', part) |
|
if match: |
|
word = match.group(1) |
|
color = color_map.get(word, 'black') |
|
colored_parts.append(f"<span style='color: {color};'>{word}</span>") |
|
else: |
|
colored_parts.append(part) |
|
return ''.join(colored_parts) |
|
|
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
global_color_map = dict(highlight_info) |
|
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes] |
|
|
|
|
|
def get_levels_and_edges(nodes): |
|
levels = {} |
|
edges = [] |
|
|
|
|
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
for masked_idx in range(num_masked): |
|
|
|
for technique_idx in range(num_sampled_per_masked): |
|
sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx |
|
if sampled_idx < len(nodes): |
|
edges.append((masked_idx, sampled_idx)) |
|
|
|
return levels, edges |
|
|
|
levels, edges = get_levels_and_edges(nodes) |
|
|
|
|
|
positions = {} |
|
|
|
|
|
root_x_spacing = 0 |
|
root_y_spacing = 8.0 |
|
|
|
|
|
sampled_x = 3 |
|
|
|
|
|
root_y_start = -(num_masked - 1) * root_y_spacing / 2 |
|
for i in range(num_masked): |
|
positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing) |
|
|
|
|
|
for masked_idx in range(num_masked): |
|
root_y = positions[masked_idx][1] |
|
|
|
|
|
children_y_spacing = 1.5 |
|
children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2 |
|
|
|
|
|
for technique_idx in range(num_sampled_per_masked): |
|
child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx |
|
child_y = children_y_start + technique_idx * children_y_spacing |
|
positions[child_idx] = (sampled_x, child_y) |
|
|
|
|
|
fig2 = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
x, y = positions[i] |
|
|
|
|
|
node_color = 'blue' if levels[i] == 0 else 'green' |
|
|
|
|
|
fig2.add_trace(go.Scatter( |
|
x=[x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=20, color=node_color, line=dict(color='black', width=2)), |
|
hoverinfo='none' |
|
)) |
|
|
|
|
|
colored_node = color_highlighted_words(node, global_color_map) |
|
|
|
fig2.add_annotation( |
|
x=x, |
|
y=y, |
|
text=colored_node, |
|
showarrow=False, |
|
xshift=15, |
|
align="left", |
|
font=dict(size=12), |
|
bordercolor='black', |
|
borderwidth=2, |
|
borderpad=4, |
|
bgcolor='white', |
|
width=400, |
|
height=100 |
|
) |
|
|
|
|
|
for i, (src, dst) in enumerate(edges): |
|
x0, y0 = positions[src] |
|
x1, y1 = positions[dst] |
|
|
|
|
|
fig2.add_trace(go.Scatter( |
|
x=[x0, x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=1) |
|
)) |
|
|
|
|
|
|
|
parent_idx = src |
|
technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i) |
|
technique_label = sampling_techniques[technique_count % len(sampling_techniques)] |
|
|
|
|
|
mid_x = (x0 + x1) / 2 |
|
mid_y = (y0 + y1) / 2 |
|
|
|
|
|
label_offset = 0.1 |
|
|
|
fig2.add_annotation( |
|
x=mid_x, |
|
y=mid_y + label_offset, |
|
text=technique_label, |
|
showarrow=False, |
|
font=dict(size=8), |
|
align="center" |
|
) |
|
|
|
|
|
fig2.update_layout( |
|
showlegend=False, |
|
margin=dict(t=20, b=20, l=20, r=20), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=1200, |
|
height=2000, |
|
plot_bgcolor='rgba(240,240,240,0.2)', |
|
paper_bgcolor='white' |
|
|
|
) |
|
|
|
return fig2 |
|
|
|
if __name__ == "__main__": |
|
paraphrased_sentence = "The quick brown fox jumps over the lazy dog." |
|
masked_sentences = [ |
|
"A fast brown fox leaps over the lazy dog.", |
|
"A quick brown fox hops over a lazy dog." |
|
] |
|
highlight_info = [ |
|
("quick", "red"), |
|
("brown", "green"), |
|
("fox", "blue"), |
|
("lazy", "purple") |
|
] |
|
common_grams = [ |
|
(1, "quick brown fox"), |
|
(2, "lazy dog") |
|
] |
|
|
|
fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams) |
|
fig1.show() |
|
|
|
sampled_sentence = ["A fast brown fox jumps over a lazy dog."] |
|
|
|
|
|
fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams) |
|
fig2.show() |
|
|