Spaces:
Runtime error
Runtime error
Update demo_model.py
Browse files- demo_model.py +28 -2
demo_model.py
CHANGED
@@ -10,6 +10,7 @@ import utils
|
|
10 |
import networkx as nx
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
import pytorch_lightning as pl
|
|
|
13 |
|
14 |
|
15 |
class LGGMText2Graph_Demo(pl.LightningModule):
|
@@ -55,7 +56,7 @@ class LGGMText2Graph_Demo(pl.LightningModule):
|
|
55 |
self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
|
56 |
|
57 |
|
58 |
-
def
|
59 |
print(num_nodes)
|
60 |
prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device)
|
61 |
samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes)
|
@@ -70,8 +71,33 @@ class LGGMText2Graph_Demo(pl.LightningModule):
|
|
70 |
|
71 |
return nx_graphs
|
72 |
|
73 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
@torch.no_grad()
|
|
|
10 |
import networkx as nx
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
import pytorch_lightning as pl
|
13 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
14 |
|
15 |
|
16 |
class LGGMText2Graph_Demo(pl.LightningModule):
|
|
|
56 |
self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
|
57 |
|
58 |
|
59 |
+
def generate_basic(self, text, num_nodes) -> None:
|
60 |
print(num_nodes)
|
61 |
prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device)
|
62 |
samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes)
|
|
|
71 |
|
72 |
return nx_graphs
|
73 |
|
74 |
+
def generate_pretrained(self, text, num_nodes) -> None:
|
75 |
+
encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
76 |
+
encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()}
|
77 |
+
|
78 |
+
# Get the model output
|
79 |
+
with torch.no_grad():
|
80 |
+
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
|
81 |
+
|
82 |
+
samples = self.sample_batch(5, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes)
|
83 |
+
|
84 |
+
nx_graphs = []
|
85 |
+
for graph in samples:
|
86 |
+
node_types, edge_types = graph
|
87 |
+
A = edge_types.bool().cpu().numpy()
|
88 |
+
|
89 |
+
nx_graph = nx.from_numpy_array(A)
|
90 |
+
nx_graphs.append(nx_graph)
|
91 |
+
|
92 |
+
return nx_graphs
|
93 |
+
|
94 |
+
def init_prompt_encoder_basic(self):
|
95 |
self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
|
96 |
+
|
97 |
+
def init_prompt_encoder_pretrained(self):
|
98 |
+
model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model
|
99 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
100 |
+
self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu')
|
101 |
|
102 |
|
103 |
@torch.no_grad()
|