Spaces:
Sleeping
Sleeping
""" | |
Author: Khanh Phan | |
Date: 2024-12-04 | |
""" | |
import colorsys | |
import json | |
import re | |
import gradio as gr | |
import openai | |
from transformers import pipeline | |
from src.application.config import ( | |
AZUREOPENAI_CLIENT, | |
ENTITY_BRIGHTNESS, | |
ENTITY_DARKEN_COLOR, | |
ENTITY_LIGHTEN_COLOR, | |
ENTITY_SATURATION, | |
GPT_ENTITY_MODEL, | |
) | |
ner_pipeline = pipeline("ner") | |
def extract_entities_gpt( | |
original_text, | |
compared_text, | |
text_generation_model=GPT_ENTITY_MODEL, | |
) -> str: | |
""" | |
Extracts entity pairs with significantly different meanings between | |
two texts using a GPT model. | |
Args: | |
original_text (str): The original text. | |
compared_text (str): The paraphrased or compared text. | |
text_generation_model (str, optional): The GPT model | |
to use for entity extraction. | |
Returns: | |
str: The JSON-like string containing the extracted entity pairs, | |
or an empty string if an error occurs. | |
""" | |
# Construct the prompt for the GPT model. | |
# TODO: Move to config or prompt file | |
prompt = f""" | |
Compare the ORIGINAL TEXT and the COMPARED TEXT. | |
Find entity pairs with significantly different meanings after paraphrasing. | |
Focus only on these significantly changed entities. These include: | |
* **Numerical changes:** e.g., "five" -> "ten," "10%" -> "50%" | |
* **Time changes:** e.g., "Monday" -> "Sunday," "10th" -> "21st" | |
* **Name changes:** e.g., "Tokyo" -> "New York," "Japan" -> "Japanese" | |
* **Opposite meanings:** e.g., "increase" -> "decrease," "good" -> "bad" | |
* **Semantically different words:** e.g., "car" -> "truck," "walk" -> "run" | |
Exclude entities where the meaning remains essentially the same, | |
even if the wording is different | |
(e.g., "big" changed to "large," "house" changed to "residence"). | |
Also exclude purely stylistic changes that don't affect the core meaning. | |
Output the extracted entity pairs, one pair per line, | |
in the following JSON-like list format without wrapping characters: | |
[ | |
["ORIGINAL_TEXT_entity_1", "COMPARED_TEXT_entity_1"], | |
["ORIGINAL_TEXT_entity_2", "COMPARED_TEXT_entity_2"] | |
] | |
If there are no entities that satisfy above condition, output empty list "[]". | |
--- | |
# ORIGINAL TEXT: | |
{original_text} | |
--- | |
# COMPARED TEXT: | |
{compared_text} | |
""" | |
# Generate text using the selected model | |
try: | |
# Send the prompt to the GPT model and get the response. | |
response = AZUREOPENAI_CLIENT.chat.completions.create( | |
model=text_generation_model, | |
messages=[{"role": "user", "content": prompt}], | |
) | |
# Extract the generated content from the response. | |
res = response.choices[0].message.content | |
except openai.OpenAIError as e: | |
print(f"Error interacting with OpenAI API: {e}") | |
res = "" | |
return res | |
def read_json(json_string: str) -> list[list[str, str]]: | |
""" | |
Parses a JSON string and returns a list of unique entity pairs. | |
Args: | |
json_string (str): The JSON string to parse. | |
Returns: | |
List[List[str, str]]: A list of unique entity pairs, | |
or an empty list if parsing fails. | |
""" | |
try: | |
# Attempt to parse the JSON string into a Python object | |
entities = json.loads(json_string) | |
# Remove duplicates pair of entities | |
unique_entities = [] | |
for inner_list in entities: | |
# Check if the current entity pair is already existed. | |
if inner_list not in unique_entities: | |
unique_entities.append(inner_list) | |
return unique_entities | |
except json.JSONDecodeError as e: | |
print(f"Error decoding JSON: {e}") | |
return [] | |
def set_color_brightness( | |
hex_color: str, | |
brightness_factor: float = ENTITY_LIGHTEN_COLOR, | |
) -> str: | |
""" | |
Lightens a HEX color by increasing its brightness in HSV space. | |
Args: | |
hex_color (str): The HEX color code (e.g., "#RRGGBB"). | |
factor (float, optional): The factor by which to increase brightness. | |
Returns: | |
str: The lightened HEX color code. | |
""" | |
# Remove the '#' prefix if present. | |
hex_color = hex_color.lstrip("#") | |
# Convert the HEX color to RGB (red, green, blue) integers. | |
r, g, b = ( | |
int(hex_color[0:2], 16), # Red component | |
int(hex_color[2:4], 16), # Green component | |
int(hex_color[4:6], 16), # Blue component | |
) | |
# Convert RGB to HSV (hue, saturation, value/brightness) | |
h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) | |
# Increase the brightness by the specified factor, but cap it at 1.0. | |
v = min(1.0, v * brightness_factor) | |
# Convert the modified HSV back to RGB | |
r, g, b = (int(c * 255) for c in colorsys.hsv_to_rgb(h, s, v)) | |
# Convert the RGB values back to a HEX color code. | |
return f"#{r:02x}{g:02x}{b:02x}" | |
def generate_colors(index: int, total_colors: int = 20) -> str: | |
""" | |
Generates a unique, evenly spaced color for each index using HSL. | |
Args: | |
index (int): The index for which to generate a color. | |
total_colors (int, optional): The total number of colors to | |
distribute evenly. Defaults to 20. | |
Returns: | |
str: A HEX color code representing the generated color. | |
""" | |
# Calculate the hue value based on the index and total number of colors. | |
# This ensures even distribution of hues across the color spectrum. | |
hue = index / total_colors # Spread hues in range [0,1] | |
# Convert HSL to RGB | |
r, g, b = colorsys.hls_to_rgb(hue, ENTITY_SATURATION, ENTITY_BRIGHTNESS) | |
# Scale the RGB values from [0, 1] to [0, 255] | |
r, g, b = int(r * 255), int(g * 255), int(b * 255) | |
# Convert to hex | |
return f"#{r:02x}{g:02x}{b:02x}" | |
def assign_colors_to_entities(entities: list) -> list[dict]: | |
""" | |
Assigns unique colors to each entity pair in a list. | |
Args: | |
entities (list): A list of entity pairs, | |
where each pair is a list of two strings. | |
Example: [["entity1_original", "entity1_compared"]] | |
Returns: | |
list: A list of dictionaries, | |
where each dictionary contains | |
- "color": the color of entity pair. | |
- "input": the original entity string. | |
- "source": the compared entity string. | |
""" | |
# Number of colors needed. | |
total_colors = len(entities) | |
# Assign colors to entities using their index. | |
entities_colors = [] | |
for index, entity in enumerate(entities): | |
color = generate_colors(index, total_colors) | |
# Append color and index to entities_colors | |
entities_colors.append( | |
{"color": color, "input": entity[0], "source": entity[1]}, | |
) | |
return entities_colors | |
def highlight_entities(text1: str, text2: str) -> list[dict]: | |
""" | |
Highlights entities with significant differences between | |
two texts by assigning them unique colors. | |
Args: | |
text1 (str): input text. | |
text2 (str): source text. | |
Returns: | |
list: A list of dictionaries, where each dictionary | |
contains the highlighted entity information (color, input, source) | |
or None if no significant entities are found or an error occurs. | |
""" | |
if text1 is None or text2 is None: | |
return None | |
# Extract entities with significant differences using a GPT model. | |
entities_text = extract_entities_gpt(text1, text2) | |
# Clean up the extracted entities string by removing wrapping characters. | |
entities_text = entities_text.replace("```json", "").replace("```", "") | |
# Parse the cleaned entities string into a Python list of entity pairs. | |
entities = read_json(entities_text) | |
# If no significant entities are found, return None. | |
if len(entities) == 0: | |
return None | |
# Assign unique colors to the extracted entities. | |
entities_with_colors = assign_colors_to_entities(entities) | |
return entities_with_colors | |
def apply_highlight( | |
text: str, | |
entities_with_colors: list[dict], | |
key: str = "input", | |
count: int = 0, | |
) -> tuple[str, list[int]]: | |
""" | |
Applies highlighting to specified entities within a text, | |
assigning them unique colors and index labels. | |
Args: | |
text (str): The text to highlight. | |
entities_with_colors (list): A list of dictionaries, | |
where each dictionary represents an entity and its color. | |
key (str, optional): The key in the entity dictionary that | |
contains the entity text to highlight. | |
count (int, optional): An offset to add to the index labels. | |
Returns: | |
tuple: | |
- A tuple containing the highlighted text (str). | |
- A list of index positions (list). | |
""" | |
if entities_with_colors is None: | |
return text, [] | |
# Start & end indices of highlighted entities. | |
all_starts = [] | |
all_ends = [] | |
highlighted_text = "" | |
temp_text = text | |
# Apply highlighting to each entity. | |
for index, entity in enumerate(entities_with_colors): | |
highlighted_text = "" | |
starts = [] | |
ends = [] | |
for m in re.finditer( | |
# Word boundaries (\b) and escape special characters | |
r"\b" + re.escape(entity[key]) + r"\b", | |
temp_text, | |
): | |
starts.append(m.start()) | |
ends.append(m.end()) | |
all_starts.extend(starts) | |
all_ends.extend(ends) | |
# Get the colors for each occurrence of the entity. | |
color = entities_with_colors[index]["color"] | |
# Lightened color for background text | |
entity_color = set_color_brightness( | |
color, | |
brightness_factor=ENTITY_LIGHTEN_COLOR, | |
) | |
# Darker color for background label (index) | |
label_color = set_color_brightness( | |
entity_color, | |
brightness_factor=ENTITY_DARKEN_COLOR, | |
) | |
# Apply highlighting to each occurrence of the entity. | |
prev_end = 0 | |
for start, end in zip(starts, ends): | |
# Non-highlighted text before the entity. | |
highlighted_text += temp_text[prev_end:start] | |
# Create the index label with the specified color and style. | |
index_label = ( | |
f'<span_style="background-color:{label_color};color:white;' | |
f"padding:1px_4px;border-radius:4px;font-size:12px;" | |
f'font-weight:bold;display:inline-block;margin-right:4px;">{index + 1 + count}</span>' # noqa: E501 | |
) | |
# Highlighted entity with the specified color and style. | |
highlighted_text += ( | |
f'<span_style="background-color:{entity_color};color:black;' | |
f'border-radius:3px;font-size:14px;display:inline-block;">' | |
f"{index_label}{temp_text[start:end]}</span>" | |
) | |
prev_end = end | |
# Append any remaining text after the last entity. | |
highlighted_text += temp_text[prev_end:] | |
# Update the temporary text with the highlighted text. | |
temp_text = highlighted_text | |
if highlighted_text == "": | |
return text, [] | |
# Get the index list of the highlighted text. | |
highlight_idx_list = get_index_list(highlighted_text) | |
return highlighted_text, highlight_idx_list | |
def get_index_list(highlighted_text: str) -> list[int]: | |
""" | |
Generates a list of indices of highlighted words within a text. | |
Args: | |
highlighted_text (str): The text containing highlighted words | |
wrapped in HTML-like span tags. | |
Returns: | |
list: A list of indices corresponding to the highlighted words. | |
An empty list if no highlighted words are found. | |
""" | |
highlighted_index = [] | |
start_index = None | |
end_index = None | |
words = highlighted_text.split() | |
for index, word in enumerate(words): | |
# Check if the word starts with a highlighted word. | |
if word.startswith("<span_style"): | |
start_index = index | |
# Check if the word ends with a closing span tag | |
if word.endswith("</span>"): | |
end_index = index | |
if start_index is not None: | |
# Add the range of indices to the result list. | |
highlighted_index.extend( | |
list( | |
range( | |
start_index, | |
end_index + 1, | |
), | |
), | |
) | |
start_index = None | |
end_index = None | |
return highlighted_index | |
def extract_entities(text: str): | |
""" | |
Extracts named entities from the given text. | |
Args: | |
text (str): The input text to extract entities from. | |
Returns: | |
list: A list of unique extracted entities (string). | |
""" | |
# Apply the Named Entity Recognition (NER) pipeline to the input text. | |
output = ner_pipeline(text) | |
# Extract words from the NER pipeline output. | |
words = extract_words(output) | |
# Combine subwords into complete words. | |
words = combine_subwords(words) | |
# Append the entities if it's not a duplicate. | |
entities = [] | |
for entity in words: | |
if entity not in entities: | |
entities.append(entity) | |
return entities | |
def extract_words(entities: list[dict]) -> list[str]: | |
""" | |
Extracts the words from a list of entities. | |
Args: | |
entities (list): A list of entities, | |
where each entity is expected to be a dictionary | |
containing a "word" key. | |
Returns: | |
list[str]: A list of words extracted from the entities. | |
""" | |
words = [] | |
for entity in entities: | |
words.append(entity["word"]) | |
return words | |
def combine_subwords(word_list): | |
""" | |
Combines subwords (indicated by "##") with the preceding word in a list. | |
Args: | |
word_list (list): A list of words, | |
where subwords are prefixed with "##". | |
Returns: | |
list: A new list with subwords combined with their preceding words | |
and hyphenated words combined. | |
""" | |
result = [] | |
i = 0 | |
while i < len(word_list): | |
if word_list[i].startswith("##"): | |
# Remove "##" and append the remaining to the previous word | |
result[-1] += word_list[i][2:] | |
elif i < len(word_list) - 2 and word_list[i + 1] == "-": | |
# Combine the current word, the hyphen, and the next word. | |
result.append(word_list[i] + word_list[i + 1] + word_list[i + 2]) | |
i += 2 # Skip the next two words (hyphen and the following word) | |
else: | |
# If neither a subword nor a hyphenated word, | |
# append the current word to the result list. | |
result.append(word_list[i]) | |
i += 1 | |
return result | |
original_text = """ | |
Title: UK pledges support for Ukraine with 100-year pact | |
Content: Sir Keir Starmer has pledged to put Ukraine in the "strongest | |
possible position" on a trip to Kyiv where he signed a "landmark" | |
100-year pact with the war-stricken country. The prime minister's | |
visit on Thursday was at one point marked by loud blasts and air | |
raid sirens after a reported Russian drone attack was intercepted | |
by Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
Volodymyr Zelensky said Ukraine would send its own "hello back". | |
An estimated one million people have been killed or wounded in the | |
war so far. As the invasion reaches the end of its third year, Ukraine | |
is losing territory in the east. Zelensky praised the UK's commitment | |
on Thursday, amid wider concerns that the US President-elect Donald | |
Trump, who is set to take office on Monday, could potentially reduce aid. | |
""" | |
compared_text = """ | |
Title: Japan pledges support for Ukraine with 100-year pact | |
Content: A leading Japanese figure has pledged to put Ukraine | |
in the "strongest possible position" on a trip to Kyiv where | |
they signed a "landmark" 100-year pact with the war-stricken country. | |
The visit on Thursday was at one point marked by loud blasts and air | |
raid sirens after a reported Russian drone attack was intercepted by | |
Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
Volodymyr Zelensky said Ukraine would send its own "hello back". | |
An estimated one million people have been killed or wounded in the | |
war so far. As the invasion reaches the end of its third year, Ukraine | |
is losing territory in the east. Zelensky praised Japan's commitment | |
on Thursday, amid wider concerns that the next US President, who is | |
set to take office on Monday, could potentially reduce aid. | |
""" | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
gr.Markdown("### Highlight Matching Parts Between Two Texts") | |
text1_input = gr.Textbox( | |
label="Text 1", | |
lines=5, | |
value=original_text, | |
) | |
text2_input = gr.Textbox( | |
label="Text 2", | |
lines=5, | |
value=compared_text, | |
) | |
submit_button = gr.Button("Highlight Matches") | |
output1 = gr.HTML("<br>" * 10) | |
output2 = gr.HTML("<br>" * 10) | |
submit_button.click( | |
fn=highlight_entities, | |
inputs=[text1_input, text2_input], | |
outputs=[output1, output2], | |
) | |
# Launch the Gradio app | |
demo.launch() | |