YuWang0103 commited on
Commit
38ed701
·
verified ·
1 Parent(s): f223b75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -8,6 +8,7 @@ from analysis.spectre_utils import CrossDomainSamplingMetrics
8
  import networkx as nx
9
  import numpy as np
10
  import matplotlib.pyplot as plt
 
11
 
12
 
13
  cfg = OmegaConf.load('./config.yaml')
@@ -22,9 +23,9 @@ input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra
22
 
23
  sampling_metrics = CrossDomainSamplingMetrics(data_loaders)
24
 
25
- model = LGGMText2Graph_Demo.load_from_checkpoint('last-v1.ckpt')
26
 
27
- model.init_prompt_encoder()
28
 
29
  def calculate_average_degree(graph):
30
  num_nodes = graph.number_of_nodes()
@@ -34,7 +35,7 @@ def calculate_average_degree(graph):
34
 
35
  def predict(text, num_nodes = None):
36
  # Assuming model.generate and other processes are defined as before
37
- graphs = model.generate(text, int(num_nodes))
38
  ccs = []
39
  degs = []
40
  images = []
@@ -51,8 +52,11 @@ def predict(text, num_nodes = None):
51
  plt.close(fig)
52
 
53
  images.append(image)
 
 
 
54
 
55
- return images[0], images[1], images[2], images[3], images[4], ccs[0], ccs[1], ccs[2], ccs[3], ccs[4], degs[0], degs[1], degs[2], degs[3], degs[4]
56
 
57
  def clear(input_text):
58
  return None, None
@@ -67,7 +71,7 @@ with gr.Blocks() as demo:
67
  input_num = gr.Slider(5, 200, value=10, label="Count", info="Number of nodes in the graph to be generated")
68
  with gr.Column():
69
  gr.Markdown("### Suggested Prompts")
70
- gr.Markdown("1. Create a complex network with high clustering coefficient, exhibiting a very dense connection.\n2. Create a graph with extremely low number of triangles, which means it has very low clustering coefficient.")
71
 
72
  with gr.Row() as output_row:
73
  output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(5)]
@@ -75,15 +79,19 @@ with gr.Blocks() as demo:
75
  output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(5)]
76
  with gr.Row():
77
  output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(5)]
 
 
 
 
78
 
79
  with gr.Row():
80
  submit_button = gr.Button("Submit")
81
  clear_button = gr.Button("Clear")
82
 
83
  # Change function is linked to the submit button
84
- submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg)
85
 
86
  # Clear function resets the text input and clears the outputs
87
- clear_button.click(fn=clear, inputs=input_text, outputs=output_images + output_texts_cc + output_texts_deg)
88
 
89
  demo.launch()
 
8
  import networkx as nx
9
  import numpy as np
10
  import matplotlib.pyplot as plt
11
+ import torch
12
 
13
 
14
  cfg = OmegaConf.load('./config.yaml')
 
23
 
24
  sampling_metrics = CrossDomainSamplingMetrics(data_loaders)
25
 
26
+ model = LGGMText2Graph_Demo.load_from_checkpoint('last-v1.ckpt', map_location=torch.device("cpu"))
27
 
28
+ model.init_prompt_encoder_pretrained()
29
 
30
  def calculate_average_degree(graph):
31
  num_nodes = graph.number_of_nodes()
 
35
 
36
  def predict(text, num_nodes = None):
37
  # Assuming model.generate and other processes are defined as before
38
+ graphs = model.generate_pretrained(text, int(num_nodes))
39
  ccs = []
40
  degs = []
41
  images = []
 
52
  plt.close(fig)
53
 
54
  images.append(image)
55
+
56
+ avg_deg = np.mean(degs)
57
+ avg_cc = np.mean(ccs)
58
 
59
+ return images[0], images[1], images[2], images[3], images[4], ccs[0], ccs[1], ccs[2], ccs[3], ccs[4], degs[0], degs[1], degs[2], degs[3], degs[4], avg_cc, avg_deg
60
 
61
  def clear(input_text):
62
  return None, None
 
71
  input_num = gr.Slider(5, 200, value=10, label="Count", info="Number of nodes in the graph to be generated")
72
  with gr.Column():
73
  gr.Markdown("### Suggested Prompts")
74
+ gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.")
75
 
76
  with gr.Row() as output_row:
77
  output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(5)]
 
79
  output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(5)]
80
  with gr.Row():
81
  output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(5)]
82
+
83
+ with gr.Row():
84
+ avg_cc_text = gr.Textbox(label="Average Clustering Coefficient")
85
+ avg_deg_text = gr.Textbox(label="Average Degree")
86
 
87
  with gr.Row():
88
  submit_button = gr.Button("Submit")
89
  clear_button = gr.Button("Clear")
90
 
91
  # Change function is linked to the submit button
92
+ submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
93
 
94
  # Clear function resets the text input and clears the outputs
95
+ clear_button.click(fn=clear, inputs=input_text, outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
96
 
97
  demo.launch()