", "")
predictions.append(seq)
except Exception as e:
logger.error(f"Error processing prediction: {str(e)}")
print(f"Error processing prediction: {str(e)}")
predictions.append("") # Append empty string if processing fails
return predictions
def calculate_metrics(self, predictions, answers):
levenshtein_scores = []
graph_scores = []
for pred, answer in zip(predictions, answers):
try:
pred = re.sub(r"(?:(?<=>) | (?=", "\n")
edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer))
logger.info(f"Prediction: {pred}")
logger.info(f" Answer: {answer}")
logger.info(f" Normed ED: {edit_dist}")
print(f"Prediction: {pred}")
print(f" Answer: {answer}")
print(f" Normed ED: {edit_dist}")
levenshtein_scores.append(edit_dist)
pred_graph = self.create_graph_from_string(pred)
answer_graph = self.create_graph_from_string(answer)
# Added this to reorder the predicted graphs ignoring the node order for better validation
# Match nodes based on labels and reorder
node_mapping = match_nodes_by_label(pred_graph, answer_graph)
pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping)
# Compare the aligned graphs
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
logger.info('Calculating the GED')
print('Calculating the GED')
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60))
graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
logger.info('Got the GED results')
print('Got the GED results')
except Exception as e:
logger.error(f"Error calculating metrics: {str(e)}")
print(f"Error calculating metrics: {str(e)}")
levenshtein_scores.append(1.0) # Worst possible score
graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores
return levenshtein_scores, graph_scores
@staticmethod
def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60):
def wrapper(return_dict):
return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph)
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(target=wrapper, args=(return_dict,))
p.start()
p.join(timeout)
if p.is_alive():
logger.warning('Graph comparison timed out. Returning default values.')
print('Graph comparison timed out. Returning default values.')
p.terminate()
p.join()
return {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
}
else:
return return_dict.get('result', {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
})
@staticmethod
def create_graph_from_string(xml_string):
G = nx.Graph()
try:
# Extract nodes
nodes = re.findall(r'(.*?)', xml_string, re.DOTALL)
for node_id, label in nodes:
G.add_node(node_id, label=label.lower())
# Extract edges
edges = re.findall(r'', xml_string)
for src, tgt in edges:
G.add_edge(src, tgt)
except Exception as e:
logger.error(f"Error creating graph from string: {str(e)}")
print(f"Error creating graph from string: {str(e)}")
return G
@staticmethod
def normalized_levenshtein(s1, s2):
distance = Levenshtein.distance(s1, s2)
max_length = max(len(s1), len(s2))
return distance / max_length if max_length > 0 else 0
@staticmethod
def calculate_node_coverage(G1, G2, threshold=0.2):
matched_nodes = 0
for n1 in G1.nodes(data=True):
if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold
for n2 in G2.nodes(data=True)):
matched_nodes += 1
return matched_nodes / max(len(G1), len(G2))
@staticmethod
def node_label_similarity(G1, G2):
labels1 = list(nx.get_node_attributes(G1, 'label').values())
labels2 = list(nx.get_node_attributes(G2, 'label').values())
total_similarity = 0
for label1 in labels1:
similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2]
total_similarity += max(similarities) if similarities else 0
return total_similarity / len(labels1) if labels1 else 0
@staticmethod
def edge_similarity(G1, G2):
return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1
@staticmethod
def degree_sequence_similarity(G1, G2):
seq1 = sorted([d for n, d in G1.degree()], reverse=True)
seq2 = sorted([d for n, d in G2.degree()], reverse=True)
# If either sequence is empty, return 0 similarity
if not seq1 or not seq2:
return 0.0
# Padding sequences to make them the same length
max_len = max(len(seq1), len(seq2))
seq1 += [0] * (max_len - len(seq1))
seq2 += [0] * (max_len - len(seq2))
# Calculate degree sequence similarity
diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2))
# Return similarity, handle edge case where the sum of degrees is zero
return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0
@staticmethod
def fast_graph_similarity(G1, G2):
node_sim = DonutModelPLModule.node_label_similarity(G1, G2)
edge_sim = DonutModelPLModule.edge_similarity(G1, G2)
degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2)
return (node_sim + edge_sim + degree_sim) / 3
@staticmethod
def compare_graphs(G1, G2):
try:
node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2)
G1_edges = set(G1.edges())
G2_edges = set(G2.edges())
correct_edges = len(G1_edges & G2_edges)
edge_precision = correct_edges / len(G2_edges) if G2_edges else 0
edge_recall = correct_edges / len(G1_edges) if G1_edges else 0
return {
"fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2),
"node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2),
"edge_similarity": DonutModelPLModule.edge_similarity(G1, G2),
"degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2),
"node_coverage": node_coverage,
"edge_precision": edge_precision,
"edge_recall": edge_recall
}
except Exception as e:
logger.error(f"Error comparing graphs: {str(e)}")
print(f"Error comparing graphs: {str(e)}")
return {
"fast_graph_similarity": 0.0,
"node_label_similarity": 0.0,
"edge_similarity": 0.0,
"degree_sequence_similarity": 0.0,
"node_coverage": 0.0,
"edge_precision": 0.0,
"edge_recall": 0.0
}
def configure_optimizers(self):
# Define the optimizer
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
# Define the warmup + decay scheduler
def lr_lambda(current_step):
if current_step < self.warmup_steps:
return float(current_step) / float(max(1, self.warmup_steps))
return 1.0 # You can replace this with a decay function like exponential decay
scheduler = LambdaLR(optimizer, lr_lambda)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'step', # Update the learning rate after every training step
'frequency': 1, # How often the scheduler is called (every step)
}
}
def on_validation_epoch_end(self):
avg_val_loss = self.val_loss_epoch_total / self.val_batch_count
mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch)
self.val_loss_epoch_total = 0.0
self.val_batch_count = 0
if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0:
if self.edit_distance_scores:
mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch)
for metric in self.graph_metrics:
if self.graph_metrics[metric]:
mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch)
print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1)
def on_train_epoch_end(self):
print(f'[INFO] - Finished epoch {self.current_epoch + 1}')
avg_train_loss = self.train_loss_epoch_total / self.train_batch_count
print(f'[INFO] - Train loss: {avg_train_loss}')
mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch)
self.train_loss_epoch_total = 0.0
self.train_batch_count = 0
if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0:
self.save_model()
def on_fit_end(self):
self.save_model()
def save_model(self):
model_dir = "Donut_model"
os.makedirs(model_dir, exist_ok=True)
self.model.save_pretrained(model_dir)
print('[INFO] - Saving the model to dagshub using mlflow')
mlflow.transformers.log_model(
transformers_model={
"model": self.model,
"feature_extractor": self.processor.feature_extractor,
"image_processor": self.processor.image_processor,
"tokenizer": self.processor.tokenizer
},
artifact_path=model_dir,
# Set task explicitly since MLflow cannot infer it from the loaded model
task = "image-to-text"
)
print('[INFO] - Saved the model to dagshub using mlflow')
def train_dataloader(self):
return train_dataloader
def val_dataloader(self):
return val_dataloader
config = {"max_epochs":200,
# "val_check_interval":0.2, # how many times we want to validate during an epoch
"check_val_every_n_epoch":1,
"gradient_clip_val":1.0,
# "num_training_samples_per_epoch": 800,
"lr":8e-4, #3e-4, #3e-5,
"train_batch_sizes": [1], #[8], #[1],#[8],
"val_batch_sizes": [1],
# "seed":2022,
"num_nodes": 1,
"warmup_steps": 200, # 800/8*30/10, 10%
"verbose": True,
}
model_module = DonutModelPLModule(config, processor, model)
# Load dataset
dataset = split_dataset['test']
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
class Sharpen:
def __call__(self, img):
return img.filter(ImageFilter.SHARPEN)
def preprocess_image(image):
# Convert to PIL Image if it's not already
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Apply sharpening
sharpen = Sharpen()
sharpened_image = sharpen(image)
return sharpened_image
def perform_inference(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(device)
# Prepare decoder input ids
batch_size = pixel_values.shape[0]
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device)
# Generate output
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=max_length, # + 500, #512, # Adjust as needed
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode the output
decoded_output = processor.batch_decode(outputs.sequences)[0]
print("Raw model output:", decoded_output)
return decoded_output
def display_example(index):
example = dataset[index]
img = example["image"]
return img, None, None
def from_json_like_to_xml_like(data):
def parse_nodes(nodes):
node_elements = []
for node in nodes:
label = node["label"]
node_elements.append(f'{label}')
return "\n" + "".join(node_elements) + "\n"
def parse_edges(edges):
edge_elements = []
for edge in edges:
edge_elements.append(f'')
return "\n" + "".join(edge_elements) + "\n"
nodes_xml = parse_nodes(data["nodes"])
edges_xml = parse_edges(data["edges"])
return nodes_xml + "\n" + edges_xml
def reshape_json_data_to_fit_visualize_graph(graph_data):
nodes = graph_data["nodes"]
edges = graph_data["edges"]
transformed_nodes = [
{"id": nodes["id"][idx], "label": nodes["label"][idx]}
for idx in range(len(nodes["id"]))
]
transformed_edges = [
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
for idx in range(len(edges["source"]))
]
return {"nodes": transformed_nodes, "edges": transformed_edges}
def get_ground_truth(index):
example = dataset[index]
ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example))
ground_truth = from_json_like_to_xml_like(json.loads(ground_truth))
print(f'Ground truth sequence: {ground_truth}')
return ground_truth
def transform_image(img, index, physics_enabled):
# Perform inference
sequence = perform_inference(img)
# Transform the sequence to graph data
graph_data = transform_sequence(sequence)
# Generate the graph visualization
graph_html = visualize_graph(graph_data, physics_enabled)
# Modify the iframe to have a fixed height
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
# Convert graph_data to a formatted JSON string
json_data = json.dumps(graph_data, indent=2)
return graph_html, json_data, sequence
import re
from typing import Dict, List, Tuple
def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]:
# Extract nodes and edges
nodes_match = re.search(r'(.*?)', sequence, re.DOTALL)
edges_match = re.search(r'(.*?)', sequence, re.DOTALL)
if not nodes_match or not edges_match:
raise ValueError("Invalid input sequence: nodes or edges not found")
nodes_text = nodes_match.group(1)
edges_text = edges_match.group(1)
# Parse nodes
nodes = []
for node_match in re.finditer(r'(.*?)', nodes_text):
node_id, node_label = node_match.groups()
nodes.append({
"id": node_id.strip(),
"label": node_label.strip()
})
# Parse edges
edges = []
for edge_match in re.finditer(r'', edges_text):
source, target = edge_match.groups()
edges.append({
"source": source.strip(),
"target": target.strip(),
"type": "->"
})
return {
"nodes": nodes,
"edges": edges
}
# function to visualize the extracted graph
import json
from pyvis.network import Network
def create_graph(nodes, edges, physics_enabled=True):
net = Network(
notebook=True,
height="100vh",
width="100vw",
bgcolor="#222222",
font_color="white",
cdn_resources="remote",
)
for node in nodes:
net.add_node(
node["id"],
label=node["label"],
title=node["label"],
color="blue" if node["label"] == "OOP" else "green",
)
for edge in edges:
net.add_edge(edge["source"], edge["target"], title=edge["type"])
net.force_atlas_2based(
gravity=-50,
central_gravity=0.01,
spring_length=100,
spring_strength=0.08,
damping=0.4,
)
options = {
"nodes": {"physics": physics_enabled},
"edges": {"smooth": True},
"interaction": {"hover": True, "zoomView": True},
"physics": {
"enabled": physics_enabled,
"stabilization": {"enabled": True, "iterations": 200},
},
}
net.set_options(json.dumps(options))
return net
def visualize_graph(json_data, physics_enabled=True):
if isinstance(json_data, str):
data = json.loads(json_data)
else:
data = json_data
nodes = data["nodes"]
edges = data["edges"]
net = create_graph(nodes, edges, physics_enabled)
html = net.generate_html()
html = html.replace("'", '"')
html = html.replace(
'"""
def update_physics(json_data, physics_enabled):
if json_data is None:
return None
data = json.loads(json_data)
graph_html = visualize_graph(data, physics_enabled)
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
return graph_html
# function to calculate the graph similarity metrics between the prediction and the ground-truth
def calculate_and_display_metrics(pred_graph, ground_truth_graph):
if pred_graph is None or ground_truth_graph is None:
return "Please generate a prediction and ensure a ground truth graph is available."
#removing the start token from the string
pred_graph = pred_graph.replace('', "").replace("", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('