Zaherrr's picture
Update app.py
38ff098 verified
import gradio as gr
from datasets import load_dataset
from PIL import Image
import json
import torch
from torchvision import transforms
from transformers import DonutProcessor, VisionEncoderDecoderModel
# import subprocess
# # Install mlflow and dagshub without dependencies
# subprocess.run(['pip', 'install', '--no-deps', 'mlflow'])
# subprocess.run(['pip', 'install', '--no-deps', 'dagshub'])
import dagshub
import mlflow
import time
import os
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# token = user_secrets.get_secret("dags_hub_token")
# from google.colab import userdata
# token = userdata.get('dags_hub_token')
token = os.getenv('dags_hub_token')
dagshub.auth.add_app_token(token)
dagshub.init(repo_owner='zaheramasha',
repo_name='Finetuning_paligemma_Zaka_capstone',
mlflow=True)
# Define the MLflow run ID and artifact path
run_id = "c41cfd149a8c44f3a92d8e0f1253af35" # Donut model trained on the PyvizAndMarkMap dataset for 27 epochs reaching a train loss of 0.168
run_id = "89bafd5e525a4d3e9d004e13c9574198" # Donut model trained on the PyvizAndMarkMap dataset for 27 + 51 = 78 epochs reaching a train loss of 0.0353. This run was a continuation of the 27 epoch one
artifact_path = "Donut_model/model"
# Create the model URI using the run ID and artifact path
model_uri = f"runs:/{run_id}/{artifact_path}"
print(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=artifact_path))
# Load the model and processors from the MLflow artifact
# loaded_model_bundle = mlflow.transformers.load_model(artifact_path=artifact_path, run_id=run_id)
# for the 20 epochs trained model
model_uri = f"mlflow-artifacts:/0a5d0550f55c4169b80cd6439556be8b/c41cfd149a8c44f3a92d8e0f1253af35/artifacts/Donut_model"
# for the fully 70 epochs trained model
model_uri = f"mlflow-artifacts:/17c375f6eab34c63b2a2e7792803132e/89bafd5e525a4d3e9d004e13c9574198/artifacts/Donut_model"
loaded_model_bundle = mlflow.transformers.load_model(model_uri=model_uri, device='cpu')#'cuda')
model = loaded_model_bundle.model
processor = DonutProcessor(tokenizer=loaded_model_bundle.tokenizer, feature_extractor=loaded_model_bundle.feature_extractor, image_processor=loaded_model_bundle.image_processor)
print(model.config.encoder.image_size)
print(model.config.decoder.max_length)
import json
import random
from typing import Any, List, Tuple, Dict
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict, concatenate_datasets
from PIL import Image, ImageFilter
from torchvision import transforms
import re
# Load and split the dataset
Pyviz_dataset = load_dataset("Zaherrr/OOP_KG_Pyviz_Synthetic_Dataset", revision="Sorted_edges")
MarkMap_dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset")
combined_dataset = concatenate_datasets([Pyviz_dataset['data'], MarkMap_dataset['data']])
train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42)
train_val_split = train_test_split["train"].train_test_split(test_size=0.125, seed=42)
split_dataset = DatasetDict(
{
"train": train_val_split["train"],
"val": train_val_split["test"],
"test": train_test_split["test"],
}
)
def reshape_json_data_to_fit_visualize_graph(graph_data):
nodes = graph_data["nodes"]
edges = graph_data["edges"]
transformed_nodes = [
{"id": nodes["id"][idx], "label": nodes["label"][idx]}
for idx in range(len(nodes["id"]))
]
transformed_edges = [
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
for idx in range(len(edges["source"]))
]
return {"nodes": transformed_nodes, "edges": transformed_edges}
def from_json_like_to_xml_like(data):
def parse_nodes(nodes):
node_elements = []
for node in nodes:
label = node["label"]
node_elements.append(f'<n id="{node["id"]}">{label}</n>')
return "<nodes>\n" + "".join(node_elements) + "\n</nodes>"
def parse_edges(edges):
edge_elements = []
for edge in edges:
edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>')
return "<edges>\n" + "".join(edge_elements) + "\n</edges>"
nodes_xml = parse_nodes(data["nodes"])
edges_xml = parse_edges(data["edges"])
return nodes_xml + "\n" + edges_xml
# function to shuffle the nodes on the fly in an attempt to reduce the bias from random node extraction
def flexible_node_shuffle(sequence):
# Split the sequence into nodes and edges
nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL)
edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL)
if not nodes_match or not edges_match:
print("Error: Could not find nodes or edges in the sequence.")
return sequence
nodes_content = nodes_match.group(1)
edges_content = edges_match.group(1)
# Extract individual nodes
nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', nodes_content, re.DOTALL)
# Shuffle the nodes
random.shuffle(nodes)
# Create a mapping of old ids to new ids
id_mapping = {old_id: str(new_id) for new_id, (old_id, _) in enumerate(nodes, start=1)}
# Reconstruct the nodes section with new ids
new_nodes_content = "".join(f'<n id="{new_id}">{content}</n>' for new_id, (_, content) in enumerate(nodes, start=1))
# Extract and update edge information
edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', edges_content)
new_edges = []
for src, tgt in edges:
new_src = int(id_mapping[src])
new_tgt = int(id_mapping[tgt])
# Append edge as tuple (original_src, original_tgt)
new_edges.append((new_src, new_tgt))
# Sort edges: first by the new src node id, then by the new tgt node id (preserving the original direction)
new_edges.sort(key=lambda x: (min(x[0], x[1]), max(x[0], x[1])))
# Reconstruct the edges section, preserving original direction
new_edges_content = "".join(f'<e src="{src}" tgt="{tgt}"/>' if src < tgt else f'<e src="{tgt}" tgt="{src}"/>' for src, tgt in new_edges)
# Reconstruct the full sequence
new_sequence = f'<nodes><newline>{new_nodes_content}<newline></nodes><newline><edges><newline>{new_edges_content}<newline></edges>'
return new_sequence
class Sharpen:
def __call__(self, img):
return img.filter(ImageFilter.SHARPEN)
# with the graph edit distance validation
import re
from nltk import edit_distance
import numpy as np
import torch
import pytorch_lightning as pl
import mlflow
import networkx as nx
import Levenshtein
import xml.etree.ElementTree as ET
import multiprocessing
import logging
from torch.optim.lr_scheduler import LambdaLR
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# for the node matching and reordering to align with the ground truth graph
def match_nodes_by_label(G_pred, G_gt):
"""Match nodes from predicted graph to ground truth graph based on label similarity."""
node_mapping = {}
for n_pred, pred_data in G_pred.nodes(data=True):
best_match = None
best_score = float('inf') # Levenshtein is a distance metric, lower is better
for n_gt, gt_data in G_gt.nodes(data=True):
sim_score = DonutModelPLModule.normalized_levenshtein(pred_data['label'], gt_data['label'])
if sim_score < best_score:
best_score = sim_score
best_match = n_gt
if best_match:
node_mapping[n_pred] = best_match
return node_mapping
# also for the reodering
def rebuild_graph_with_mapped_nodes(G_pred, node_mapping):
"""Rebuild the predicted graph with nodes aligned to the ground truth."""
G_aligned = nx.Graph()
for node_pred, node_gt in node_mapping.items():
G_aligned.add_node(node_gt, label=G_pred.nodes[node_pred]['label'])
for u, v in G_pred.edges():
if u in node_mapping and v in node_mapping:
G_aligned.add_edge(node_mapping[u], node_mapping[v])
return G_aligned
class DonutModelPLModule(pl.LightningModule):
def __init__(self, config, processor, model):
super().__init__()
self.config = config
self.processor = processor
self.model = model
self.train_loss_epoch_total = 0.0
self.val_loss_epoch_total = 0.0
self.train_batch_count = 0
self.val_batch_count = 0
self.edit_distance_scores = []
self.graph_metrics = {
'fast_graph_similarity': [],
'node_label_similarity': [],
'edge_similarity': [],
'degree_sequence_similarity': [],
'node_coverage': [],
'edge_precision': [],
'edge_recall': []
}
self.lr = config["lr"]
self.warmup_steps = config["warmup_steps"]
def training_step(self, batch, batch_idx):
pixel_values, labels, _ = batch
outputs = self.model(pixel_values, labels=labels)
loss = outputs.loss
self.train_loss_epoch_total += loss.item()
self.train_batch_count += 1
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=0):
pixel_values, labels, answers = batch
outputs = self.model(pixel_values, labels=labels)
val_loss = outputs.loss
self.val_loss_epoch_total += val_loss.item()
self.val_batch_count += 1
self.log("val_loss", val_loss)
if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0:
logger.info(f'Finished epoch: {self.current_epoch + 1}')
print(f'Finished epoch: {self.current_epoch + 1}')
batch_size = pixel_values.shape[0]
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
try:
outputs = self.model.generate(pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=self.config.get("max_length", 512),
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,)
predictions = self.process_predictions(outputs)
logger.info('Calculating graph metrics')
print('Calculating graph metrics')
levenshtein_scores, graph_scores = self.calculate_metrics(predictions, answers)
logger.info('Finished calculating graph metrics')
print('Finished calculating graph metrics')
self.edit_distance_scores.append(np.mean(levenshtein_scores))
for metric in self.graph_metrics:
self.graph_metrics[metric].append(np.mean([score[metric] for score in graph_scores if metric in score]))
self.log("val_edit_distance", np.mean(levenshtein_scores), prog_bar=True)
for metric in self.graph_metrics:
self.log(f"val_{metric}", self.graph_metrics[metric][-1], prog_bar=True)
except Exception as e:
logger.error(f"Error in validation step: {str(e)}")
print(f"Error in validation step: {str(e)}")
def process_predictions(self, outputs):
predictions = []
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
try:
seq = (
seq.replace(self.processor.tokenizer.eos_token, "")
.replace(self.processor.tokenizer.pad_token, "")
.replace('<n id=" ', '<n id="')
.replace('src=" ', 'src="')
.replace('tgt=" ', 'tgt="')
.replace('<newline>', '\n')
)
seq = re.sub(r"<s>", "", seq, count=1).strip()
seq = seq.replace("<s>", "")
predictions.append(seq)
except Exception as e:
logger.error(f"Error processing prediction: {str(e)}")
print(f"Error processing prediction: {str(e)}")
predictions.append("") # Append empty string if processing fails
return predictions
def calculate_metrics(self, predictions, answers):
levenshtein_scores = []
graph_scores = []
for pred, answer in zip(predictions, answers):
try:
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
answer = answer.replace(self.processor.tokenizer.bos_token, "").replace(self.processor.tokenizer.eos_token, "").replace("<newline>", "\n")
edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer))
logger.info(f"Prediction: {pred}")
logger.info(f" Answer: {answer}")
logger.info(f" Normed ED: {edit_dist}")
print(f"Prediction: {pred}")
print(f" Answer: {answer}")
print(f" Normed ED: {edit_dist}")
levenshtein_scores.append(edit_dist)
pred_graph = self.create_graph_from_string(pred)
answer_graph = self.create_graph_from_string(answer)
# Added this to reorder the predicted graphs ignoring the node order for better validation
# Match nodes based on labels and reorder
node_mapping = match_nodes_by_label(pred_graph, answer_graph)
pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping)
# Compare the aligned graphs
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
logger.info('Calculating the GED')
print('Calculating the GED')
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60))
graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
logger.info('Got the GED results')
print('Got the GED results')
except Exception as e:
logger.error(f"Error calculating metrics: {str(e)}")
print(f"Error calculating metrics: {str(e)}")
levenshtein_scores.append(1.0) # Worst possible score
graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores
return levenshtein_scores, graph_scores
@staticmethod
def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60):
def wrapper(return_dict):
return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph)
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(target=wrapper, args=(return_dict,))
p.start()
p.join(timeout)
if p.is_alive():
logger.warning('Graph comparison timed out. Returning default values.')
print('Graph comparison timed out. Returning default values.')
p.terminate()
p.join()
return {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
}
else:
return return_dict.get('result', {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
})
@staticmethod
def create_graph_from_string(xml_string):
G = nx.Graph()
try:
# Extract nodes
nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', xml_string, re.DOTALL)
for node_id, label in nodes:
G.add_node(node_id, label=label.lower())
# Extract edges
edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', xml_string)
for src, tgt in edges:
G.add_edge(src, tgt)
except Exception as e:
logger.error(f"Error creating graph from string: {str(e)}")
print(f"Error creating graph from string: {str(e)}")
return G
@staticmethod
def normalized_levenshtein(s1, s2):
distance = Levenshtein.distance(s1, s2)
max_length = max(len(s1), len(s2))
return distance / max_length if max_length > 0 else 0
@staticmethod
def calculate_node_coverage(G1, G2, threshold=0.2):
matched_nodes = 0
for n1 in G1.nodes(data=True):
if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold
for n2 in G2.nodes(data=True)):
matched_nodes += 1
return matched_nodes / max(len(G1), len(G2))
@staticmethod
def node_label_similarity(G1, G2):
labels1 = list(nx.get_node_attributes(G1, 'label').values())
labels2 = list(nx.get_node_attributes(G2, 'label').values())
total_similarity = 0
for label1 in labels1:
similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2]
total_similarity += max(similarities) if similarities else 0
return total_similarity / len(labels1) if labels1 else 0
@staticmethod
def edge_similarity(G1, G2):
return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1
@staticmethod
def degree_sequence_similarity(G1, G2):
seq1 = sorted([d for n, d in G1.degree()], reverse=True)
seq2 = sorted([d for n, d in G2.degree()], reverse=True)
# If either sequence is empty, return 0 similarity
if not seq1 or not seq2:
return 0.0
# Padding sequences to make them the same length
max_len = max(len(seq1), len(seq2))
seq1 += [0] * (max_len - len(seq1))
seq2 += [0] * (max_len - len(seq2))
# Calculate degree sequence similarity
diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2))
# Return similarity, handle edge case where the sum of degrees is zero
return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0
@staticmethod
def fast_graph_similarity(G1, G2):
node_sim = DonutModelPLModule.node_label_similarity(G1, G2)
edge_sim = DonutModelPLModule.edge_similarity(G1, G2)
degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2)
return (node_sim + edge_sim + degree_sim) / 3
@staticmethod
def compare_graphs(G1, G2):
try:
node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2)
G1_edges = set(G1.edges())
G2_edges = set(G2.edges())
correct_edges = len(G1_edges & G2_edges)
edge_precision = correct_edges / len(G2_edges) if G2_edges else 0
edge_recall = correct_edges / len(G1_edges) if G1_edges else 0
return {
"fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2),
"node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2),
"edge_similarity": DonutModelPLModule.edge_similarity(G1, G2),
"degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2),
"node_coverage": node_coverage,
"edge_precision": edge_precision,
"edge_recall": edge_recall
}
except Exception as e:
logger.error(f"Error comparing graphs: {str(e)}")
print(f"Error comparing graphs: {str(e)}")
return {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
}
def configure_optimizers(self):
# Define the optimizer
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
# Define the warmup + decay scheduler
def lr_lambda(current_step):
if current_step < self.warmup_steps:
return float(current_step) / float(max(1, self.warmup_steps))
return 1.0 # You can replace this with a decay function like exponential decay
scheduler = LambdaLR(optimizer, lr_lambda)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'step', # Update the learning rate after every training step
'frequency': 1, # How often the scheduler is called (every step)
}
}
def on_validation_epoch_end(self):
avg_val_loss = self.val_loss_epoch_total / self.val_batch_count
mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch)
self.val_loss_epoch_total = 0.0
self.val_batch_count = 0
if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0:
if self.edit_distance_scores:
mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch)
for metric in self.graph_metrics:
if self.graph_metrics[metric]:
mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch)
print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1)
def on_train_epoch_end(self):
print(f'[INFO] - Finished epoch {self.current_epoch + 1}')
avg_train_loss = self.train_loss_epoch_total / self.train_batch_count
print(f'[INFO] - Train loss: {avg_train_loss}')
mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch)
self.train_loss_epoch_total = 0.0
self.train_batch_count = 0
if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0:
self.save_model()
def on_fit_end(self):
self.save_model()
def save_model(self):
model_dir = "Donut_model"
os.makedirs(model_dir, exist_ok=True)
self.model.save_pretrained(model_dir)
print('[INFO] - Saving the model to dagshub using mlflow')
mlflow.transformers.log_model(
transformers_model={
"model": self.model,
"feature_extractor": self.processor.feature_extractor,
"image_processor": self.processor.image_processor,
"tokenizer": self.processor.tokenizer
},
artifact_path=model_dir,
# Set task explicitly since MLflow cannot infer it from the loaded model
task = "image-to-text"
)
print('[INFO] - Saved the model to dagshub using mlflow')
def train_dataloader(self):
return train_dataloader
def val_dataloader(self):
return val_dataloader
config = {"max_epochs":200,
# "val_check_interval":0.2, # how many times we want to validate during an epoch
"check_val_every_n_epoch":1,
"gradient_clip_val":1.0,
# "num_training_samples_per_epoch": 800,
"lr":8e-4, #3e-4, #3e-5,
"train_batch_sizes": [1], #[8], #[1],#[8],
"val_batch_sizes": [1],
# "seed":2022,
"num_nodes": 1,
"warmup_steps": 200, # 800/8*30/10, 10%
"verbose": True,
}
model_module = DonutModelPLModule(config, processor, model)
# Load dataset
dataset = split_dataset['test']
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
class Sharpen:
def __call__(self, img):
return img.filter(ImageFilter.SHARPEN)
def preprocess_image(image):
# Convert to PIL Image if it's not already
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Apply sharpening
sharpen = Sharpen()
sharpened_image = sharpen(image)
return sharpened_image
def perform_inference(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(device)
# Prepare decoder input ids
batch_size = pixel_values.shape[0]
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device)
# Generate output
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=max_length, # + 500, #512, # Adjust as needed
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode the output
decoded_output = processor.batch_decode(outputs.sequences)[0]
print("Raw model output:", decoded_output)
return decoded_output
def display_example(index):
example = dataset[index]
img = example["image"]
return img, None, None
def from_json_like_to_xml_like(data):
def parse_nodes(nodes):
node_elements = []
for node in nodes:
label = node["label"]
node_elements.append(f'<n id="{node["id"]}">{label}</n>')
return "<nodes>\n" + "".join(node_elements) + "\n</nodes>"
def parse_edges(edges):
edge_elements = []
for edge in edges:
edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>')
return "<edges>\n" + "".join(edge_elements) + "\n</edges>"
nodes_xml = parse_nodes(data["nodes"])
edges_xml = parse_edges(data["edges"])
return nodes_xml + "\n" + edges_xml
def reshape_json_data_to_fit_visualize_graph(graph_data):
nodes = graph_data["nodes"]
edges = graph_data["edges"]
transformed_nodes = [
{"id": nodes["id"][idx], "label": nodes["label"][idx]}
for idx in range(len(nodes["id"]))
]
transformed_edges = [
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
for idx in range(len(edges["source"]))
]
return {"nodes": transformed_nodes, "edges": transformed_edges}
def get_ground_truth(index):
example = dataset[index]
ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example))
ground_truth = from_json_like_to_xml_like(json.loads(ground_truth))
print(f'Ground truth sequence: {ground_truth}')
return ground_truth
def transform_image(img, index, physics_enabled):
# Perform inference
sequence = perform_inference(img)
# Transform the sequence to graph data
graph_data = transform_sequence(sequence)
# Generate the graph visualization
graph_html = visualize_graph(graph_data, physics_enabled)
# Modify the iframe to have a fixed height
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
# Convert graph_data to a formatted JSON string
json_data = json.dumps(graph_data, indent=2)
return graph_html, json_data, sequence
import re
from typing import Dict, List, Tuple
def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]:
# Extract nodes and edges
nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL)
edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL)
if not nodes_match or not edges_match:
raise ValueError("Invalid input sequence: nodes or edges not found")
nodes_text = nodes_match.group(1)
edges_text = edges_match.group(1)
# Parse nodes
nodes = []
for node_match in re.finditer(r'<n id="\s*(\d+)">(.*?)</n>', nodes_text):
node_id, node_label = node_match.groups()
nodes.append({
"id": node_id.strip(),
"label": node_label.strip()
})
# Parse edges
edges = []
for edge_match in re.finditer(r'<e src="\s*(\d+)" tgt="\s*(\d+)"/>', edges_text):
source, target = edge_match.groups()
edges.append({
"source": source.strip(),
"target": target.strip(),
"type": "->"
})
return {
"nodes": nodes,
"edges": edges
}
# function to visualize the extracted graph
import json
from pyvis.network import Network
def create_graph(nodes, edges, physics_enabled=True):
net = Network(
notebook=True,
height="100vh",
width="100vw",
bgcolor="#222222",
font_color="white",
cdn_resources="remote",
)
for node in nodes:
net.add_node(
node["id"],
label=node["label"],
title=node["label"],
color="blue" if node["label"] == "OOP" else "green",
)
for edge in edges:
net.add_edge(edge["source"], edge["target"], title=edge["type"])
net.force_atlas_2based(
gravity=-50,
central_gravity=0.01,
spring_length=100,
spring_strength=0.08,
damping=0.4,
)
options = {
"nodes": {"physics": physics_enabled},
"edges": {"smooth": True},
"interaction": {"hover": True, "zoomView": True},
"physics": {
"enabled": physics_enabled,
"stabilization": {"enabled": True, "iterations": 200},
},
}
net.set_options(json.dumps(options))
return net
def visualize_graph(json_data, physics_enabled=True):
if isinstance(json_data, str):
data = json.loads(json_data)
else:
data = json_data
nodes = data["nodes"]
edges = data["edges"]
net = create_graph(nodes, edges, physics_enabled)
html = net.generate_html()
html = html.replace("'", '"')
html = html.replace(
'<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"'
)
return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>"""
def update_physics(json_data, physics_enabled):
if json_data is None:
return None
data = json.loads(json_data)
graph_html = visualize_graph(data, physics_enabled)
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
return graph_html
# function to calculate the graph similarity metrics between the prediction and the ground-truth
def calculate_and_display_metrics(pred_graph, ground_truth_graph):
if pred_graph is None or ground_truth_graph is None:
return "Please generate a prediction and ensure a ground truth graph is available."
#removing the start token from the string
pred_graph = pred_graph.replace('<s>', "").replace("<newline>", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('<n id=" ', '<n id="')
print(f'Prediction: {pred_graph}')
# Assuming the graphs are in the correct format for the calculate_metrics function
metrics = model_module.calculate_metrics([pred_graph], [ground_truth_graph])
# Format the metrics for display
overall_metric = metrics[0][0]
detailed_metrics = metrics[1][0]
# output = f"Overall Metric: {overall_metric:.4f}\n\nDetailed Metrics:\n"
output = f"Detailed Metrics:\n"
for key, value in detailed_metrics.items():
output += f"{key}: {value:.4f}\n"
return output
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Knowledge Graph Visualizer with Model Inference")
with gr.Row():
index_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
label="Example Index"
)
with gr.Row():
image_output = gr.Image(type="pil", label="Image", height=500, interactive=False)
graph_output = gr.HTML(label="Knowledge Graph")
with gr.Row():
transform_button = gr.Button("Transform")
physics_toggle = gr.Checkbox(label="Enable Physics", value=True)
with gr.Row():
json_output = gr.Code(language="json", label="Graph JSON Data")
ground_truth_output = gr.Textbox(visible=False)#gr.JSON(label="Ground Truth Graph", visible=False)
predicted_raw_sequence = gr.Textbox(visible=False)
with gr.Row():
metrics_button = gr.Button("Calculate Metrics")
metrics_output = gr.Textbox(label="Similarity Metrics", lines=10)
index_slider.change(
fn=display_example,
inputs=[index_slider],
outputs=[image_output, graph_output, json_output],
).then(
fn=get_ground_truth,
inputs=[index_slider],
outputs=[ground_truth_output],
)
transform_button.click(
fn=transform_image,
inputs=[image_output, index_slider, physics_toggle],
outputs=[graph_output, json_output, predicted_raw_sequence],
).then(
fn=calculate_and_display_metrics,
inputs=[predicted_raw_sequence, ground_truth_output],
outputs=[metrics_output]#gr.Textbox(label="Metrics"),
)
metrics_button.click(
fn=calculate_and_display_metrics,
inputs=[predicted_raw_sequence, ground_truth_output],
outputs=[metrics_output],
)
physics_toggle.change(
fn=update_physics,
inputs=[json_output, physics_toggle],
outputs=[graph_output],
)
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True, debug=True)