Zaherrr commited on
Commit
7ab6281
·
verified ·
1 Parent(s): 5b20ed5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -0
app.py CHANGED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from PIL import Image
4
+ import json
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+ # Load dataset
9
+ dataset = split_dataset['test']
10
+
11
+ # Set up device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model.to(device)
14
+
15
+ class Sharpen:
16
+ def __call__(self, img):
17
+ return img.filter(ImageFilter.SHARPEN)
18
+
19
+ def preprocess_image(image):
20
+ # Convert to PIL Image if it's not already
21
+ if not isinstance(image, Image.Image):
22
+ image = Image.fromarray(image)
23
+
24
+ # Apply sharpening
25
+ sharpen = Sharpen()
26
+ sharpened_image = sharpen(image)
27
+
28
+ return sharpened_image
29
+
30
+ def perform_inference(image):
31
+ # Preprocess the image
32
+ inputs = processor(images=image, return_tensors="pt")
33
+ pixel_values = inputs.pixel_values.to(device)
34
+
35
+ # Prepare decoder input ids
36
+ batch_size = pixel_values.shape[0]
37
+ decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device)
38
+
39
+ # Generate output
40
+ outputs = model.generate(
41
+ pixel_values,
42
+ decoder_input_ids=decoder_input_ids,
43
+ max_length=max_length, # + 500, #512, # Adjust as needed
44
+ early_stopping=True,
45
+ pad_token_id=processor.tokenizer.pad_token_id,
46
+ eos_token_id=processor.tokenizer.eos_token_id,
47
+ use_cache=True,
48
+ num_beams=1,
49
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
50
+ return_dict_in_generate=True,
51
+ )
52
+
53
+ # Decode the output
54
+ decoded_output = processor.batch_decode(outputs.sequences)[0]
55
+ print("Raw model output:", decoded_output)
56
+
57
+ return decoded_output
58
+
59
+ def display_example(index):
60
+ example = dataset[index]
61
+ img = example["image"]
62
+ return img, None, None
63
+
64
+ def from_json_like_to_xml_like(data):
65
+ def parse_nodes(nodes):
66
+ node_elements = []
67
+ for node in nodes:
68
+ label = node["label"]
69
+ node_elements.append(f'<n id="{node["id"]}">{label}</n>')
70
+ return "<nodes>\n" + "".join(node_elements) + "\n</nodes>"
71
+
72
+ def parse_edges(edges):
73
+ edge_elements = []
74
+ for edge in edges:
75
+ edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>')
76
+ return "<edges>\n" + "".join(edge_elements) + "\n</edges>"
77
+
78
+ nodes_xml = parse_nodes(data["nodes"])
79
+ edges_xml = parse_edges(data["edges"])
80
+ return nodes_xml + "\n" + edges_xml
81
+
82
+
83
+ def reshape_json_data_to_fit_visualize_graph(graph_data):
84
+ nodes = graph_data["nodes"]
85
+ edges = graph_data["edges"]
86
+ transformed_nodes = [
87
+ {"id": nodes["id"][idx], "label": nodes["label"][idx]}
88
+ for idx in range(len(nodes["id"]))
89
+ ]
90
+ transformed_edges = [
91
+ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
92
+ for idx in range(len(edges["source"]))
93
+ ]
94
+ return {"nodes": transformed_nodes, "edges": transformed_edges}
95
+
96
+ def get_ground_truth(index):
97
+ example = dataset[index]
98
+ ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example))
99
+ ground_truth = from_json_like_to_xml_like(json.loads(ground_truth))
100
+ print(f'Ground truth sequence: {ground_truth}')
101
+ return ground_truth
102
+
103
+ def transform_image(img, index, physics_enabled):
104
+ # Perform inference
105
+ sequence = perform_inference(img)
106
+
107
+ # Transform the sequence to graph data
108
+ graph_data = transform_sequence(sequence)
109
+
110
+ # Generate the graph visualization
111
+ graph_html = visualize_graph(graph_data, physics_enabled)
112
+
113
+ # Modify the iframe to have a fixed height
114
+ graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
115
+
116
+ # Convert graph_data to a formatted JSON string
117
+ json_data = json.dumps(graph_data, indent=2)
118
+
119
+ return graph_html, json_data, sequence
120
+
121
+ import re
122
+ from typing import Dict, List, Tuple
123
+
124
+ def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]:
125
+ # Extract nodes and edges
126
+ nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL)
127
+ edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL)
128
+
129
+ if not nodes_match or not edges_match:
130
+ raise ValueError("Invalid input sequence: nodes or edges not found")
131
+
132
+ nodes_text = nodes_match.group(1)
133
+ edges_text = edges_match.group(1)
134
+
135
+ # Parse nodes
136
+ nodes = []
137
+ for node_match in re.finditer(r'<n id="\s*(\d+)">(.*?)</n>', nodes_text):
138
+ node_id, node_label = node_match.groups()
139
+ nodes.append({
140
+ "id": node_id.strip(),
141
+ "label": node_label.strip()
142
+ })
143
+
144
+ # Parse edges
145
+ edges = []
146
+ for edge_match in re.finditer(r'<e src="\s*(\d+)" tgt="\s*(\d+)"/>', edges_text):
147
+ source, target = edge_match.groups()
148
+ edges.append({
149
+ "source": source.strip(),
150
+ "target": target.strip(),
151
+ "type": "->"
152
+ })
153
+
154
+ return {
155
+ "nodes": nodes,
156
+ "edges": edges
157
+ }
158
+
159
+ # function to visualize the extracted graph
160
+ import json
161
+ from pyvis.network import Network
162
+
163
+
164
+ def create_graph(nodes, edges, physics_enabled=True):
165
+ net = Network(
166
+ notebook=True,
167
+ height="100vh",
168
+ width="100vw",
169
+ bgcolor="#222222",
170
+ font_color="white",
171
+ cdn_resources="remote",
172
+ )
173
+
174
+ for node in nodes:
175
+ net.add_node(
176
+ node["id"],
177
+ label=node["label"],
178
+ title=node["label"],
179
+ color="blue" if node["label"] == "OOP" else "green",
180
+ )
181
+
182
+ for edge in edges:
183
+ net.add_edge(edge["source"], edge["target"], title=edge["type"])
184
+
185
+ net.force_atlas_2based(
186
+ gravity=-50,
187
+ central_gravity=0.01,
188
+ spring_length=100,
189
+ spring_strength=0.08,
190
+ damping=0.4,
191
+ )
192
+
193
+ options = {
194
+ "nodes": {"physics": physics_enabled},
195
+ "edges": {"smooth": True},
196
+ "interaction": {"hover": True, "zoomView": True},
197
+ "physics": {
198
+ "enabled": physics_enabled,
199
+ "stabilization": {"enabled": True, "iterations": 200},
200
+ },
201
+ }
202
+
203
+ net.set_options(json.dumps(options))
204
+ return net
205
+
206
+
207
+ def visualize_graph(json_data, physics_enabled=True):
208
+ if isinstance(json_data, str):
209
+ data = json.loads(json_data)
210
+ else:
211
+ data = json_data
212
+ nodes = data["nodes"]
213
+ edges = data["edges"]
214
+ net = create_graph(nodes, edges, physics_enabled)
215
+ html = net.generate_html()
216
+ html = html.replace("'", '"')
217
+ html = html.replace(
218
+ '<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"'
219
+ )
220
+ return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>"""
221
+
222
+ def update_physics(json_data, physics_enabled):
223
+ if json_data is None:
224
+ return None
225
+
226
+ data = json.loads(json_data)
227
+ graph_html = visualize_graph(data, physics_enabled)
228
+ graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
229
+ return graph_html
230
+
231
+
232
+ # function to calculate the graph similarity metrics between the prediction and the ground-truth
233
+ def calculate_and_display_metrics(pred_graph, ground_truth_graph):
234
+ if pred_graph is None or ground_truth_graph is None:
235
+ return "Please generate a prediction and ensure a ground truth graph is available."
236
+
237
+ #removing the start token from the string
238
+ pred_graph = pred_graph.replace('<s>', "").replace("<newline>", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('<n id=" ', '<n id="')
239
+ print(f'Prediction: {pred_graph}')
240
+
241
+ # Assuming the graphs are in the correct format for the calculate_metrics function
242
+ metrics = model_module.calculate_metrics([pred_graph], [ground_truth_graph])
243
+
244
+ # Format the metrics for display
245
+ overall_metric = metrics[0][0]
246
+ detailed_metrics = metrics[1][0]
247
+
248
+ # output = f"Overall Metric: {overall_metric:.4f}\n\nDetailed Metrics:\n"
249
+ output = f"Detailed Metrics:\n"
250
+
251
+ for key, value in detailed_metrics.items():
252
+ output += f"{key}: {value:.4f}\n"
253
+
254
+ return output
255
+
256
+
257
+ def create_interface():
258
+ with gr.Blocks() as demo:
259
+ gr.Markdown("# Knowledge Graph Visualizer with Model Inference")
260
+
261
+ with gr.Row():
262
+ index_slider = gr.Slider(
263
+ minimum=0,
264
+ maximum=len(dataset) - 1,
265
+ step=1,
266
+ label="Example Index"
267
+ )
268
+
269
+ with gr.Row():
270
+ image_output = gr.Image(type="pil", label="Image", height=500, interactive=False)
271
+ graph_output = gr.HTML(label="Knowledge Graph")
272
+
273
+ with gr.Row():
274
+ transform_button = gr.Button("Transform")
275
+ physics_toggle = gr.Checkbox(label="Enable Physics", value=True)
276
+
277
+ with gr.Row():
278
+ json_output = gr.Code(language="json", label="Graph JSON Data")
279
+ ground_truth_output = gr.Textbox(visible=False)#gr.JSON(label="Ground Truth Graph", visible=False)
280
+ predicted_raw_sequence = gr.Textbox(visible=False)
281
+
282
+ with gr.Row():
283
+ metrics_button = gr.Button("Calculate Metrics")
284
+ metrics_output = gr.Textbox(label="Similarity Metrics", lines=10)
285
+
286
+ index_slider.change(
287
+ fn=display_example,
288
+ inputs=[index_slider],
289
+ outputs=[image_output, graph_output, json_output],
290
+ ).then(
291
+ fn=get_ground_truth,
292
+ inputs=[index_slider],
293
+ outputs=[ground_truth_output],
294
+ )
295
+
296
+ transform_button.click(
297
+ fn=transform_image,
298
+ inputs=[image_output, index_slider, physics_toggle],
299
+ outputs=[graph_output, json_output, predicted_raw_sequence],
300
+ ).then(
301
+ fn=calculate_and_display_metrics,
302
+ inputs=[predicted_raw_sequence, ground_truth_output],
303
+ outputs=[metrics_output]#gr.Textbox(label="Metrics"),
304
+ )
305
+ metrics_button.click(
306
+ fn=calculate_and_display_metrics,
307
+ inputs=[predicted_raw_sequence, ground_truth_output],
308
+ outputs=[metrics_output],
309
+ )
310
+ physics_toggle.change(
311
+ fn=update_physics,
312
+ inputs=[json_output, physics_toggle],
313
+ outputs=[graph_output],
314
+ )
315
+ return demo
316
+
317
+ # Create and launch the interface
318
+ if __name__ == "__main__":
319
+ demo = create_interface()
320
+ demo.launch(share=True, debug=True)