Huhujingjing commited on
Commit
420dbd3
·
1 Parent(s): cf18ce2

Upload model

Browse files
Files changed (2) hide show
  1. configuration_gcn.py +8 -2
  2. modeling_gcn.py +99 -17
configuration_gcn.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import PretrainedConfig
2
-
3
  class GCNConfig(PretrainedConfig):
4
  model_type = "gcn"
5
 
@@ -10,6 +10,9 @@ class GCNConfig(PretrainedConfig):
10
  hidden_size: int=64,
11
  n_layers: int=6,
12
  num_classes: int=1,
 
 
 
13
  **kwargs,
14
  ):
15
 
@@ -19,9 +22,12 @@ class GCNConfig(PretrainedConfig):
19
  self.n_layers = n_layers # the number of GCN layers
20
  self.num_classes = num_classes # the number of output classes
21
 
 
 
 
22
  super().__init__(**kwargs)
23
 
24
 
25
  if __name__ == "__main__":
26
- gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1)
27
  gcn_config.save_pretrained("custom-gcn")
 
1
  from transformers import PretrainedConfig
2
+ from typing import List
3
  class GCNConfig(PretrainedConfig):
4
  model_type = "gcn"
5
 
 
10
  hidden_size: int=64,
11
  n_layers: int=6,
12
  num_classes: int=1,
13
+
14
+ smiles: List[str] = None,
15
+ processor_class: str = "SmilesProcessor",
16
  **kwargs,
17
  ):
18
 
 
22
  self.n_layers = n_layers # the number of GCN layers
23
  self.num_classes = num_classes # the number of output classes
24
 
25
+ self.smiles = smiles # process smiles
26
+ self.processor_class = processor_class
27
+
28
  super().__init__(**kwargs)
29
 
30
 
31
  if __name__ == "__main__":
32
+ gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor")
33
  gcn_config.save_pretrained("custom-gcn")
modeling_gcn.py CHANGED
@@ -5,7 +5,98 @@ from torch_scatter import scatter
5
  from transformers import PreTrainedModel
6
  from gcn_model.configuration_gcn import GCNConfig
7
  import torch
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  """
11
  MLP Layer used after graph vector representation
@@ -67,26 +158,17 @@ class GCNModel(PreTrainedModel):
67
  n_layers=config.n_layers,
68
  num_classes=config.num_classes,
69
  )
 
 
 
70
 
71
  def forward(self, tensor):
72
  return self.model.forward_features(tensor)
73
 
74
- # class GCNModelForMolecularPrediction(PreTrainedModel):
75
- # config_class = GCNConfig
76
- #
77
- # def __init__(self, config):
78
- # super().__init__(config)
79
- #
80
- # self.model = GCNNet(
81
- # input_feature=config.input_feature,
82
- # emb_input=config.emb_input,
83
- # hidden_size=config.hidden_size,
84
- # n_layers=config.n_layers,
85
- # num_classes=config.num_classes,
86
- # )
87
- #
88
- # def forward(self, tensor):
89
- # return self.model.forward_features(tensor)
90
 
91
 
92
  if __name__ == "__main__":
 
5
  from transformers import PreTrainedModel
6
  from gcn_model.configuration_gcn import GCNConfig
7
  import torch
8
+ from rdkit import Chem
9
+ from rdkit.Chem import AllChem
10
+ import torch
11
+ from torch_geometric.data import Data
12
+
13
+
14
+ class SmilesDataset(torch.utils.data.Dataset):
15
+ def __init__(self, smiles):
16
+ self.smiles_list = smiles
17
+ self.data_list = []
18
+
19
+
20
+ def __len__(self):
21
+ return len(self.data_list)
22
+
23
+ def __getitem__(self, idx):
24
+ return self.data_list[idx]
25
+
26
+ def get_data(self, smiles):
27
+ self.smiles_list = smiles
28
+ # self.data_list = []
29
+ # bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
30
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'S': 4}
31
+
32
+ for i in range(len(self.smiles_list)):
33
+ # 将 SMILES 表示转换为 RDKit 的分子对象
34
+ # print(self.smiles_list[i])
35
+ mol = Chem.MolFromSmiles(self.smiles_list[i]) # 从smiles编码中获取结构信息
36
+ if mol is None:
37
+ print("无法创建Mol对象", self.smiles_list[i])
38
+ else:
39
+
40
+ mol3d = Chem.AddHs(
41
+ mol) # 在rdkit中,分子在默认情况下是不显示氢的,但氢原子对于真实的几何构象计算有很大的影响,所以在计算3D构象前,需要使用Chem.AddHs()方法加上氢原子
42
+ if mol3d is None:
43
+ print("无法创建mol3d对象", self.smiles_list[i])
44
+ else:
45
+ AllChem.EmbedMolecule(mol3d, randomSeed=1) # 生成3D构象
46
+
47
+ N = mol3d.GetNumAtoms()
48
+ # 获取原子坐标信息
49
+ if mol3d.GetNumConformers() > 0:
50
+ conformer = mol3d.GetConformer()
51
+ pos = conformer.GetPositions()
52
+ pos = torch.tensor(pos, dtype=torch.float)
53
+
54
+ type_idx = []
55
+ # atomic_number = []
56
+ # aromatic = []
57
+ # sp = []
58
+ # sp2 = []
59
+ # sp3 = []
60
+ for atom in mol3d.GetAtoms():
61
+ type_idx.append(types[atom.GetSymbol()])
62
+ # atomic_number.append(atom.GetAtomicNum())
63
+ # aromatic.append(1 if atom.GetIsAromatic() else 0)
64
+ # hybridization = atom.GetHybridization()
65
+ # sp.append(1 if hybridization == HybridizationType.SP else 0)
66
+ # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
67
+ # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
68
+
69
+ # z = torch.tensor(atomic_number, dtype=torch.long)
70
+
71
+ row, col, edge_type = [], [], []
72
+ for bond in mol3d.GetBonds():
73
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
74
+ row += [start, end]
75
+ col += [end, start]
76
+ # edge_type += 2 * [bonds[bond.GetBondType()]]
77
+
78
+ edge_index = torch.tensor([row, col], dtype=torch.long)
79
+ # edge_type = torch.tensor(edge_type, dtype=torch.long)
80
+ # edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
81
+
82
+ perm = (edge_index[0] * N + edge_index[1]).argsort()
83
+ edge_index = edge_index[:, perm]
84
+ # edge_type = edge_type[perm]
85
+ # edge_attr = edge_attr[perm]
86
+ #
87
+ # row, col = edge_index
88
+ # hs = (z == 1).to(torch.float)
89
+
90
+ x = torch.tensor(type_idx).to(torch.float)
91
+
92
+ # y = self.y_list[i]
93
+
94
+ data = Data(x=x, pos=pos, edge_index=edge_index, smiles=self.smiles_list[i])
95
+
96
+ self.data_list.append(data)
97
+ else:
98
+ print("无法创建comfor", self.smiles_list[i])
99
+ return self.data_list
100
 
101
  """
102
  MLP Layer used after graph vector representation
 
158
  n_layers=config.n_layers,
159
  num_classes=config.num_classes,
160
  )
161
+ self.process = SmilesDataset(
162
+ smiles=config.smiles,
163
+ )
164
 
165
  def forward(self, tensor):
166
  return self.model.forward_features(tensor)
167
 
168
+ def process_smiles(self, smiles):
169
+ return self.process.get_data(smiles)
170
+
171
+
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  if __name__ == "__main__":