File size: 2,669 Bytes
5ce695c
bb2ae93
5ce695c
 
 
0d30669
5ce695c
 
 
0d30669
84d90da
5ce695c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d30669
5ce695c
 
 
0d30669
5ce695c
 
0d30669
 
 
18bc43a
0d30669
 
 
 
2cd6b7d
 
0d30669
 
5ce695c
 
 
 
0d30669
5ce695c
 
0d30669
 
 
 
5ce695c
0d30669
5ce695c
8252945
6adaa4f
efcd777
 
 
 
 
 
 
 
0d30669
5ce695c
 
 
2cd6b7d
5ce695c
0d30669
5ce695c
 
 
 
 
0d30669
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from datasets import load_dataset, Dataset
from PIL import Image
import io
import base64
import json 
from graph_visualization import visualize_graph

# Load the dataset
dataset = load_dataset("Zaherrr/OOP_KG_Dataset", split='data')
print(f'This is the dataset: {dataset}')

def reshape_json_data_to_fit_visualize_graph(graph_data):
    nodes = graph_data["nodes"]
    edges = graph_data["edges"]
    transformed_nodes = [
        {"id": nodes["id"][idx], "label": nodes["label"][idx]}
        for idx in range(len(nodes["id"]))
    ]
    transformed_edges = [
        {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
        for idx in range(len(edges["source"]))
    ]
    graph_data = {"nodes": transformed_nodes, "edges": transformed_edges}
    return graph_data

def display_example(index):
    example = dataset[index]
    img = example["image"]
    
    # Prepare the graph data
    graph_data = {"nodes": example["nodes"], "edges": example["edges"]}
    transformed_graph_data = reshape_json_data_to_fit_visualize_graph(graph_data)
    
    # Generate the graph visualization
    graph_html = visualize_graph(transformed_graph_data)
    
    # Wrap the graph HTML in a div with fixed height and scrolling
    graph_html_with_style = f"""
    <div style="height: 500px;">
        {graph_html}
    </div>
    """
    
    # Convert graph_data to a formatted JSON string
    json_data = json.dumps(transformed_graph_data, indent=2)
    
    return img, graph_html_with_style, json_data, transformed_graph_data

def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Knowledge Graph Visualizer")
        
        with gr.Row():
            index_slider = gr.Slider(
                minimum=0,
                maximum=len(dataset) - 1,
                step=1,
                label="Example Index"
            )
        
        with gr.Row():
            image_output = gr.Image(type="pil", label="Image", height=300)
            graph_output = gr.HTML(label="Knowledge Graph")
        
        with gr.Row():
            json_output = gr.Code(language="json", label="Graph JSON Data")
            text_output = gr.Textbox(
                label="Graph Text Data",
                placeholder="Text data will appear here",
                interactive=False,
            )
        
        index_slider.change(
            fn=display_example,
            inputs=[index_slider],
            outputs=[image_output, graph_output, json_output, text_output],
        )
    
    return demo

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()