YuWang0103 commited on
Commit
bc5c298
·
verified ·
1 Parent(s): 38ed701

Update demo_model.py

Browse files
Files changed (1) hide show
  1. demo_model.py +28 -2
demo_model.py CHANGED
@@ -10,6 +10,7 @@ import utils
10
  import networkx as nx
11
  from sentence_transformers import SentenceTransformer
12
  import pytorch_lightning as pl
 
13
 
14
 
15
  class LGGMText2Graph_Demo(pl.LightningModule):
@@ -55,7 +56,7 @@ class LGGMText2Graph_Demo(pl.LightningModule):
55
  self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
56
 
57
 
58
- def generate(self, text, num_nodes) -> None:
59
  print(num_nodes)
60
  prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device)
61
  samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes)
@@ -70,8 +71,33 @@ class LGGMText2Graph_Demo(pl.LightningModule):
70
 
71
  return nx_graphs
72
 
73
- def init_prompt_encoder(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
75
 
76
 
77
  @torch.no_grad()
 
10
  import networkx as nx
11
  from sentence_transformers import SentenceTransformer
12
  import pytorch_lightning as pl
13
+ from transformers import BertTokenizer, BertForSequenceClassification
14
 
15
 
16
  class LGGMText2Graph_Demo(pl.LightningModule):
 
56
  self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
57
 
58
 
59
+ def generate_basic(self, text, num_nodes) -> None:
60
  print(num_nodes)
61
  prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device)
62
  samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes)
 
71
 
72
  return nx_graphs
73
 
74
+ def generate_pretrained(self, text, num_nodes) -> None:
75
+ encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
76
+ encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()}
77
+
78
+ # Get the model output
79
+ with torch.no_grad():
80
+ prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
81
+
82
+ samples = self.sample_batch(5, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes)
83
+
84
+ nx_graphs = []
85
+ for graph in samples:
86
+ node_types, edge_types = graph
87
+ A = edge_types.bool().cpu().numpy()
88
+
89
+ nx_graph = nx.from_numpy_array(A)
90
+ nx_graphs.append(nx_graph)
91
+
92
+ return nx_graphs
93
+
94
+ def init_prompt_encoder_basic(self):
95
  self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
96
+
97
+ def init_prompt_encoder_pretrained(self):
98
+ model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model
99
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
100
+ self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu')
101
 
102
 
103
  @torch.no_grad()