import gradio as gr from datasets import load_dataset from PIL import Image import json import torch from torchvision import transforms from transformers import DonutProcessor, VisionEncoderDecoderModel # import subprocess # # Install mlflow and dagshub without dependencies # subprocess.run(['pip', 'install', '--no-deps', 'mlflow']) # subprocess.run(['pip', 'install', '--no-deps', 'dagshub']) import dagshub import mlflow import time import os # from kaggle_secrets import UserSecretsClient # user_secrets = UserSecretsClient() # token = user_secrets.get_secret("dags_hub_token") # from google.colab import userdata # token = userdata.get('dags_hub_token') token = os.getenv('dags_hub_token') dagshub.auth.add_app_token(token) dagshub.init(repo_owner='zaheramasha', repo_name='Finetuning_paligemma_Zaka_capstone', mlflow=True) # Define the MLflow run ID and artifact path run_id = "c41cfd149a8c44f3a92d8e0f1253af35" # Donut model trained on the PyvizAndMarkMap dataset for 27 epochs reaching a train loss of 0.168 run_id = "89bafd5e525a4d3e9d004e13c9574198" # Donut model trained on the PyvizAndMarkMap dataset for 27 + 51 = 78 epochs reaching a train loss of 0.0353. This run was a continuation of the 27 epoch one artifact_path = "Donut_model/model" # Create the model URI using the run ID and artifact path model_uri = f"runs:/{run_id}/{artifact_path}" print(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=artifact_path)) # Load the model and processors from the MLflow artifact # loaded_model_bundle = mlflow.transformers.load_model(artifact_path=artifact_path, run_id=run_id) # for the 20 epochs trained model model_uri = f"mlflow-artifacts:/0a5d0550f55c4169b80cd6439556be8b/c41cfd149a8c44f3a92d8e0f1253af35/artifacts/Donut_model" # for the fully 70 epochs trained model model_uri = f"mlflow-artifacts:/17c375f6eab34c63b2a2e7792803132e/89bafd5e525a4d3e9d004e13c9574198/artifacts/Donut_model" loaded_model_bundle = mlflow.transformers.load_model(model_uri=model_uri, device='cpu')#'cuda') model = loaded_model_bundle.model processor = DonutProcessor(tokenizer=loaded_model_bundle.tokenizer, feature_extractor=loaded_model_bundle.feature_extractor, image_processor=loaded_model_bundle.image_processor) print(model.config.encoder.image_size) print(model.config.decoder.max_length) import json import random from typing import Any, List, Tuple, Dict import torch from torch.utils.data import Dataset from datasets import load_dataset, DatasetDict, concatenate_datasets from PIL import Image, ImageFilter from torchvision import transforms import re # Load and split the dataset Pyviz_dataset = load_dataset("Zaherrr/OOP_KG_Pyviz_Synthetic_Dataset", revision="Sorted_edges") MarkMap_dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset") combined_dataset = concatenate_datasets([Pyviz_dataset['data'], MarkMap_dataset['data']]) train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42) train_val_split = train_test_split["train"].train_test_split(test_size=0.125, seed=42) split_dataset = DatasetDict( { "train": train_val_split["train"], "val": train_val_split["test"], "test": train_test_split["test"], } ) def reshape_json_data_to_fit_visualize_graph(graph_data): nodes = graph_data["nodes"] edges = graph_data["edges"] transformed_nodes = [ {"id": nodes["id"][idx], "label": nodes["label"][idx]} for idx in range(len(nodes["id"])) ] transformed_edges = [ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} for idx in range(len(edges["source"])) ] return {"nodes": transformed_nodes, "edges": transformed_edges} def from_json_like_to_xml_like(data): def parse_nodes(nodes): node_elements = [] for node in nodes: label = node["label"] node_elements.append(f'{label}') return "\n" + "".join(node_elements) + "\n" def parse_edges(edges): edge_elements = [] for edge in edges: edge_elements.append(f'') return "\n" + "".join(edge_elements) + "\n" nodes_xml = parse_nodes(data["nodes"]) edges_xml = parse_edges(data["edges"]) return nodes_xml + "\n" + edges_xml # function to shuffle the nodes on the fly in an attempt to reduce the bias from random node extraction def flexible_node_shuffle(sequence): # Split the sequence into nodes and edges nodes_match = re.search(r'(.*?)', sequence, re.DOTALL) edges_match = re.search(r'(.*?)', sequence, re.DOTALL) if not nodes_match or not edges_match: print("Error: Could not find nodes or edges in the sequence.") return sequence nodes_content = nodes_match.group(1) edges_content = edges_match.group(1) # Extract individual nodes nodes = re.findall(r'(.*?)', nodes_content, re.DOTALL) # Shuffle the nodes random.shuffle(nodes) # Create a mapping of old ids to new ids id_mapping = {old_id: str(new_id) for new_id, (old_id, _) in enumerate(nodes, start=1)} # Reconstruct the nodes section with new ids new_nodes_content = "".join(f'{content}' for new_id, (_, content) in enumerate(nodes, start=1)) # Extract and update edge information edges = re.findall(r'', edges_content) new_edges = [] for src, tgt in edges: new_src = int(id_mapping[src]) new_tgt = int(id_mapping[tgt]) # Append edge as tuple (original_src, original_tgt) new_edges.append((new_src, new_tgt)) # Sort edges: first by the new src node id, then by the new tgt node id (preserving the original direction) new_edges.sort(key=lambda x: (min(x[0], x[1]), max(x[0], x[1]))) # Reconstruct the edges section, preserving original direction new_edges_content = "".join(f'' if src < tgt else f'' for src, tgt in new_edges) # Reconstruct the full sequence new_sequence = f'{new_nodes_content}{new_edges_content}' return new_sequence class Sharpen: def __call__(self, img): return img.filter(ImageFilter.SHARPEN) # with the graph edit distance validation import re from nltk import edit_distance import numpy as np import torch import pytorch_lightning as pl import mlflow import networkx as nx import Levenshtein import xml.etree.ElementTree as ET import multiprocessing import logging from torch.optim.lr_scheduler import LambdaLR logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # for the node matching and reordering to align with the ground truth graph def match_nodes_by_label(G_pred, G_gt): """Match nodes from predicted graph to ground truth graph based on label similarity.""" node_mapping = {} for n_pred, pred_data in G_pred.nodes(data=True): best_match = None best_score = float('inf') # Levenshtein is a distance metric, lower is better for n_gt, gt_data in G_gt.nodes(data=True): sim_score = DonutModelPLModule.normalized_levenshtein(pred_data['label'], gt_data['label']) if sim_score < best_score: best_score = sim_score best_match = n_gt if best_match: node_mapping[n_pred] = best_match return node_mapping # also for the reodering def rebuild_graph_with_mapped_nodes(G_pred, node_mapping): """Rebuild the predicted graph with nodes aligned to the ground truth.""" G_aligned = nx.Graph() for node_pred, node_gt in node_mapping.items(): G_aligned.add_node(node_gt, label=G_pred.nodes[node_pred]['label']) for u, v in G_pred.edges(): if u in node_mapping and v in node_mapping: G_aligned.add_edge(node_mapping[u], node_mapping[v]) return G_aligned class DonutModelPLModule(pl.LightningModule): def __init__(self, config, processor, model): super().__init__() self.config = config self.processor = processor self.model = model self.train_loss_epoch_total = 0.0 self.val_loss_epoch_total = 0.0 self.train_batch_count = 0 self.val_batch_count = 0 self.edit_distance_scores = [] self.graph_metrics = { 'fast_graph_similarity': [], 'node_label_similarity': [], 'edge_similarity': [], 'degree_sequence_similarity': [], 'node_coverage': [], 'edge_precision': [], 'edge_recall': [] } self.lr = config["lr"] self.warmup_steps = config["warmup_steps"] def training_step(self, batch, batch_idx): pixel_values, labels, _ = batch outputs = self.model(pixel_values, labels=labels) loss = outputs.loss self.train_loss_epoch_total += loss.item() self.train_batch_count += 1 self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx, dataset_idx=0): pixel_values, labels, answers = batch outputs = self.model(pixel_values, labels=labels) val_loss = outputs.loss self.val_loss_epoch_total += val_loss.item() self.val_batch_count += 1 self.log("val_loss", val_loss) if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: logger.info(f'Finished epoch: {self.current_epoch + 1}') print(f'Finished epoch: {self.current_epoch + 1}') batch_size = pixel_values.shape[0] decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) try: outputs = self.model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=self.config.get("max_length", 512), early_stopping=True, pad_token_id=self.processor.tokenizer.pad_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[self.processor.tokenizer.unk_token_id]], return_dict_in_generate=True,) predictions = self.process_predictions(outputs) logger.info('Calculating graph metrics') print('Calculating graph metrics') levenshtein_scores, graph_scores = self.calculate_metrics(predictions, answers) logger.info('Finished calculating graph metrics') print('Finished calculating graph metrics') self.edit_distance_scores.append(np.mean(levenshtein_scores)) for metric in self.graph_metrics: self.graph_metrics[metric].append(np.mean([score[metric] for score in graph_scores if metric in score])) self.log("val_edit_distance", np.mean(levenshtein_scores), prog_bar=True) for metric in self.graph_metrics: self.log(f"val_{metric}", self.graph_metrics[metric][-1], prog_bar=True) except Exception as e: logger.error(f"Error in validation step: {str(e)}") print(f"Error in validation step: {str(e)}") def process_predictions(self, outputs): predictions = [] for seq in self.processor.tokenizer.batch_decode(outputs.sequences): try: seq = ( seq.replace(self.processor.tokenizer.eos_token, "") .replace(self.processor.tokenizer.pad_token, "") .replace('', '\n') ) seq = re.sub(r"", "", seq, count=1).strip() seq = seq.replace("", "") predictions.append(seq) except Exception as e: logger.error(f"Error processing prediction: {str(e)}") print(f"Error processing prediction: {str(e)}") predictions.append("") # Append empty string if processing fails return predictions def calculate_metrics(self, predictions, answers): levenshtein_scores = [] graph_scores = [] for pred, answer in zip(predictions, answers): try: pred = re.sub(r"(?:(?<=>) | (?=", "\n") edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer)) logger.info(f"Prediction: {pred}") logger.info(f" Answer: {answer}") logger.info(f" Normed ED: {edit_dist}") print(f"Prediction: {pred}") print(f" Answer: {answer}") print(f" Normed ED: {edit_dist}") levenshtein_scores.append(edit_dist) pred_graph = self.create_graph_from_string(pred) answer_graph = self.create_graph_from_string(answer) # Added this to reorder the predicted graphs ignoring the node order for better validation # Match nodes based on labels and reorder node_mapping = match_nodes_by_label(pred_graph, answer_graph) pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping) # Compare the aligned graphs # graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) logger.info('Calculating the GED') print('Calculating the GED') # graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60)) graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) logger.info('Got the GED results') print('Got the GED results') except Exception as e: logger.error(f"Error calculating metrics: {str(e)}") print(f"Error calculating metrics: {str(e)}") levenshtein_scores.append(1.0) # Worst possible score graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores return levenshtein_scores, graph_scores @staticmethod def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60): def wrapper(return_dict): return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph) manager = multiprocessing.Manager() return_dict = manager.dict() p = multiprocessing.Process(target=wrapper, args=(return_dict,)) p.start() p.join(timeout) if p.is_alive(): logger.warning('Graph comparison timed out. Returning default values.') print('Graph comparison timed out. Returning default values.') p.terminate() p.join() return { "fast_graph_similarity": 0.0, "node_label_similarity": 0.0, "edge_similarity": 0.0, "degree_sequence_similarity": 0.0, "node_coverage": 0.0, "edge_precision": 0.0, "edge_recall": 0.0 } else: return return_dict.get('result', { "fast_graph_similarity": 0.0, "node_label_similarity": 0.0, "edge_similarity": 0.0, "degree_sequence_similarity": 0.0, "node_coverage": 0.0, "edge_precision": 0.0, "edge_recall": 0.0 }) @staticmethod def create_graph_from_string(xml_string): G = nx.Graph() try: # Extract nodes nodes = re.findall(r'(.*?)', xml_string, re.DOTALL) for node_id, label in nodes: G.add_node(node_id, label=label.lower()) # Extract edges edges = re.findall(r'', xml_string) for src, tgt in edges: G.add_edge(src, tgt) except Exception as e: logger.error(f"Error creating graph from string: {str(e)}") print(f"Error creating graph from string: {str(e)}") return G @staticmethod def normalized_levenshtein(s1, s2): distance = Levenshtein.distance(s1, s2) max_length = max(len(s1), len(s2)) return distance / max_length if max_length > 0 else 0 @staticmethod def calculate_node_coverage(G1, G2, threshold=0.2): matched_nodes = 0 for n1 in G1.nodes(data=True): if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold for n2 in G2.nodes(data=True)): matched_nodes += 1 return matched_nodes / max(len(G1), len(G2)) @staticmethod def node_label_similarity(G1, G2): labels1 = list(nx.get_node_attributes(G1, 'label').values()) labels2 = list(nx.get_node_attributes(G2, 'label').values()) total_similarity = 0 for label1 in labels1: similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2] total_similarity += max(similarities) if similarities else 0 return total_similarity / len(labels1) if labels1 else 0 @staticmethod def edge_similarity(G1, G2): return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1 @staticmethod def degree_sequence_similarity(G1, G2): seq1 = sorted([d for n, d in G1.degree()], reverse=True) seq2 = sorted([d for n, d in G2.degree()], reverse=True) # If either sequence is empty, return 0 similarity if not seq1 or not seq2: return 0.0 # Padding sequences to make them the same length max_len = max(len(seq1), len(seq2)) seq1 += [0] * (max_len - len(seq1)) seq2 += [0] * (max_len - len(seq2)) # Calculate degree sequence similarity diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2)) # Return similarity, handle edge case where the sum of degrees is zero return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0 @staticmethod def fast_graph_similarity(G1, G2): node_sim = DonutModelPLModule.node_label_similarity(G1, G2) edge_sim = DonutModelPLModule.edge_similarity(G1, G2) degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2) return (node_sim + edge_sim + degree_sim) / 3 @staticmethod def compare_graphs(G1, G2): try: node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2) G1_edges = set(G1.edges()) G2_edges = set(G2.edges()) correct_edges = len(G1_edges & G2_edges) edge_precision = correct_edges / len(G2_edges) if G2_edges else 0 edge_recall = correct_edges / len(G1_edges) if G1_edges else 0 return { "fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2), "node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2), "edge_similarity": DonutModelPLModule.edge_similarity(G1, G2), "degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2), "node_coverage": node_coverage, "edge_precision": edge_precision, "edge_recall": edge_recall } except Exception as e: logger.error(f"Error comparing graphs: {str(e)}") print(f"Error comparing graphs: {str(e)}") return { "fast_graph_similarity": 0.0, "node_label_similarity": 0.0, "edge_similarity": 0.0, "degree_sequence_similarity": 0.0, "node_coverage": 0.0, "edge_precision": 0.0, "edge_recall": 0.0 } def configure_optimizers(self): # Define the optimizer optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) # Define the warmup + decay scheduler def lr_lambda(current_step): if current_step < self.warmup_steps: return float(current_step) / float(max(1, self.warmup_steps)) return 1.0 # You can replace this with a decay function like exponential decay scheduler = LambdaLR(optimizer, lr_lambda) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'interval': 'step', # Update the learning rate after every training step 'frequency': 1, # How often the scheduler is called (every step) } } def on_validation_epoch_end(self): avg_val_loss = self.val_loss_epoch_total / self.val_batch_count mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch) self.val_loss_epoch_total = 0.0 self.val_batch_count = 0 if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: if self.edit_distance_scores: mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch) for metric in self.graph_metrics: if self.graph_metrics[metric]: mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch) print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1) def on_train_epoch_end(self): print(f'[INFO] - Finished epoch {self.current_epoch + 1}') avg_train_loss = self.train_loss_epoch_total / self.train_batch_count print(f'[INFO] - Train loss: {avg_train_loss}') mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch) self.train_loss_epoch_total = 0.0 self.train_batch_count = 0 if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0: self.save_model() def on_fit_end(self): self.save_model() def save_model(self): model_dir = "Donut_model" os.makedirs(model_dir, exist_ok=True) self.model.save_pretrained(model_dir) print('[INFO] - Saving the model to dagshub using mlflow') mlflow.transformers.log_model( transformers_model={ "model": self.model, "feature_extractor": self.processor.feature_extractor, "image_processor": self.processor.image_processor, "tokenizer": self.processor.tokenizer }, artifact_path=model_dir, # Set task explicitly since MLflow cannot infer it from the loaded model task = "image-to-text" ) print('[INFO] - Saved the model to dagshub using mlflow') def train_dataloader(self): return train_dataloader def val_dataloader(self): return val_dataloader config = {"max_epochs":200, # "val_check_interval":0.2, # how many times we want to validate during an epoch "check_val_every_n_epoch":1, "gradient_clip_val":1.0, # "num_training_samples_per_epoch": 800, "lr":8e-4, #3e-4, #3e-5, "train_batch_sizes": [1], #[8], #[1],#[8], "val_batch_sizes": [1], # "seed":2022, "num_nodes": 1, "warmup_steps": 200, # 800/8*30/10, 10% "verbose": True, } model_module = DonutModelPLModule(config, processor, model) # Load dataset dataset = split_dataset['test'] # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) class Sharpen: def __call__(self, img): return img.filter(ImageFilter.SHARPEN) def preprocess_image(image): # Convert to PIL Image if it's not already if not isinstance(image, Image.Image): image = Image.fromarray(image) # Apply sharpening sharpen = Sharpen() sharpened_image = sharpen(image) return sharpened_image def perform_inference(image): # Preprocess the image inputs = processor(images=image, return_tensors="pt") pixel_values = inputs.pixel_values.to(device) # Prepare decoder input ids batch_size = pixel_values.shape[0] decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device) # Generate output outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=max_length, # + 500, #512, # Adjust as needed early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # Decode the output decoded_output = processor.batch_decode(outputs.sequences)[0] print("Raw model output:", decoded_output) return decoded_output def display_example(index): example = dataset[index] img = example["image"] return img, None, None def from_json_like_to_xml_like(data): def parse_nodes(nodes): node_elements = [] for node in nodes: label = node["label"] node_elements.append(f'{label}') return "\n" + "".join(node_elements) + "\n" def parse_edges(edges): edge_elements = [] for edge in edges: edge_elements.append(f'') return "\n" + "".join(edge_elements) + "\n" nodes_xml = parse_nodes(data["nodes"]) edges_xml = parse_edges(data["edges"]) return nodes_xml + "\n" + edges_xml def reshape_json_data_to_fit_visualize_graph(graph_data): nodes = graph_data["nodes"] edges = graph_data["edges"] transformed_nodes = [ {"id": nodes["id"][idx], "label": nodes["label"][idx]} for idx in range(len(nodes["id"])) ] transformed_edges = [ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} for idx in range(len(edges["source"])) ] return {"nodes": transformed_nodes, "edges": transformed_edges} def get_ground_truth(index): example = dataset[index] ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example)) ground_truth = from_json_like_to_xml_like(json.loads(ground_truth)) print(f'Ground truth sequence: {ground_truth}') return ground_truth def transform_image(img, index, physics_enabled): # Perform inference sequence = perform_inference(img) # Transform the sequence to graph data graph_data = transform_sequence(sequence) # Generate the graph visualization graph_html = visualize_graph(graph_data, physics_enabled) # Modify the iframe to have a fixed height graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') # Convert graph_data to a formatted JSON string json_data = json.dumps(graph_data, indent=2) return graph_html, json_data, sequence import re from typing import Dict, List, Tuple def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]: # Extract nodes and edges nodes_match = re.search(r'(.*?)', sequence, re.DOTALL) edges_match = re.search(r'(.*?)', sequence, re.DOTALL) if not nodes_match or not edges_match: raise ValueError("Invalid input sequence: nodes or edges not found") nodes_text = nodes_match.group(1) edges_text = edges_match.group(1) # Parse nodes nodes = [] for node_match in re.finditer(r'(.*?)', nodes_text): node_id, node_label = node_match.groups() nodes.append({ "id": node_id.strip(), "label": node_label.strip() }) # Parse edges edges = [] for edge_match in re.finditer(r'', edges_text): source, target = edge_match.groups() edges.append({ "source": source.strip(), "target": target.strip(), "type": "->" }) return { "nodes": nodes, "edges": edges } # function to visualize the extracted graph import json from pyvis.network import Network def create_graph(nodes, edges, physics_enabled=True): net = Network( notebook=True, height="100vh", width="100vw", bgcolor="#222222", font_color="white", cdn_resources="remote", ) for node in nodes: net.add_node( node["id"], label=node["label"], title=node["label"], color="blue" if node["label"] == "OOP" else "green", ) for edge in edges: net.add_edge(edge["source"], edge["target"], title=edge["type"]) net.force_atlas_2based( gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.08, damping=0.4, ) options = { "nodes": {"physics": physics_enabled}, "edges": {"smooth": True}, "interaction": {"hover": True, "zoomView": True}, "physics": { "enabled": physics_enabled, "stabilization": {"enabled": True, "iterations": 200}, }, } net.set_options(json.dumps(options)) return net def visualize_graph(json_data, physics_enabled=True): if isinstance(json_data, str): data = json.loads(json_data) else: data = json_data nodes = data["nodes"] edges = data["edges"] net = create_graph(nodes, edges, physics_enabled) html = net.generate_html() html = html.replace("'", '"') html = html.replace( '
""" def update_physics(json_data, physics_enabled): if json_data is None: return None data = json.loads(json_data) graph_html = visualize_graph(data, physics_enabled) graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') return graph_html # function to calculate the graph similarity metrics between the prediction and the ground-truth def calculate_and_display_metrics(pred_graph, ground_truth_graph): if pred_graph is None or ground_truth_graph is None: return "Please generate a prediction and ensure a ground truth graph is available." #removing the start token from the string pred_graph = pred_graph.replace('', "").replace("", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('