Zaherrr commited on
Commit
3251231
·
verified ·
1 Parent(s): aa8d81c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +568 -0
app.py CHANGED
@@ -5,6 +5,574 @@ import json
5
  import torch
6
  from torchvision import transforms
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Load dataset
9
  dataset = split_dataset['test']
10
 
 
5
  import torch
6
  from torchvision import transforms
7
 
8
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
9
+
10
+ import dagshub
11
+ import mlflow
12
+ import time
13
+ import os
14
+
15
+ # from kaggle_secrets import UserSecretsClient
16
+ # user_secrets = UserSecretsClient()
17
+ # token = user_secrets.get_secret("dags_hub_token")
18
+ # from google.colab import userdata
19
+ # token = userdata.get('dags_hub_token')
20
+ token = os.getenv('dags_hub_token')
21
+ dagshub.auth.add_app_token(token)
22
+
23
+ dagshub.init(repo_owner='zaheramasha',
24
+ repo_name='Finetuning_paligemma_Zaka_capstone',
25
+ mlflow=True)
26
+
27
+ # Define the MLflow run ID and artifact path
28
+ run_id = "c41cfd149a8c44f3a92d8e0f1253af35" # Donut model trained on the PyvizAndMarkMap dataset for 27 epochs reaching a train loss of 0.168
29
+ run_id = "89bafd5e525a4d3e9d004e13c9574198" # Donut model trained on the PyvizAndMarkMap dataset for 27 + 51 = 78 epochs reaching a train loss of 0.0353. This run was a continuation of the 27 epoch one
30
+
31
+ artifact_path = "Donut_model/model"
32
+
33
+ # Create the model URI using the run ID and artifact path
34
+ model_uri = f"runs:/{run_id}/{artifact_path}"
35
+ print(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=artifact_path))
36
+ # Load the model and processors from the MLflow artifact
37
+ # loaded_model_bundle = mlflow.transformers.load_model(artifact_path=artifact_path, run_id=run_id)
38
+ # for the 20 epochs trained model
39
+ model_uri = f"mlflow-artifacts:/0a5d0550f55c4169b80cd6439556be8b/c41cfd149a8c44f3a92d8e0f1253af35/artifacts/Donut_model"
40
+
41
+ # for the fully 70 epochs trained model
42
+ model_uri = f"mlflow-artifacts:/17c375f6eab34c63b2a2e7792803132e/89bafd5e525a4d3e9d004e13c9574198/artifacts/Donut_model"
43
+ loaded_model_bundle = mlflow.transformers.load_model(model_uri=model_uri, device='cuda')
44
+
45
+ model = loaded_model_bundle.model
46
+ processor = DonutProcessor(tokenizer=loaded_model_bundle.tokenizer, feature_extractor=loaded_model_bundle.feature_extractor, image_processor=loaded_model_bundle.image_processor)
47
+ print(model.config.encoder.image_size)
48
+ print(model.config.decoder.max_length)
49
+
50
+
51
+ import json
52
+ import random
53
+ from typing import Any, List, Tuple, Dict
54
+ import torch
55
+ from torch.utils.data import Dataset
56
+ from datasets import load_dataset, DatasetDict, concatenate_datasets
57
+ from PIL import Image, ImageFilter
58
+ from torchvision import transforms
59
+ import re
60
+
61
+ # Load and split the dataset
62
+ Pyviz_dataset = load_dataset("Zaherrr/OOP_KG_Pyviz_Synthetic_Dataset", revision="Sorted_edges")
63
+ MarkMap_dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset")
64
+ combined_dataset = concatenate_datasets([Pyviz_dataset['data'], MarkMap_dataset['data']])
65
+
66
+ train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42)
67
+ train_val_split = train_test_split["train"].train_test_split(test_size=0.125, seed=42)
68
+ split_dataset = DatasetDict(
69
+ {
70
+ "train": train_val_split["train"],
71
+ "val": train_val_split["test"],
72
+ "test": train_test_split["test"],
73
+ }
74
+ )
75
+
76
+ def reshape_json_data_to_fit_visualize_graph(graph_data):
77
+ nodes = graph_data["nodes"]
78
+ edges = graph_data["edges"]
79
+ transformed_nodes = [
80
+ {"id": nodes["id"][idx], "label": nodes["label"][idx]}
81
+ for idx in range(len(nodes["id"]))
82
+ ]
83
+ transformed_edges = [
84
+ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
85
+ for idx in range(len(edges["source"]))
86
+ ]
87
+ return {"nodes": transformed_nodes, "edges": transformed_edges}
88
+
89
+ def from_json_like_to_xml_like(data):
90
+ def parse_nodes(nodes):
91
+ node_elements = []
92
+ for node in nodes:
93
+ label = node["label"]
94
+ node_elements.append(f'<n id="{node["id"]}">{label}</n>')
95
+ return "<nodes>\n" + "".join(node_elements) + "\n</nodes>"
96
+
97
+ def parse_edges(edges):
98
+ edge_elements = []
99
+ for edge in edges:
100
+ edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>')
101
+ return "<edges>\n" + "".join(edge_elements) + "\n</edges>"
102
+
103
+ nodes_xml = parse_nodes(data["nodes"])
104
+ edges_xml = parse_edges(data["edges"])
105
+ return nodes_xml + "\n" + edges_xml
106
+
107
+
108
+ # function to shuffle the nodes on the fly in an attempt to reduce the bias from random node extraction
109
+ def flexible_node_shuffle(sequence):
110
+ # Split the sequence into nodes and edges
111
+ nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL)
112
+ edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL)
113
+
114
+ if not nodes_match or not edges_match:
115
+ print("Error: Could not find nodes or edges in the sequence.")
116
+ return sequence
117
+
118
+ nodes_content = nodes_match.group(1)
119
+ edges_content = edges_match.group(1)
120
+
121
+ # Extract individual nodes
122
+ nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', nodes_content, re.DOTALL)
123
+
124
+ # Shuffle the nodes
125
+ random.shuffle(nodes)
126
+
127
+ # Create a mapping of old ids to new ids
128
+ id_mapping = {old_id: str(new_id) for new_id, (old_id, _) in enumerate(nodes, start=1)}
129
+
130
+ # Reconstruct the nodes section with new ids
131
+ new_nodes_content = "".join(f'<n id="{new_id}">{content}</n>' for new_id, (_, content) in enumerate(nodes, start=1))
132
+
133
+ # Extract and update edge information
134
+ edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', edges_content)
135
+ new_edges = []
136
+ for src, tgt in edges:
137
+ new_src = int(id_mapping[src])
138
+ new_tgt = int(id_mapping[tgt])
139
+ # Append edge as tuple (original_src, original_tgt)
140
+ new_edges.append((new_src, new_tgt))
141
+
142
+ # Sort edges: first by the new src node id, then by the new tgt node id (preserving the original direction)
143
+ new_edges.sort(key=lambda x: (min(x[0], x[1]), max(x[0], x[1])))
144
+
145
+ # Reconstruct the edges section, preserving original direction
146
+ new_edges_content = "".join(f'<e src="{src}" tgt="{tgt}"/>' if src < tgt else f'<e src="{tgt}" tgt="{src}"/>' for src, tgt in new_edges)
147
+
148
+ # Reconstruct the full sequence
149
+ new_sequence = f'<nodes><newline>{new_nodes_content}<newline></nodes><newline><edges><newline>{new_edges_content}<newline></edges>'
150
+
151
+ return new_sequence
152
+
153
+ class Sharpen:
154
+ def __call__(self, img):
155
+ return img.filter(ImageFilter.SHARPEN)
156
+
157
+ # with the graph edit distance validation
158
+ import re
159
+ from nltk import edit_distance
160
+ import numpy as np
161
+ import torch
162
+ import pytorch_lightning as pl
163
+ import mlflow
164
+ import networkx as nx
165
+ import Levenshtein
166
+ import xml.etree.ElementTree as ET
167
+ import multiprocessing
168
+ import logging
169
+ from torch.optim.lr_scheduler import LambdaLR
170
+
171
+ logging.basicConfig(level=logging.INFO)
172
+ logger = logging.getLogger(__name__)
173
+
174
+
175
+ # for the node matching and reordering to align with the ground truth graph
176
+ def match_nodes_by_label(G_pred, G_gt):
177
+ """Match nodes from predicted graph to ground truth graph based on label similarity."""
178
+ node_mapping = {}
179
+ for n_pred, pred_data in G_pred.nodes(data=True):
180
+ best_match = None
181
+ best_score = float('inf') # Levenshtein is a distance metric, lower is better
182
+ for n_gt, gt_data in G_gt.nodes(data=True):
183
+ sim_score = DonutModelPLModule.normalized_levenshtein(pred_data['label'], gt_data['label'])
184
+ if sim_score < best_score:
185
+ best_score = sim_score
186
+ best_match = n_gt
187
+ if best_match:
188
+ node_mapping[n_pred] = best_match
189
+ return node_mapping
190
+
191
+ # also for the reodering
192
+ def rebuild_graph_with_mapped_nodes(G_pred, node_mapping):
193
+ """Rebuild the predicted graph with nodes aligned to the ground truth."""
194
+ G_aligned = nx.Graph()
195
+ for node_pred, node_gt in node_mapping.items():
196
+ G_aligned.add_node(node_gt, label=G_pred.nodes[node_pred]['label'])
197
+
198
+ for u, v in G_pred.edges():
199
+ if u in node_mapping and v in node_mapping:
200
+ G_aligned.add_edge(node_mapping[u], node_mapping[v])
201
+
202
+ return G_aligned
203
+
204
+ class DonutModelPLModule(pl.LightningModule):
205
+ def __init__(self, config, processor, model):
206
+ super().__init__()
207
+ self.config = config
208
+ self.processor = processor
209
+ self.model = model
210
+ self.train_loss_epoch_total = 0.0
211
+ self.val_loss_epoch_total = 0.0
212
+ self.train_batch_count = 0
213
+ self.val_batch_count = 0
214
+ self.edit_distance_scores = []
215
+ self.graph_metrics = {
216
+ 'fast_graph_similarity': [],
217
+ 'node_label_similarity': [],
218
+ 'edge_similarity': [],
219
+ 'degree_sequence_similarity': [],
220
+ 'node_coverage': [],
221
+ 'edge_precision': [],
222
+ 'edge_recall': []
223
+ }
224
+ self.lr = config["lr"]
225
+ self.warmup_steps = config["warmup_steps"]
226
+
227
+
228
+ def training_step(self, batch, batch_idx):
229
+ pixel_values, labels, _ = batch
230
+ outputs = self.model(pixel_values, labels=labels)
231
+ loss = outputs.loss
232
+ self.train_loss_epoch_total += loss.item()
233
+ self.train_batch_count += 1
234
+ self.log("train_loss", loss, prog_bar=True)
235
+ return loss
236
+
237
+ def validation_step(self, batch, batch_idx, dataset_idx=0):
238
+ pixel_values, labels, answers = batch
239
+ outputs = self.model(pixel_values, labels=labels)
240
+ val_loss = outputs.loss
241
+ self.val_loss_epoch_total += val_loss.item()
242
+ self.val_batch_count += 1
243
+ self.log("val_loss", val_loss)
244
+
245
+ if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0:
246
+ logger.info(f'Finished epoch: {self.current_epoch + 1}')
247
+ print(f'Finished epoch: {self.current_epoch + 1}')
248
+ batch_size = pixel_values.shape[0]
249
+ decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
250
+
251
+ try:
252
+ outputs = self.model.generate(pixel_values,
253
+ decoder_input_ids=decoder_input_ids,
254
+ max_length=self.config.get("max_length", 512),
255
+ early_stopping=True,
256
+ pad_token_id=self.processor.tokenizer.pad_token_id,
257
+ eos_token_id=self.processor.tokenizer.eos_token_id,
258
+ use_cache=True,
259
+ num_beams=1,
260
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
261
+ return_dict_in_generate=True,)
262
+
263
+ predictions = self.process_predictions(outputs)
264
+ logger.info('Calculating graph metrics')
265
+ print('Calculating graph metrics')
266
+ levenshtein_scores, graph_scores = self.calculate_metrics(predictions, answers)
267
+ logger.info('Finished calculating graph metrics')
268
+ print('Finished calculating graph metrics')
269
+
270
+ self.edit_distance_scores.append(np.mean(levenshtein_scores))
271
+ for metric in self.graph_metrics:
272
+ self.graph_metrics[metric].append(np.mean([score[metric] for score in graph_scores if metric in score]))
273
+
274
+ self.log("val_edit_distance", np.mean(levenshtein_scores), prog_bar=True)
275
+ for metric in self.graph_metrics:
276
+ self.log(f"val_{metric}", self.graph_metrics[metric][-1], prog_bar=True)
277
+ except Exception as e:
278
+ logger.error(f"Error in validation step: {str(e)}")
279
+ print(f"Error in validation step: {str(e)}")
280
+
281
+ def process_predictions(self, outputs):
282
+ predictions = []
283
+ for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
284
+ try:
285
+ seq = (
286
+ seq.replace(self.processor.tokenizer.eos_token, "")
287
+ .replace(self.processor.tokenizer.pad_token, "")
288
+ .replace('<n id=" ', '<n id="')
289
+ .replace('src=" ', 'src="')
290
+ .replace('tgt=" ', 'tgt="')
291
+ .replace('<newline>', '\n')
292
+ )
293
+ seq = re.sub(r"<s>", "", seq, count=1).strip()
294
+ seq = seq.replace("<s>", "")
295
+ predictions.append(seq)
296
+ except Exception as e:
297
+ logger.error(f"Error processing prediction: {str(e)}")
298
+ print(f"Error processing prediction: {str(e)}")
299
+ predictions.append("") # Append empty string if processing fails
300
+ return predictions
301
+
302
+
303
+ def calculate_metrics(self, predictions, answers):
304
+ levenshtein_scores = []
305
+ graph_scores = []
306
+ for pred, answer in zip(predictions, answers):
307
+ try:
308
+ pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
309
+ answer = answer.replace(self.processor.tokenizer.bos_token, "").replace(self.processor.tokenizer.eos_token, "").replace("<newline>", "\n")
310
+ edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer))
311
+
312
+ logger.info(f"Prediction: {pred}")
313
+ logger.info(f" Answer: {answer}")
314
+ logger.info(f" Normed ED: {edit_dist}")
315
+ print(f"Prediction: {pred}")
316
+ print(f" Answer: {answer}")
317
+ print(f" Normed ED: {edit_dist}")
318
+ levenshtein_scores.append(edit_dist)
319
+
320
+ pred_graph = self.create_graph_from_string(pred)
321
+ answer_graph = self.create_graph_from_string(answer)
322
+
323
+ # Added this to reorder the predicted graphs ignoring the node order for better validation
324
+ # Match nodes based on labels and reorder
325
+ node_mapping = match_nodes_by_label(pred_graph, answer_graph)
326
+ pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping)
327
+
328
+ # Compare the aligned graphs
329
+ # graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
330
+
331
+ logger.info('Calculating the GED')
332
+ print('Calculating the GED')
333
+ # graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60))
334
+ graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60))
335
+ logger.info('Got the GED results')
336
+ print('Got the GED results')
337
+ except Exception as e:
338
+ logger.error(f"Error calculating metrics: {str(e)}")
339
+ print(f"Error calculating metrics: {str(e)}")
340
+ levenshtein_scores.append(1.0) # Worst possible score
341
+ graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores
342
+ return levenshtein_scores, graph_scores
343
+
344
+ @staticmethod
345
+ def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60):
346
+ def wrapper(return_dict):
347
+ return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph)
348
+
349
+ manager = multiprocessing.Manager()
350
+ return_dict = manager.dict()
351
+ p = multiprocessing.Process(target=wrapper, args=(return_dict,))
352
+ p.start()
353
+ p.join(timeout)
354
+
355
+ if p.is_alive():
356
+ logger.warning('Graph comparison timed out. Returning default values.')
357
+ print('Graph comparison timed out. Returning default values.')
358
+ p.terminate()
359
+ p.join()
360
+ return {
361
+ "fast_graph_similarity": 0.0,
362
+ "node_label_similarity": 0.0,
363
+ "edge_similarity": 0.0,
364
+ "degree_sequence_similarity": 0.0,
365
+ "node_coverage": 0.0,
366
+ "edge_precision": 0.0,
367
+ "edge_recall": 0.0
368
+ }
369
+ else:
370
+ return return_dict.get('result', {
371
+ "fast_graph_similarity": 0.0,
372
+ "node_label_similarity": 0.0,
373
+ "edge_similarity": 0.0,
374
+ "degree_sequence_similarity": 0.0,
375
+ "node_coverage": 0.0,
376
+ "edge_precision": 0.0,
377
+ "edge_recall": 0.0
378
+ })
379
+
380
+ @staticmethod
381
+ def create_graph_from_string(xml_string):
382
+ G = nx.Graph()
383
+ try:
384
+ # Extract nodes
385
+ nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', xml_string, re.DOTALL)
386
+ for node_id, label in nodes:
387
+ G.add_node(node_id, label=label.lower())
388
+
389
+ # Extract edges
390
+ edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', xml_string)
391
+ for src, tgt in edges:
392
+ G.add_edge(src, tgt)
393
+ except Exception as e:
394
+ logger.error(f"Error creating graph from string: {str(e)}")
395
+ print(f"Error creating graph from string: {str(e)}")
396
+ return G
397
+
398
+ @staticmethod
399
+ def normalized_levenshtein(s1, s2):
400
+ distance = Levenshtein.distance(s1, s2)
401
+ max_length = max(len(s1), len(s2))
402
+ return distance / max_length if max_length > 0 else 0
403
+
404
+ @staticmethod
405
+ def calculate_node_coverage(G1, G2, threshold=0.2):
406
+ matched_nodes = 0
407
+ for n1 in G1.nodes(data=True):
408
+ if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold
409
+ for n2 in G2.nodes(data=True)):
410
+ matched_nodes += 1
411
+ return matched_nodes / max(len(G1), len(G2))
412
+
413
+ @staticmethod
414
+ def node_label_similarity(G1, G2):
415
+ labels1 = list(nx.get_node_attributes(G1, 'label').values())
416
+ labels2 = list(nx.get_node_attributes(G2, 'label').values())
417
+
418
+ total_similarity = 0
419
+ for label1 in labels1:
420
+ similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2]
421
+ total_similarity += max(similarities) if similarities else 0
422
+
423
+ return total_similarity / len(labels1) if labels1 else 0
424
+
425
+ @staticmethod
426
+ def edge_similarity(G1, G2):
427
+ return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1
428
+
429
+ @staticmethod
430
+ def degree_sequence_similarity(G1, G2):
431
+ seq1 = sorted([d for n, d in G1.degree()], reverse=True)
432
+ seq2 = sorted([d for n, d in G2.degree()], reverse=True)
433
+
434
+ # If either sequence is empty, return 0 similarity
435
+ if not seq1 or not seq2:
436
+ return 0.0
437
+
438
+ # Padding sequences to make them the same length
439
+ max_len = max(len(seq1), len(seq2))
440
+ seq1 += [0] * (max_len - len(seq1))
441
+ seq2 += [0] * (max_len - len(seq2))
442
+
443
+ # Calculate degree sequence similarity
444
+ diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2))
445
+
446
+ # Return similarity, handle edge case where the sum of degrees is zero
447
+ return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0
448
+
449
+ @staticmethod
450
+ def fast_graph_similarity(G1, G2):
451
+ node_sim = DonutModelPLModule.node_label_similarity(G1, G2)
452
+ edge_sim = DonutModelPLModule.edge_similarity(G1, G2)
453
+ degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2)
454
+ return (node_sim + edge_sim + degree_sim) / 3
455
+
456
+ @staticmethod
457
+ def compare_graphs(G1, G2):
458
+ try:
459
+ node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2)
460
+ G1_edges = set(G1.edges())
461
+ G2_edges = set(G2.edges())
462
+ correct_edges = len(G1_edges & G2_edges)
463
+ edge_precision = correct_edges / len(G2_edges) if G2_edges else 0
464
+ edge_recall = correct_edges / len(G1_edges) if G1_edges else 0
465
+ return {
466
+ "fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2),
467
+ "node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2),
468
+ "edge_similarity": DonutModelPLModule.edge_similarity(G1, G2),
469
+ "degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2),
470
+ "node_coverage": node_coverage,
471
+ "edge_precision": edge_precision,
472
+ "edge_recall": edge_recall
473
+ }
474
+ except Exception as e:
475
+ logger.error(f"Error comparing graphs: {str(e)}")
476
+ print(f"Error comparing graphs: {str(e)}")
477
+ return {
478
+ "fast_graph_similarity": 0.0,
479
+ "node_label_similarity": 0.0,
480
+ "edge_similarity": 0.0,
481
+ "degree_sequence_similarity": 0.0,
482
+ "node_coverage": 0.0,
483
+ "edge_precision": 0.0,
484
+ "edge_recall": 0.0
485
+ }
486
+
487
+ def configure_optimizers(self):
488
+ # Define the optimizer
489
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
490
+
491
+ # Define the warmup + decay scheduler
492
+ def lr_lambda(current_step):
493
+ if current_step < self.warmup_steps:
494
+ return float(current_step) / float(max(1, self.warmup_steps))
495
+ return 1.0 # You can replace this with a decay function like exponential decay
496
+
497
+ scheduler = LambdaLR(optimizer, lr_lambda)
498
+
499
+ return {
500
+ 'optimizer': optimizer,
501
+ 'lr_scheduler': {
502
+ 'scheduler': scheduler,
503
+ 'interval': 'step', # Update the learning rate after every training step
504
+ 'frequency': 1, # How often the scheduler is called (every step)
505
+ }
506
+ }
507
+
508
+ def on_validation_epoch_end(self):
509
+ avg_val_loss = self.val_loss_epoch_total / self.val_batch_count
510
+ mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch)
511
+ self.val_loss_epoch_total = 0.0
512
+ self.val_batch_count = 0
513
+
514
+ if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0:
515
+ if self.edit_distance_scores:
516
+ mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch)
517
+ for metric in self.graph_metrics:
518
+ if self.graph_metrics[metric]:
519
+ mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch)
520
+ print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1)
521
+
522
+ def on_train_epoch_end(self):
523
+ print(f'[INFO] - Finished epoch {self.current_epoch + 1}')
524
+ avg_train_loss = self.train_loss_epoch_total / self.train_batch_count
525
+ print(f'[INFO] - Train loss: {avg_train_loss}')
526
+ mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch)
527
+ self.train_loss_epoch_total = 0.0
528
+ self.train_batch_count = 0
529
+
530
+ if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0:
531
+ self.save_model()
532
+
533
+ def on_fit_end(self):
534
+ self.save_model()
535
+
536
+ def save_model(self):
537
+ model_dir = "Donut_model"
538
+ os.makedirs(model_dir, exist_ok=True)
539
+ self.model.save_pretrained(model_dir)
540
+ print('[INFO] - Saving the model to dagshub using mlflow')
541
+ mlflow.transformers.log_model(
542
+ transformers_model={
543
+ "model": self.model,
544
+ "feature_extractor": self.processor.feature_extractor,
545
+ "image_processor": self.processor.image_processor,
546
+ "tokenizer": self.processor.tokenizer
547
+ },
548
+ artifact_path=model_dir,
549
+ # Set task explicitly since MLflow cannot infer it from the loaded model
550
+ task = "image-to-text"
551
+ )
552
+ print('[INFO] - Saved the model to dagshub using mlflow')
553
+
554
+ def train_dataloader(self):
555
+ return train_dataloader
556
+
557
+ def val_dataloader(self):
558
+ return val_dataloader
559
+
560
+ config = {"max_epochs":200,
561
+ # "val_check_interval":0.2, # how many times we want to validate during an epoch
562
+ "check_val_every_n_epoch":1,
563
+ "gradient_clip_val":1.0,
564
+ # "num_training_samples_per_epoch": 800,
565
+ "lr":8e-4, #3e-4, #3e-5,
566
+ "train_batch_sizes": [1], #[8], #[1],#[8],
567
+ "val_batch_sizes": [1],
568
+ # "seed":2022,
569
+ "num_nodes": 1,
570
+ "warmup_steps": 200, # 800/8*30/10, 10%
571
+ "verbose": True,
572
+ }
573
+
574
+ model_module = DonutModelPLModule(config, processor, model)
575
+
576
  # Load dataset
577
  dataset = split_dataset['test']
578