Spaces:
Build error
Build error
import gradio as gr | |
from datasets import load_dataset | |
from PIL import Image | |
import json | |
import torch | |
from torchvision import transforms | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
# import subprocess | |
# # Install mlflow and dagshub without dependencies | |
# subprocess.run(['pip', 'install', '--no-deps', 'mlflow']) | |
# subprocess.run(['pip', 'install', '--no-deps', 'dagshub']) | |
import dagshub | |
import mlflow | |
import time | |
import os | |
# from kaggle_secrets import UserSecretsClient | |
# user_secrets = UserSecretsClient() | |
# token = user_secrets.get_secret("dags_hub_token") | |
# from google.colab import userdata | |
# token = userdata.get('dags_hub_token') | |
token = os.getenv('dags_hub_token') | |
dagshub.auth.add_app_token(token) | |
dagshub.init(repo_owner='zaheramasha', | |
repo_name='Finetuning_paligemma_Zaka_capstone', | |
mlflow=True) | |
# Define the MLflow run ID and artifact path | |
run_id = "c41cfd149a8c44f3a92d8e0f1253af35" # Donut model trained on the PyvizAndMarkMap dataset for 27 epochs reaching a train loss of 0.168 | |
run_id = "89bafd5e525a4d3e9d004e13c9574198" # Donut model trained on the PyvizAndMarkMap dataset for 27 + 51 = 78 epochs reaching a train loss of 0.0353. This run was a continuation of the 27 epoch one | |
artifact_path = "Donut_model/model" | |
# Create the model URI using the run ID and artifact path | |
model_uri = f"runs:/{run_id}/{artifact_path}" | |
print(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=artifact_path)) | |
# Load the model and processors from the MLflow artifact | |
# loaded_model_bundle = mlflow.transformers.load_model(artifact_path=artifact_path, run_id=run_id) | |
# for the 20 epochs trained model | |
model_uri = f"mlflow-artifacts:/0a5d0550f55c4169b80cd6439556be8b/c41cfd149a8c44f3a92d8e0f1253af35/artifacts/Donut_model" | |
# for the fully 70 epochs trained model | |
model_uri = f"mlflow-artifacts:/17c375f6eab34c63b2a2e7792803132e/89bafd5e525a4d3e9d004e13c9574198/artifacts/Donut_model" | |
loaded_model_bundle = mlflow.transformers.load_model(model_uri=model_uri, device='cpu')#'cuda') | |
model = loaded_model_bundle.model | |
processor = DonutProcessor(tokenizer=loaded_model_bundle.tokenizer, feature_extractor=loaded_model_bundle.feature_extractor, image_processor=loaded_model_bundle.image_processor) | |
print(model.config.encoder.image_size) | |
print(model.config.decoder.max_length) | |
import json | |
import random | |
from typing import Any, List, Tuple, Dict | |
import torch | |
from torch.utils.data import Dataset | |
from datasets import load_dataset, DatasetDict, concatenate_datasets | |
from PIL import Image, ImageFilter | |
from torchvision import transforms | |
import re | |
# Load and split the dataset | |
Pyviz_dataset = load_dataset("Zaherrr/OOP_KG_Pyviz_Synthetic_Dataset", revision="Sorted_edges") | |
MarkMap_dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset") | |
combined_dataset = concatenate_datasets([Pyviz_dataset['data'], MarkMap_dataset['data']]) | |
train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42) | |
train_val_split = train_test_split["train"].train_test_split(test_size=0.125, seed=42) | |
split_dataset = DatasetDict( | |
{ | |
"train": train_val_split["train"], | |
"val": train_val_split["test"], | |
"test": train_test_split["test"], | |
} | |
) | |
def reshape_json_data_to_fit_visualize_graph(graph_data): | |
nodes = graph_data["nodes"] | |
edges = graph_data["edges"] | |
transformed_nodes = [ | |
{"id": nodes["id"][idx], "label": nodes["label"][idx]} | |
for idx in range(len(nodes["id"])) | |
] | |
transformed_edges = [ | |
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} | |
for idx in range(len(edges["source"])) | |
] | |
return {"nodes": transformed_nodes, "edges": transformed_edges} | |
def from_json_like_to_xml_like(data): | |
def parse_nodes(nodes): | |
node_elements = [] | |
for node in nodes: | |
label = node["label"] | |
node_elements.append(f'<n id="{node["id"]}">{label}</n>') | |
return "<nodes>\n" + "".join(node_elements) + "\n</nodes>" | |
def parse_edges(edges): | |
edge_elements = [] | |
for edge in edges: | |
edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>') | |
return "<edges>\n" + "".join(edge_elements) + "\n</edges>" | |
nodes_xml = parse_nodes(data["nodes"]) | |
edges_xml = parse_edges(data["edges"]) | |
return nodes_xml + "\n" + edges_xml | |
# function to shuffle the nodes on the fly in an attempt to reduce the bias from random node extraction | |
def flexible_node_shuffle(sequence): | |
# Split the sequence into nodes and edges | |
nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL) | |
edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL) | |
if not nodes_match or not edges_match: | |
print("Error: Could not find nodes or edges in the sequence.") | |
return sequence | |
nodes_content = nodes_match.group(1) | |
edges_content = edges_match.group(1) | |
# Extract individual nodes | |
nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', nodes_content, re.DOTALL) | |
# Shuffle the nodes | |
random.shuffle(nodes) | |
# Create a mapping of old ids to new ids | |
id_mapping = {old_id: str(new_id) for new_id, (old_id, _) in enumerate(nodes, start=1)} | |
# Reconstruct the nodes section with new ids | |
new_nodes_content = "".join(f'<n id="{new_id}">{content}</n>' for new_id, (_, content) in enumerate(nodes, start=1)) | |
# Extract and update edge information | |
edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', edges_content) | |
new_edges = [] | |
for src, tgt in edges: | |
new_src = int(id_mapping[src]) | |
new_tgt = int(id_mapping[tgt]) | |
# Append edge as tuple (original_src, original_tgt) | |
new_edges.append((new_src, new_tgt)) | |
# Sort edges: first by the new src node id, then by the new tgt node id (preserving the original direction) | |
new_edges.sort(key=lambda x: (min(x[0], x[1]), max(x[0], x[1]))) | |
# Reconstruct the edges section, preserving original direction | |
new_edges_content = "".join(f'<e src="{src}" tgt="{tgt}"/>' if src < tgt else f'<e src="{tgt}" tgt="{src}"/>' for src, tgt in new_edges) | |
# Reconstruct the full sequence | |
new_sequence = f'<nodes><newline>{new_nodes_content}<newline></nodes><newline><edges><newline>{new_edges_content}<newline></edges>' | |
return new_sequence | |
class Sharpen: | |
def __call__(self, img): | |
return img.filter(ImageFilter.SHARPEN) | |
# with the graph edit distance validation | |
import re | |
from nltk import edit_distance | |
import numpy as np | |
import torch | |
import pytorch_lightning as pl | |
import mlflow | |
import networkx as nx | |
import Levenshtein | |
import xml.etree.ElementTree as ET | |
import multiprocessing | |
import logging | |
from torch.optim.lr_scheduler import LambdaLR | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# for the node matching and reordering to align with the ground truth graph | |
def match_nodes_by_label(G_pred, G_gt): | |
"""Match nodes from predicted graph to ground truth graph based on label similarity.""" | |
node_mapping = {} | |
for n_pred, pred_data in G_pred.nodes(data=True): | |
best_match = None | |
best_score = float('inf') # Levenshtein is a distance metric, lower is better | |
for n_gt, gt_data in G_gt.nodes(data=True): | |
sim_score = DonutModelPLModule.normalized_levenshtein(pred_data['label'], gt_data['label']) | |
if sim_score < best_score: | |
best_score = sim_score | |
best_match = n_gt | |
if best_match: | |
node_mapping[n_pred] = best_match | |
return node_mapping | |
# also for the reodering | |
def rebuild_graph_with_mapped_nodes(G_pred, node_mapping): | |
"""Rebuild the predicted graph with nodes aligned to the ground truth.""" | |
G_aligned = nx.Graph() | |
for node_pred, node_gt in node_mapping.items(): | |
G_aligned.add_node(node_gt, label=G_pred.nodes[node_pred]['label']) | |
for u, v in G_pred.edges(): | |
if u in node_mapping and v in node_mapping: | |
G_aligned.add_edge(node_mapping[u], node_mapping[v]) | |
return G_aligned | |
class DonutModelPLModule(pl.LightningModule): | |
def __init__(self, config, processor, model): | |
super().__init__() | |
self.config = config | |
self.processor = processor | |
self.model = model | |
self.train_loss_epoch_total = 0.0 | |
self.val_loss_epoch_total = 0.0 | |
self.train_batch_count = 0 | |
self.val_batch_count = 0 | |
self.edit_distance_scores = [] | |
self.graph_metrics = { | |
'fast_graph_similarity': [], | |
'node_label_similarity': [], | |
'edge_similarity': [], | |
'degree_sequence_similarity': [], | |
'node_coverage': [], | |
'edge_precision': [], | |
'edge_recall': [] | |
} | |
self.lr = config["lr"] | |
self.warmup_steps = config["warmup_steps"] | |
def training_step(self, batch, batch_idx): | |
pixel_values, labels, _ = batch | |
outputs = self.model(pixel_values, labels=labels) | |
loss = outputs.loss | |
self.train_loss_epoch_total += loss.item() | |
self.train_batch_count += 1 | |
self.log("train_loss", loss, prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx, dataset_idx=0): | |
pixel_values, labels, answers = batch | |
outputs = self.model(pixel_values, labels=labels) | |
val_loss = outputs.loss | |
self.val_loss_epoch_total += val_loss.item() | |
self.val_batch_count += 1 | |
self.log("val_loss", val_loss) | |
if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: | |
logger.info(f'Finished epoch: {self.current_epoch + 1}') | |
print(f'Finished epoch: {self.current_epoch + 1}') | |
batch_size = pixel_values.shape[0] | |
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) | |
try: | |
outputs = self.model.generate(pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=self.config.get("max_length", 512), | |
early_stopping=True, | |
pad_token_id=self.processor.tokenizer.pad_token_id, | |
eos_token_id=self.processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[self.processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True,) | |
predictions = self.process_predictions(outputs) | |
logger.info('Calculating graph metrics') | |
print('Calculating graph metrics') | |
levenshtein_scores, graph_scores = self.calculate_metrics(predictions, answers) | |
logger.info('Finished calculating graph metrics') | |
print('Finished calculating graph metrics') | |
self.edit_distance_scores.append(np.mean(levenshtein_scores)) | |
for metric in self.graph_metrics: | |
self.graph_metrics[metric].append(np.mean([score[metric] for score in graph_scores if metric in score])) | |
self.log("val_edit_distance", np.mean(levenshtein_scores), prog_bar=True) | |
for metric in self.graph_metrics: | |
self.log(f"val_{metric}", self.graph_metrics[metric][-1], prog_bar=True) | |
except Exception as e: | |
logger.error(f"Error in validation step: {str(e)}") | |
print(f"Error in validation step: {str(e)}") | |
def process_predictions(self, outputs): | |
predictions = [] | |
for seq in self.processor.tokenizer.batch_decode(outputs.sequences): | |
try: | |
seq = ( | |
seq.replace(self.processor.tokenizer.eos_token, "") | |
.replace(self.processor.tokenizer.pad_token, "") | |
.replace('<n id=" ', '<n id="') | |
.replace('src=" ', 'src="') | |
.replace('tgt=" ', 'tgt="') | |
.replace('<newline>', '\n') | |
) | |
seq = re.sub(r"<s>", "", seq, count=1).strip() | |
seq = seq.replace("<s>", "") | |
predictions.append(seq) | |
except Exception as e: | |
logger.error(f"Error processing prediction: {str(e)}") | |
print(f"Error processing prediction: {str(e)}") | |
predictions.append("") # Append empty string if processing fails | |
return predictions | |
def calculate_metrics(self, predictions, answers): | |
levenshtein_scores = [] | |
graph_scores = [] | |
for pred, answer in zip(predictions, answers): | |
try: | |
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred) | |
answer = answer.replace(self.processor.tokenizer.bos_token, "").replace(self.processor.tokenizer.eos_token, "").replace("<newline>", "\n") | |
edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer)) | |
logger.info(f"Prediction: {pred}") | |
logger.info(f" Answer: {answer}") | |
logger.info(f" Normed ED: {edit_dist}") | |
print(f"Prediction: {pred}") | |
print(f" Answer: {answer}") | |
print(f" Normed ED: {edit_dist}") | |
levenshtein_scores.append(edit_dist) | |
pred_graph = self.create_graph_from_string(pred) | |
answer_graph = self.create_graph_from_string(answer) | |
# Added this to reorder the predicted graphs ignoring the node order for better validation | |
# Match nodes based on labels and reorder | |
node_mapping = match_nodes_by_label(pred_graph, answer_graph) | |
pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping) | |
# Compare the aligned graphs | |
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) | |
logger.info('Calculating the GED') | |
print('Calculating the GED') | |
# graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60)) | |
graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) | |
logger.info('Got the GED results') | |
print('Got the GED results') | |
except Exception as e: | |
logger.error(f"Error calculating metrics: {str(e)}") | |
print(f"Error calculating metrics: {str(e)}") | |
levenshtein_scores.append(1.0) # Worst possible score | |
graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores | |
return levenshtein_scores, graph_scores | |
def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60): | |
def wrapper(return_dict): | |
return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph) | |
manager = multiprocessing.Manager() | |
return_dict = manager.dict() | |
p = multiprocessing.Process(target=wrapper, args=(return_dict,)) | |
p.start() | |
p.join(timeout) | |
if p.is_alive(): | |
logger.warning('Graph comparison timed out. Returning default values.') | |
print('Graph comparison timed out. Returning default values.') | |
p.terminate() | |
p.join() | |
return { | |
"fast_graph_similarity": 0.0, | |
"node_label_similarity": 0.0, | |
"edge_similarity": 0.0, | |
"degree_sequence_similarity": 0.0, | |
"node_coverage": 0.0, | |
"edge_precision": 0.0, | |
"edge_recall": 0.0 | |
} | |
else: | |
return return_dict.get('result', { | |
"fast_graph_similarity": 0.0, | |
"node_label_similarity": 0.0, | |
"edge_similarity": 0.0, | |
"degree_sequence_similarity": 0.0, | |
"node_coverage": 0.0, | |
"edge_precision": 0.0, | |
"edge_recall": 0.0 | |
}) | |
def create_graph_from_string(xml_string): | |
G = nx.Graph() | |
try: | |
# Extract nodes | |
nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', xml_string, re.DOTALL) | |
for node_id, label in nodes: | |
G.add_node(node_id, label=label.lower()) | |
# Extract edges | |
edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', xml_string) | |
for src, tgt in edges: | |
G.add_edge(src, tgt) | |
except Exception as e: | |
logger.error(f"Error creating graph from string: {str(e)}") | |
print(f"Error creating graph from string: {str(e)}") | |
return G | |
def normalized_levenshtein(s1, s2): | |
distance = Levenshtein.distance(s1, s2) | |
max_length = max(len(s1), len(s2)) | |
return distance / max_length if max_length > 0 else 0 | |
def calculate_node_coverage(G1, G2, threshold=0.2): | |
matched_nodes = 0 | |
for n1 in G1.nodes(data=True): | |
if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold | |
for n2 in G2.nodes(data=True)): | |
matched_nodes += 1 | |
return matched_nodes / max(len(G1), len(G2)) | |
def node_label_similarity(G1, G2): | |
labels1 = list(nx.get_node_attributes(G1, 'label').values()) | |
labels2 = list(nx.get_node_attributes(G2, 'label').values()) | |
total_similarity = 0 | |
for label1 in labels1: | |
similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2] | |
total_similarity += max(similarities) if similarities else 0 | |
return total_similarity / len(labels1) if labels1 else 0 | |
def edge_similarity(G1, G2): | |
return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1 | |
def degree_sequence_similarity(G1, G2): | |
seq1 = sorted([d for n, d in G1.degree()], reverse=True) | |
seq2 = sorted([d for n, d in G2.degree()], reverse=True) | |
# If either sequence is empty, return 0 similarity | |
if not seq1 or not seq2: | |
return 0.0 | |
# Padding sequences to make them the same length | |
max_len = max(len(seq1), len(seq2)) | |
seq1 += [0] * (max_len - len(seq1)) | |
seq2 += [0] * (max_len - len(seq2)) | |
# Calculate degree sequence similarity | |
diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2)) | |
# Return similarity, handle edge case where the sum of degrees is zero | |
return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0 | |
def fast_graph_similarity(G1, G2): | |
node_sim = DonutModelPLModule.node_label_similarity(G1, G2) | |
edge_sim = DonutModelPLModule.edge_similarity(G1, G2) | |
degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2) | |
return (node_sim + edge_sim + degree_sim) / 3 | |
def compare_graphs(G1, G2): | |
try: | |
node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2) | |
G1_edges = set(G1.edges()) | |
G2_edges = set(G2.edges()) | |
correct_edges = len(G1_edges & G2_edges) | |
edge_precision = correct_edges / len(G2_edges) if G2_edges else 0 | |
edge_recall = correct_edges / len(G1_edges) if G1_edges else 0 | |
return { | |
"fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2), | |
"node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2), | |
"edge_similarity": DonutModelPLModule.edge_similarity(G1, G2), | |
"degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2), | |
"node_coverage": node_coverage, | |
"edge_precision": edge_precision, | |
"edge_recall": edge_recall | |
} | |
except Exception as e: | |
logger.error(f"Error comparing graphs: {str(e)}") | |
print(f"Error comparing graphs: {str(e)}") | |
return { | |
"fast_graph_similarity": 0.0, | |
"node_label_similarity": 0.0, | |
"edge_similarity": 0.0, | |
"degree_sequence_similarity": 0.0, | |
"node_coverage": 0.0, | |
"edge_precision": 0.0, | |
"edge_recall": 0.0 | |
} | |
def configure_optimizers(self): | |
# Define the optimizer | |
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) | |
# Define the warmup + decay scheduler | |
def lr_lambda(current_step): | |
if current_step < self.warmup_steps: | |
return float(current_step) / float(max(1, self.warmup_steps)) | |
return 1.0 # You can replace this with a decay function like exponential decay | |
scheduler = LambdaLR(optimizer, lr_lambda) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': { | |
'scheduler': scheduler, | |
'interval': 'step', # Update the learning rate after every training step | |
'frequency': 1, # How often the scheduler is called (every step) | |
} | |
} | |
def on_validation_epoch_end(self): | |
avg_val_loss = self.val_loss_epoch_total / self.val_batch_count | |
mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch) | |
self.val_loss_epoch_total = 0.0 | |
self.val_batch_count = 0 | |
if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: | |
if self.edit_distance_scores: | |
mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch) | |
for metric in self.graph_metrics: | |
if self.graph_metrics[metric]: | |
mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch) | |
print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1) | |
def on_train_epoch_end(self): | |
print(f'[INFO] - Finished epoch {self.current_epoch + 1}') | |
avg_train_loss = self.train_loss_epoch_total / self.train_batch_count | |
print(f'[INFO] - Train loss: {avg_train_loss}') | |
mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch) | |
self.train_loss_epoch_total = 0.0 | |
self.train_batch_count = 0 | |
if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0: | |
self.save_model() | |
def on_fit_end(self): | |
self.save_model() | |
def save_model(self): | |
model_dir = "Donut_model" | |
os.makedirs(model_dir, exist_ok=True) | |
self.model.save_pretrained(model_dir) | |
print('[INFO] - Saving the model to dagshub using mlflow') | |
mlflow.transformers.log_model( | |
transformers_model={ | |
"model": self.model, | |
"feature_extractor": self.processor.feature_extractor, | |
"image_processor": self.processor.image_processor, | |
"tokenizer": self.processor.tokenizer | |
}, | |
artifact_path=model_dir, | |
# Set task explicitly since MLflow cannot infer it from the loaded model | |
task = "image-to-text" | |
) | |
print('[INFO] - Saved the model to dagshub using mlflow') | |
def train_dataloader(self): | |
return train_dataloader | |
def val_dataloader(self): | |
return val_dataloader | |
config = {"max_epochs":200, | |
# "val_check_interval":0.2, # how many times we want to validate during an epoch | |
"check_val_every_n_epoch":1, | |
"gradient_clip_val":1.0, | |
# "num_training_samples_per_epoch": 800, | |
"lr":8e-4, #3e-4, #3e-5, | |
"train_batch_sizes": [1], #[8], #[1],#[8], | |
"val_batch_sizes": [1], | |
# "seed":2022, | |
"num_nodes": 1, | |
"warmup_steps": 200, # 800/8*30/10, 10% | |
"verbose": True, | |
} | |
model_module = DonutModelPLModule(config, processor, model) | |
# Load dataset | |
dataset = split_dataset['test'] | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
class Sharpen: | |
def __call__(self, img): | |
return img.filter(ImageFilter.SHARPEN) | |
def preprocess_image(image): | |
# Convert to PIL Image if it's not already | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Apply sharpening | |
sharpen = Sharpen() | |
sharpened_image = sharpen(image) | |
return sharpened_image | |
def perform_inference(image): | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
pixel_values = inputs.pixel_values.to(device) | |
# Prepare decoder input ids | |
batch_size = pixel_values.shape[0] | |
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device) | |
# Generate output | |
outputs = model.generate( | |
pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=max_length, # + 500, #512, # Adjust as needed | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
# Decode the output | |
decoded_output = processor.batch_decode(outputs.sequences)[0] | |
print("Raw model output:", decoded_output) | |
return decoded_output | |
def display_example(index): | |
example = dataset[index] | |
img = example["image"] | |
return img, None, None | |
def from_json_like_to_xml_like(data): | |
def parse_nodes(nodes): | |
node_elements = [] | |
for node in nodes: | |
label = node["label"] | |
node_elements.append(f'<n id="{node["id"]}">{label}</n>') | |
return "<nodes>\n" + "".join(node_elements) + "\n</nodes>" | |
def parse_edges(edges): | |
edge_elements = [] | |
for edge in edges: | |
edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>') | |
return "<edges>\n" + "".join(edge_elements) + "\n</edges>" | |
nodes_xml = parse_nodes(data["nodes"]) | |
edges_xml = parse_edges(data["edges"]) | |
return nodes_xml + "\n" + edges_xml | |
def reshape_json_data_to_fit_visualize_graph(graph_data): | |
nodes = graph_data["nodes"] | |
edges = graph_data["edges"] | |
transformed_nodes = [ | |
{"id": nodes["id"][idx], "label": nodes["label"][idx]} | |
for idx in range(len(nodes["id"])) | |
] | |
transformed_edges = [ | |
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} | |
for idx in range(len(edges["source"])) | |
] | |
return {"nodes": transformed_nodes, "edges": transformed_edges} | |
def get_ground_truth(index): | |
example = dataset[index] | |
ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example)) | |
ground_truth = from_json_like_to_xml_like(json.loads(ground_truth)) | |
print(f'Ground truth sequence: {ground_truth}') | |
return ground_truth | |
def transform_image(img, index, physics_enabled): | |
# Perform inference | |
sequence = perform_inference(img) | |
# Transform the sequence to graph data | |
graph_data = transform_sequence(sequence) | |
# Generate the graph visualization | |
graph_html = visualize_graph(graph_data, physics_enabled) | |
# Modify the iframe to have a fixed height | |
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') | |
# Convert graph_data to a formatted JSON string | |
json_data = json.dumps(graph_data, indent=2) | |
return graph_html, json_data, sequence | |
import re | |
from typing import Dict, List, Tuple | |
def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]: | |
# Extract nodes and edges | |
nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL) | |
edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL) | |
if not nodes_match or not edges_match: | |
raise ValueError("Invalid input sequence: nodes or edges not found") | |
nodes_text = nodes_match.group(1) | |
edges_text = edges_match.group(1) | |
# Parse nodes | |
nodes = [] | |
for node_match in re.finditer(r'<n id="\s*(\d+)">(.*?)</n>', nodes_text): | |
node_id, node_label = node_match.groups() | |
nodes.append({ | |
"id": node_id.strip(), | |
"label": node_label.strip() | |
}) | |
# Parse edges | |
edges = [] | |
for edge_match in re.finditer(r'<e src="\s*(\d+)" tgt="\s*(\d+)"/>', edges_text): | |
source, target = edge_match.groups() | |
edges.append({ | |
"source": source.strip(), | |
"target": target.strip(), | |
"type": "->" | |
}) | |
return { | |
"nodes": nodes, | |
"edges": edges | |
} | |
# function to visualize the extracted graph | |
import json | |
from pyvis.network import Network | |
def create_graph(nodes, edges, physics_enabled=True): | |
net = Network( | |
notebook=True, | |
height="100vh", | |
width="100vw", | |
bgcolor="#222222", | |
font_color="white", | |
cdn_resources="remote", | |
) | |
for node in nodes: | |
net.add_node( | |
node["id"], | |
label=node["label"], | |
title=node["label"], | |
color="blue" if node["label"] == "OOP" else "green", | |
) | |
for edge in edges: | |
net.add_edge(edge["source"], edge["target"], title=edge["type"]) | |
net.force_atlas_2based( | |
gravity=-50, | |
central_gravity=0.01, | |
spring_length=100, | |
spring_strength=0.08, | |
damping=0.4, | |
) | |
options = { | |
"nodes": {"physics": physics_enabled}, | |
"edges": {"smooth": True}, | |
"interaction": {"hover": True, "zoomView": True}, | |
"physics": { | |
"enabled": physics_enabled, | |
"stabilization": {"enabled": True, "iterations": 200}, | |
}, | |
} | |
net.set_options(json.dumps(options)) | |
return net | |
def visualize_graph(json_data, physics_enabled=True): | |
if isinstance(json_data, str): | |
data = json.loads(json_data) | |
else: | |
data = json_data | |
nodes = data["nodes"] | |
edges = data["edges"] | |
net = create_graph(nodes, edges, physics_enabled) | |
html = net.generate_html() | |
html = html.replace("'", '"') | |
html = html.replace( | |
'<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"' | |
) | |
return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>""" | |
def update_physics(json_data, physics_enabled): | |
if json_data is None: | |
return None | |
data = json.loads(json_data) | |
graph_html = visualize_graph(data, physics_enabled) | |
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') | |
return graph_html | |
# function to calculate the graph similarity metrics between the prediction and the ground-truth | |
def calculate_and_display_metrics(pred_graph, ground_truth_graph): | |
if pred_graph is None or ground_truth_graph is None: | |
return "Please generate a prediction and ensure a ground truth graph is available." | |
#removing the start token from the string | |
pred_graph = pred_graph.replace('<s>', "").replace("<newline>", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('<n id=" ', '<n id="') | |
print(f'Prediction: {pred_graph}') | |
# Assuming the graphs are in the correct format for the calculate_metrics function | |
metrics = model_module.calculate_metrics([pred_graph], [ground_truth_graph]) | |
# Format the metrics for display | |
overall_metric = metrics[0][0] | |
detailed_metrics = metrics[1][0] | |
# output = f"Overall Metric: {overall_metric:.4f}\n\nDetailed Metrics:\n" | |
output = f"Detailed Metrics:\n" | |
for key, value in detailed_metrics.items(): | |
output += f"{key}: {value:.4f}\n" | |
return output | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Knowledge Graph Visualizer with Model Inference") | |
with gr.Row(): | |
index_slider = gr.Slider( | |
minimum=0, | |
maximum=len(dataset) - 1, | |
step=1, | |
label="Example Index" | |
) | |
with gr.Row(): | |
image_output = gr.Image(type="pil", label="Image", height=500, interactive=False) | |
graph_output = gr.HTML(label="Knowledge Graph") | |
with gr.Row(): | |
transform_button = gr.Button("Transform") | |
physics_toggle = gr.Checkbox(label="Enable Physics", value=True) | |
with gr.Row(): | |
json_output = gr.Code(language="json", label="Graph JSON Data") | |
ground_truth_output = gr.Textbox(visible=False)#gr.JSON(label="Ground Truth Graph", visible=False) | |
predicted_raw_sequence = gr.Textbox(visible=False) | |
with gr.Row(): | |
metrics_button = gr.Button("Calculate Metrics") | |
metrics_output = gr.Textbox(label="Similarity Metrics", lines=10) | |
index_slider.change( | |
fn=display_example, | |
inputs=[index_slider], | |
outputs=[image_output, graph_output, json_output], | |
).then( | |
fn=get_ground_truth, | |
inputs=[index_slider], | |
outputs=[ground_truth_output], | |
) | |
transform_button.click( | |
fn=transform_image, | |
inputs=[image_output, index_slider, physics_toggle], | |
outputs=[graph_output, json_output, predicted_raw_sequence], | |
).then( | |
fn=calculate_and_display_metrics, | |
inputs=[predicted_raw_sequence, ground_truth_output], | |
outputs=[metrics_output]#gr.Textbox(label="Metrics"), | |
) | |
metrics_button.click( | |
fn=calculate_and_display_metrics, | |
inputs=[predicted_raw_sequence, ground_truth_output], | |
outputs=[metrics_output], | |
) | |
physics_toggle.change( | |
fn=update_physics, | |
inputs=[json_output, physics_toggle], | |
outputs=[graph_output], | |
) | |
return demo | |
# Create and launch the interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=True, debug=True) |