Spaces:
Sleeping
Sleeping
import colorsys | |
import json | |
import os | |
import re | |
import gradio as gr | |
import openai | |
from dotenv import load_dotenv | |
from transformers import pipeline | |
ner_pipeline = pipeline("ner") | |
load_dotenv() | |
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") | |
client = openai.AzureOpenAI( | |
api_version="2024-05-01-preview", # AZURE_OPENAI_API_VERSION, | |
api_key=AZURE_OPENAI_API_KEY, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
) | |
def extract_entities_gpt( | |
original_text, | |
compared_text, | |
text_generation_model="o1-mini", | |
): | |
# "gpt-4o-mini" or "o1-mini" | |
# Generate text using the selected models | |
prompt = f""" | |
Compare the ORIGINAL TEXT and the COMPARED TEXT. | |
Find entity pairs with significantly different meanings after paraphrasing. | |
Focus only on these significantly changed entities. These include: | |
* **Numerical changes:** e.g., "five" -> "ten," "10%" -> "50%" | |
* **Time changes:** e.g., "Monday" -> "Sunday," "10th" -> "21st" | |
* **Name changes:** e.g., "Tokyo" -> "New York," "Japan" -> "Japanese" | |
* **Opposite meanings:** e.g., "increase" -> "decrease," "good" -> "bad" | |
* **Semantically different words:** e.g., "car" -> "truck," "walk" -> "run" | |
Exclude entities where the meaning remains essentially the same, | |
even if the wording is different | |
(e.g., "big" changed to "large," "house" changed to "residence"). | |
Also exclude purely stylistic changes that don't affect the core meaning. | |
Output the extracted entity pairs, one pair per line, | |
in the following JSON-like list format without wrapping characters: | |
[ | |
["ORIGINAL_TEXT_entity_1", "COMPARED_TEXT_entity_1"], | |
["ORIGINAL_TEXT_entity_2", "COMPARED_TEXT_entity_2"] | |
] | |
If there are no entities that satisfy above condition, output empty list "[]". | |
--- | |
# ORIGINAL TEXT: | |
{original_text} | |
--- | |
# COMPARED TEXT: | |
{compared_text} | |
""" | |
# Generate text using the text generation model | |
# Generate text using the selected model | |
try: | |
response = client.chat.completions.create( | |
model=text_generation_model, | |
messages=[{"role": "user", "content": prompt}], | |
) | |
res = response.choices[0].message.content | |
except openai.OpenAIError as e: | |
print(f"Error interacting with OpenAI API: {e}") | |
res = "" | |
return res | |
def read_json(json_string) -> list[list[str]]: | |
try: | |
entities = json.loads(json_string) | |
# Remove duplicates pair of entities | |
unique_entities = [] | |
for inner_list in entities: | |
if inner_list not in unique_entities: | |
unique_entities.append(inner_list) | |
return unique_entities | |
except json.JSONDecodeError as e: | |
print(f"Error decoding JSON: {e}") | |
return [] | |
def lighten_color(hex_color, factor=1.8): | |
"""Lightens a HEX color by increasing its brightness in HSV space.""" | |
hex_color = hex_color.lstrip("#") | |
r, g, b = ( | |
int(hex_color[0:2], 16), | |
int(hex_color[2:4], 16), | |
int(hex_color[4:6], 16), | |
) | |
# Convert to HSV | |
h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) | |
v = min(1.0, v * factor) # Increase brightness | |
# Convert back to HEX | |
r, g, b = (int(c * 255) for c in colorsys.hsv_to_rgb(h, s, v)) | |
return f"#{r:02x}{g:02x}{b:02x}" | |
def darken_color(hex_color, factor=0.7): | |
"""Darkens a hex color by reducing its brightness in the HSV space.""" | |
hex_color = hex_color.lstrip("#") | |
r, g, b = ( | |
int(hex_color[0:2], 16), | |
int(hex_color[2:4], 16), | |
int(hex_color[4:6], 16), | |
) | |
# Convert to HSV to adjust brightness | |
h, s, v = colorsys.rgb_to_hsv(r / 255.0, g / 255.0, b / 255.0) | |
v = max(0, v * factor) # Reduce brightness | |
# Convert back to HEX | |
r, g, b = (int(c * 255) for c in colorsys.hsv_to_rgb(h, s, v)) | |
return f"#{r:02x}{g:02x}{b:02x}" | |
def generate_color(index, total_colors=20): | |
"""Generates a unique, evenly spaced color for each index using HSL.""" | |
hue = index / total_colors # Spread hues in range [0,1] | |
saturation = 0.65 # Keep colors vivid | |
lightness = 0.75 # Balanced brightness | |
# Convert HSL to RGB | |
r, g, b = colorsys.hls_to_rgb(hue, lightness, saturation) | |
r, g, b = int(r * 255), int(g * 255), int(b * 255) | |
return f"#{r:02x}{g:02x}{b:02x}" # Convert to hex | |
def assign_colors_to_entities(entities): | |
total_colors = len(entities) | |
# Assign colors to entities | |
entities_colors = [] | |
for index, entity in enumerate(entities): | |
color = generate_color(index, total_colors) | |
# append color and index to entities_colors | |
entities_colors.append( | |
{"color": color, "input": entity[0], "source": entity[1]}, | |
) | |
return entities_colors | |
def highlight_entities(text1, text2): | |
if text1 is None or text2 is None: | |
return None | |
entities_text = extract_entities_gpt(text1, text2) | |
# Clean up entities: remove wrapping characters | |
entities_text = entities_text.replace("```json", "").replace("```", "") | |
entities = read_json(entities_text) | |
if len(entities) == 0: | |
return None | |
# Assign colors to entities | |
entities_with_colors = assign_colors_to_entities(entities) | |
return entities_with_colors | |
def apply_highlight(text, entities_with_colors, key="input", count=0): | |
if entities_with_colors is None: | |
return text, [] | |
all_starts = [] | |
all_ends = [] | |
highlighted_text = "" | |
temp_text = text | |
for index, entity in enumerate(entities_with_colors): | |
highlighted_text = "" | |
# find a list of starts and ends of entity in text: | |
# starts = [m.start() for m in re.finditer(entity[key], temp_text)] | |
# ends = [m.end() for m in re.finditer(entity[key], temp_text)] | |
starts = [] | |
ends = [] | |
# "\b" is for bound a word | |
for m in re.finditer( | |
r"\b" + re.escape(entity[key]) + r"\b", | |
temp_text, | |
): | |
starts.append(m.start()) | |
ends.append(m.end()) | |
all_starts.extend(starts) | |
all_ends.extend(ends) | |
color = entities_with_colors[index]["color"] | |
entity_color = lighten_color( | |
color, | |
factor=2.2, | |
) # Lightened color for background text | |
label_color = darken_color( | |
entity_color, | |
factor=0.7, | |
) # Darker color for background label (index) | |
# Apply highlighting to each entity | |
prev_end = 0 | |
for start, end in zip(starts, ends): | |
# Append non-highlighted text | |
highlighted_text += temp_text[prev_end:start] | |
# Style the index as a label | |
index_label = ( | |
f'<span_style="background-color:{label_color};color:white;' | |
f"padding:1px_4px;border-radius:4px;font-size:12px;" | |
f'font-weight:bold;display:inline-block;margin-right:4px;">{index + 1 + count}</span>' # noqa: E501 | |
) | |
# Append highlighted text with index label | |
highlighted_text += ( | |
f'\n<span_style="background-color:{entity_color};color:black;' | |
f'border-radius:3px;font-size:14px;display:inline-block;">' | |
f"{index_label}{temp_text[start:end]}</span>\n" | |
) | |
prev_end = end | |
highlighted_text += temp_text[prev_end:] | |
temp_text = highlighted_text | |
if highlighted_text == "": | |
return text, [] | |
highlight_idx_list = get_index_list(highlighted_text) | |
return highlighted_text, highlight_idx_list | |
def get_index_list(highlighted_text): | |
""" | |
Generates a list of indices between corresponding start and end indices. | |
Args: | |
starts: A list of starting indices. | |
ends: A list of ending indices. Must be the same length as starts. | |
Returns: | |
A list containing all indices within the specified ranges. | |
Returns an empty list if the input is invalid (e.g., different lengths, | |
end < start, etc.). | |
""" | |
highlighted_index = [] | |
words = highlighted_text.split() | |
for index, word in enumerate(words): | |
if word.startswith("<span_style"): | |
start_index = index | |
if word.endswith("</span>"): | |
end_index = index | |
highlighted_index.extend(list(range(start_index, end_index + 1))) | |
return highlighted_index | |
def extract_entities(text): | |
output = ner_pipeline(text) | |
words = extract_words(output) | |
words = combine_subwords(words) | |
# extract word in each entity and assign to a list of entities, | |
# connect words if there is no space between them | |
entities = [] | |
for entity in words: | |
if entity not in entities: | |
entities.append(entity) | |
return entities | |
def extract_words(entities): | |
""" | |
Extracts the words from a list of entities. | |
Args: | |
entities: A list of entities. | |
Returns: | |
A list of words extracted from the entities. | |
""" | |
words = [] | |
for entity in entities: | |
words.append(entity["word"]) | |
return words | |
def combine_subwords(word_list): | |
""" | |
Combines subwords (indicated by "##") with the preceding word in a list. | |
Args: | |
word_list: A list of words, where subwords are prefixed with "##". | |
Returns: | |
A new list with subwords combined with their preceding words. | |
""" | |
result = [] | |
i = 0 | |
while i < len(word_list): | |
if word_list[i].startswith("##"): | |
result[-1] += word_list[i][ | |
2: | |
] # Remove "##" and append to the previous word | |
elif ( | |
i < len(word_list) - 2 and word_list[i + 1] == "-" | |
): # Combine hyphenated words | |
result.append(word_list[i] + word_list[i + 1] + word_list[i + 2]) | |
i += 2 # Skip the next two words | |
else: | |
result.append(word_list[i]) | |
i += 1 | |
return result | |
original_text = """ | |
Title: UK pledges support for Ukraine with 100-year pact | |
Content: Sir Keir Starmer has pledged to put Ukraine in the "strongest | |
possible position" on a trip to Kyiv where he signed a "landmark" | |
100-year pact with the war-stricken country. The prime minister's | |
visit on Thursday was at one point marked by loud blasts and air | |
raid sirens after a reported Russian drone attack was intercepted | |
by Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
Volodymyr Zelensky said Ukraine would send its own "hello back". | |
An estimated one million people have been killed or wounded in the | |
war so far. As the invasion reaches the end of its third year, Ukraine | |
is losing territory in the east. Zelensky praised the UK's commitment | |
on Thursday, amid wider concerns that the US President-elect Donald | |
Trump, who is set to take office on Monday, could potentially reduce aid. | |
""" | |
compared_text = """ | |
Title: Japan pledges support for Ukraine with 100-year pact | |
Content: A leading Japanese figure has pledged to put Ukraine | |
in the "strongest possible position" on a trip to Kyiv where | |
they signed a "landmark" 100-year pact with the war-stricken country. | |
The visit on Thursday was at one point marked by loud blasts and air | |
raid sirens after a reported Russian drone attack was intercepted by | |
Ukraine's defence systems. Acknowledging the "hello" from Russia, | |
Volodymyr Zelensky said Ukraine would send its own "hello back". | |
An estimated one million people have been killed or wounded in the | |
war so far. As the invasion reaches the end of its third year, Ukraine | |
is losing territory in the east. Zelensky praised Japan's commitment | |
on Thursday, amid wider concerns that the next US President, who is | |
set to take office on Monday, could potentially reduce aid. | |
""" | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
gr.Markdown("### Highlight Matching Parts Between Two Paragraphs") | |
text1_input = gr.Textbox( | |
label="Paragraph 1", | |
lines=5, | |
value=original_text, | |
) | |
text2_input = gr.Textbox( | |
label="Paragraph 2", | |
lines=5, | |
value=compared_text, | |
) | |
submit_button = gr.Button("Highlight Matches") | |
output1 = gr.HTML("<br>" * 10) | |
output2 = gr.HTML("<br>" * 10) | |
submit_button.click( | |
fn=highlight_entities, | |
inputs=[text1_input, text2_input], | |
outputs=[output1, output2], | |
) | |
# Launch the Gradio app | |
demo.launch() | |