File size: 2,961 Bytes
5ce695c
bb2ae93
5ce695c
 
 
0d30669
5ce695c
 
 
0d30669
84d90da
5ce695c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d30669
5ce695c
 
 
0d30669
5ce695c
 
0d30669
da4a279
1394cbb
 
 
 
 
 
 
 
0d30669
2cd6b7d
 
0d30669
 
5ce695c
 
0077eb0
5e5196a
3d80d6c
5ce695c
0d30669
5ce695c
 
0d30669
 
 
 
5ce695c
0d30669
5ce695c
c6b2d7d
5e5196a
efcd777
 
 
 
 
 
 
 
0d30669
5ce695c
 
 
2cd6b7d
5ce695c
0d30669
5ce695c
 
 
 
 
0d30669
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from datasets import load_dataset, Dataset
from PIL import Image
import io
import base64
import json 
from graph_visualization import visualize_graph

# Load the dataset
dataset = load_dataset("Zaherrr/OOP_KG_Dataset", split='data')
print(f'This is the dataset: {dataset}')

def reshape_json_data_to_fit_visualize_graph(graph_data):
    nodes = graph_data["nodes"]
    edges = graph_data["edges"]
    transformed_nodes = [
        {"id": nodes["id"][idx], "label": nodes["label"][idx]}
        for idx in range(len(nodes["id"]))
    ]
    transformed_edges = [
        {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
        for idx in range(len(edges["source"]))
    ]
    graph_data = {"nodes": transformed_nodes, "edges": transformed_edges}
    return graph_data

def display_example(index):
    example = dataset[index]
    img = example["image"]
    
    # Prepare the graph data
    graph_data = {"nodes": example["nodes"], "edges": example["edges"]}
    transformed_graph_data = reshape_json_data_to_fit_visualize_graph(graph_data)
    
    # Generate the graph visualization
    graph_html = visualize_graph(transformed_graph_data)
    
    # Wrap the graph HTML in a div with fixed height but no scrolling
    # graph_html_with_style = f"""
    # <div style="height: 300px; justify-content: center; align-items: center;overflow-y: auto">
    #     <div style=max-height: 300px;">
    #         {graph_html}
    #     </div>
    # </div>
    # """
    graph_html_with_style = graph_html
    
    # Convert graph_data to a formatted JSON string
    json_data = json.dumps(transformed_graph_data, indent=2)
    
    return img, graph_html_with_style, json_data, transformed_graph_data

def create_interface():
    with gr.Blocks(css="#graph-output {height: 200px; overflow: hidden;}") as demo:
    # with gr.Blocks() as demo:
    
        gr.Markdown("# Knowledge Graph Visualizer")
        
        with gr.Row():
            index_slider = gr.Slider(
                minimum=0,
                maximum=len(dataset) - 1,
                step=1,
                label="Example Index"
            )
        
        with gr.Row():
            image_output = gr.Image(type="pil", label="Image", height=300)
            graph_output = gr.HTML(label="Knowledge Graph", elem_id="graph-output")
        
        with gr.Row():
            json_output = gr.Code(language="json", label="Graph JSON Data")
            text_output = gr.Textbox(
                label="Graph Text Data",
                placeholder="Text data will appear here",
                interactive=False,
            )
        
        index_slider.change(
            fn=display_example,
            inputs=[index_slider],
            outputs=[image_output, graph_output, json_output, text_output],
        )
    
    return demo

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()