mgyigit commited on
Commit
1c867f8
·
verified ·
1 Parent(s): 7db6b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -225
app.py CHANGED
@@ -13,50 +13,50 @@ import time
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
- submodel = 'DrugGEN'
17
- inference_model = "/home/user/app/experiments/models/DrugGEN/"
18
- sample_num = 100
19
 
20
  # Data configuration
21
- inf_smiles = '/home/user/app/data/chembl_test.smi'
22
- train_smiles = '/home/user/app/data/chembl_train.smi'
23
- inf_batch_size = 1
24
- mol_data_dir = '/home/user/app/data'
25
- features = False
26
 
27
  # Model configuration
28
- act = 'relu'
29
- max_atom = 45
30
- dim = 128
31
- depth = 1
32
- heads = 8
33
- mlp_ratio = 3
34
- dropout = 0.
35
 
36
  # Seed configuration
37
- set_seed = True
38
- seed = 10
39
 
40
- disable_correction = False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
- submodel = 'DrugGEN'
45
- inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/"
46
- train_drug_smiles = '/home/user/app/data/akt_train.smi'
47
- max_atom = 45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
- submodel = 'DrugGEN'
52
- inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/"
53
- train_drug_smiles = '/home/user/app/data/cdk2_train.smi'
54
- max_atom = 38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
- submodel = "NoTarget"
59
- inference_model = "/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
@@ -66,34 +66,62 @@ model_configs = {
66
  }
67
 
68
 
69
- def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str):
70
- """
71
- Depending on the selected mode, either generate new molecules or evaluate provided SMILES.
72
-
73
  Returns:
74
- image, file_path, basic_metrics, advanced_metrics
75
- """
 
 
 
76
  config = model_configs[model_name]
77
-
78
- if mode == "Custom Input SMILES":
79
- # Process the custom input SMILES
80
- smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""]
81
- if len(smiles_list) > 100:
82
- raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.")
83
- # Write the custom SMILES to a temporary file and update config
84
- temp_input_file = "custom_input.smi"
85
- with open(temp_input_file, "w") as f:
86
- for s in smiles_list:
87
- f.write(s + "\n")
88
- config.inf_smiles = temp_input_file
89
- config.sample_num = len(smiles_list)
90
- # Always use a random seed for custom mode
91
- config.seed = random.randint(0, 10000)
92
- else:
93
- # Classical Generation mode
94
  config.sample_num = num_molecules
 
95
  if config.sample_num > 250:
96
  raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if seed_num is None or seed_num.strip() == "":
98
  config.seed = random.randint(0, 10000)
99
  else:
@@ -102,24 +130,37 @@ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str,
102
  except ValueError:
103
  raise gr.Error("The seed must be an integer value!")
104
 
105
- # Adjust model name for the inference if not using NoTarget
106
- if model_name != "DrugGEN-NoTarget":
107
- target_model_name = "DrugGEN"
108
- else:
109
- target_model_name = "NoTarget"
110
 
111
  inferer = Inference(config)
112
  start_time = time.time()
113
  scores = inferer.inference() # This returns a DataFrame with specific columns
114
  et = time.time() - start_time
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Create basic metrics dataframe
117
  basic_metrics = pd.DataFrame({
118
  "Validity": [scores["validity"].iloc[0]],
119
  "Uniqueness": [scores["uniqueness"].iloc[0]],
120
  "Novelty (Train)": [scores["novelty"].iloc[0]],
121
- "Novelty (Inference)": [scores["novelty_test"].iloc[0]],
122
- "Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]],
123
  "Runtime (s)": [round(et, 2)]
124
  })
125
 
@@ -129,13 +170,13 @@ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str,
129
  "SA Score": [scores["sa"].iloc[0]],
130
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
131
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
132
- "SNN Real Inhibitors": [scores["snn_drug"].iloc[0]],
133
- "Average Length": [scores["max_len"].iloc[0]]
134
  })
135
 
136
- # Process the output file from inference
137
- output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt'
138
- new_path = f'{target_model_name}_denovo_mols.smi'
139
  os.rename(output_file_path, new_path)
140
 
141
  with open(new_path) as f:
@@ -143,14 +184,13 @@ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str,
143
 
144
  generated_molecule_list = inference_drugs.split("\n")[:-1]
145
 
146
- # Randomly select up to 12 molecules for display
147
  rng = random.Random(config.seed)
148
  if len(generated_molecule_list) > 12:
149
- selected_smiles = rng.choices(generated_molecule_list, k=12)
150
  else:
151
- selected_smiles = generated_molecule_list
152
-
153
- selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None]
154
 
155
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
156
  drawOptions.prepareMolsBeforeDrawing = False
@@ -161,15 +201,21 @@ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str,
161
  molsPerRow=3,
162
  subImgSize=(400, 400),
163
  maxMols=len(selected_molecules),
 
164
  returnPNG=False,
165
  drawOptions=drawOptions,
166
  highlightAtomLists=None,
167
  highlightBondLists=None,
168
  )
169
 
 
 
 
 
170
  return molecule_image, new_path, basic_metrics, advanced_metrics
171
 
172
 
 
173
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
174
  # Add custom CSS for styling
175
  gr.HTML("""
@@ -185,40 +231,44 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
185
  </style>
186
  """)
187
 
188
- gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
189
-
190
- gr.HTML("""
191
- <div style="display: flex; gap: 10px; margin-bottom: 15px;">
192
- <!-- arXiv badge -->
193
- <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
194
- <div style="
195
- display: inline-block;
196
- background-color: #b31b1b;
197
- color: #ffffff !important;
198
- padding: 5px 10px;
199
- border-radius: 5px;
200
- font-size: 14px;">
201
- <span style="font-weight: bold;">arXiv</span> 2302.07868
202
- </div>
203
- </a>
204
-
205
- <!-- GitHub badge -->
206
- <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
207
- <div style="
208
- display: inline-block;
209
- background-color: #24292e;
210
- color: #ffffff !important;
211
- padding: 5px 10px;
212
- border-radius: 5px;
213
- font-size: 14px;">
214
- <span style="font-weight: bold;">GitHub</span> Repository
 
 
 
 
 
 
215
  </div>
216
- </a>
217
- </div>
218
- """)
219
-
220
- with gr.Accordion("About DrugGEN Models", open=False):
221
- gr.Markdown("""
222
  ## Model Variations
223
 
224
  ### DrugGEN-AKT1
@@ -228,158 +278,131 @@ This model is designed to generate molecules targeting the human AKT1 protein (U
228
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
229
 
230
  ### DrugGEN-NoTarget
231
- This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein.
232
- - Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties.
 
 
233
 
234
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
235
- """)
236
-
237
- with gr.Accordion("Understanding the Metrics", open=False):
238
- gr.Markdown("""
239
  ## Evaluation Metrics
240
 
241
  ### Basic Metrics
242
  - **Validity**: Percentage of generated molecules that are chemically valid
243
  - **Uniqueness**: Percentage of unique molecules among valid ones
244
- - **Runtime**: Time taken to generate or evaluate the molecules
245
 
246
  ### Novelty Metrics
247
  - **Novelty (Train)**: Percentage of molecules not found in the training set
248
- - **Novelty (Inference)**: Percentage of molecules not found in the test set
249
- - **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein
250
 
251
  ### Structural Metrics
252
- - **Average Length**: Average component length in the generated molecules
253
  - **Mean Atom Type**: Average distribution of atom types
254
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
255
 
256
  ### Drug-likeness Metrics
257
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
258
- - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better)
259
 
260
  ### Similarity Metrics
261
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
262
- - **SNN Real Inhibitors**: Similarity to known drugs (higher means more similar to approved drugs)
263
- """)
264
-
265
- # Use Gradio Tabs to separate the two modes.
266
- with gr.Tabs():
267
- with gr.TabItem("Classical Generation"):
268
- with gr.Row():
269
- with gr.Column(scale=1):
270
- model_name = gr.Radio(
271
- choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
272
- value="DrugGEN-AKT1",
273
- label="Select Target Model",
274
- info="Choose which protein target or general model to use for molecule generation"
275
- )
276
-
277
- num_molecules = gr.Slider(
278
- minimum=10,
279
- maximum=250,
280
- value=100,
281
- step=10,
282
- label="Number of Molecules to Generate",
283
- info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, we set a 250-molecule cap."
284
- )
285
-
286
- seed_num = gr.Textbox(
287
- label="Random Seed (Optional)",
288
- value="",
289
- info="Set a specific seed for reproducible results, or leave empty for random generation"
290
- )
291
-
292
- classical_submit = gr.Button(
293
- value="Generate Molecules",
294
- variant="primary",
295
- size="lg"
296
- )
297
- with gr.Column(scale=2):
298
- basic_metrics_df = gr.Dataframe(
299
- headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
300
- elem_id="basic-metrics"
301
- )
302
-
303
- advanced_metrics_df = gr.Dataframe(
304
- headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
305
- elem_id="advanced-metrics"
306
- )
307
-
308
- file_download = gr.File(
309
- label="Download All Generated Molecules (SMILES format)"
310
- )
311
-
312
- image_output = gr.Image(
313
- label="Structures of Randomly Selected Generated Molecules",
314
- elem_id="molecule_display"
315
- )
316
-
317
- with gr.TabItem("Custom Input SMILES"):
318
- with gr.Row():
319
- with gr.Column(scale=1):
320
- # Reuse model selection for custom input
321
- model_name_custom = gr.Radio(
322
- choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
323
- value="DrugGEN-AKT1",
324
- label="Select Target Model",
325
- info="Choose which protein target or general model to use for evaluation"
326
- )
327
- custom_smiles = gr.Textbox(
328
- label="Input SMILES (one per line, maximum 100 molecules)",
329
- placeholder="C(C(=O)O)N\nCCO\n...",
330
- lines=10
331
- )
332
- custom_submit = gr.Button(
333
- value="Evaluate Custom SMILES",
334
- variant="primary",
335
- size="lg"
336
- )
337
- with gr.Column(scale=2):
338
- basic_metrics_df_custom = gr.Dataframe(
339
- headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
340
- elem_id="basic-metrics-custom"
341
- )
342
-
343
- advanced_metrics_df_custom = gr.Dataframe(
344
- headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
345
- elem_id="advanced-metrics-custom"
346
- )
347
-
348
- file_download_custom = gr.File(
349
- label="Download All Molecules (SMILES format)"
350
- )
351
-
352
- image_output_custom = gr.Image(
353
- label="Structures of Randomly Selected Molecules",
354
- elem_id="molecule_display_custom"
355
- )
356
 
357
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
358
 
359
- # Set up the click actions for each tab.
360
- classical_submit.click(
361
- run_inference,
362
- inputs=[gr.Variable("Generate Molecules"), model_name, num_molecules, seed_num, gr.Textbox.update(value="")],
363
  outputs=[
364
  image_output,
365
  file_download,
366
  basic_metrics_df,
367
  advanced_metrics_df
368
- ],
369
- api_name="inference_classical"
370
- )
371
-
372
- custom_submit.click(
373
- run_inference,
374
- inputs=[gr.Variable("Custom Input SMILES"), model_name_custom, 0, gr.Textbox.update(value=""), custom_smiles],
375
- outputs=[
376
- image_output_custom,
377
- file_download_custom,
378
- basic_metrics_df_custom,
379
- advanced_metrics_df_custom
380
- ],
381
- api_name="inference_custom"
382
  )
383
-
384
  demo.queue()
385
  demo.launch()
 
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
+ submodel='DrugGEN'
17
+ inference_model="/home/user/app/experiments/models/DrugGEN/"
18
+ sample_num=100
19
 
20
  # Data configuration
21
+ inf_smiles='/home/user/app/data/chembl_test.smi'
22
+ train_smiles='/home/user/app/data/chembl_train.smi'
23
+ inf_batch_size=1
24
+ mol_data_dir='/home/user/app/data'
25
+ features=False
26
 
27
  # Model configuration
28
+ act='relu'
29
+ max_atom=45
30
+ dim=128
31
+ depth=1
32
+ heads=8
33
+ mlp_ratio=3
34
+ dropout=0.
35
 
36
  # Seed configuration
37
+ set_seed=True
38
+ seed=10
39
 
40
+ disable_correction=False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
+ submodel='DrugGEN'
45
+ inference_model="/home/user/app/experiments/models/DrugGEN-akt1/"
46
+ train_drug_smiles='/home/user/app/data/akt_train.smi'
47
+ max_atom=45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
+ submodel='DrugGEN'
52
+ inference_model="/home/user/app/experiments/models/DrugGEN-cdk2/"
53
+ train_drug_smiles='/home/user/app//data/cdk2_train.smi'
54
+ max_atom=38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
+ submodel="NoTarget"
59
+ inference_model="/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
 
66
  }
67
 
68
 
69
+
70
+ def function(model_name: str, input_mode: str, num_molecules: int = None, seed_num: str = None, smiles_input: str = None):
71
+ '''
 
72
  Returns:
73
+ image, metrics_df, file_path, basic_metrics, advanced_metrics
74
+ '''
75
+ if model_name == "DrugGEN-NoTarget":
76
+ model_name = "NoTarget"
77
+
78
  config = model_configs[model_name]
79
+
80
+ # Handle the input mode
81
+ if input_mode == "generate":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  config.sample_num = num_molecules
83
+
84
  if config.sample_num > 250:
85
  raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
86
+
87
+ if seed_num is None or seed_num.strip() == "":
88
+ config.seed = random.randint(0, 10000)
89
+ else:
90
+ try:
91
+ config.seed = int(seed_num)
92
+ except ValueError:
93
+ raise gr.Error("The seed must be an integer value!")
94
+ else: # input_mode == "smiles"
95
+ if not smiles_input or smiles_input.strip() == "":
96
+ raise gr.Error("Please enter at least one SMILES string.")
97
+
98
+ # Split by newlines and filter empty lines
99
+ smiles_list = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()]
100
+
101
+ if len(smiles_list) > 100:
102
+ raise gr.Error("You have entered more than the allowed limit of 100 SMILES. Please reduce your input.")
103
+
104
+ # Validate all SMILES
105
+ invalid_smiles = []
106
+ for i, smi in enumerate(smiles_list):
107
+ mol = Chem.MolFromSmiles(smi)
108
+ if mol is None:
109
+ invalid_smiles.append((i+1, smi))
110
+
111
+ if invalid_smiles:
112
+ invalid_str = "\n".join([f"Line {i}: {smi}" for i, smi in invalid_smiles])
113
+ raise gr.Error(f"The following SMILES are invalid:\n{invalid_str}")
114
+
115
+ # Save SMILES to a temporary file that matches the expected input format
116
+ temp_smiles_file = f'/home/user/app/data/temp_input.smi'
117
+ with open(temp_smiles_file, 'w') as f:
118
+ f.write('\n'.join(smiles_list))
119
+
120
+ # Update config to use this file
121
+ config.inf_smiles = temp_smiles_file
122
+ config.sample_num = len(smiles_list)
123
+
124
+ # Set a random seed if not provided
125
  if seed_num is None or seed_num.strip() == "":
126
  config.seed = random.randint(0, 10000)
127
  else:
 
130
  except ValueError:
131
  raise gr.Error("The seed must be an integer value!")
132
 
133
+ if model_name != "NoTarget":
134
+ model_name = "DrugGEN"
 
 
 
135
 
136
  inferer = Inference(config)
137
  start_time = time.time()
138
  scores = inferer.inference() # This returns a DataFrame with specific columns
139
  et = time.time() - start_time
140
 
141
+ score_df = pd.DataFrame({
142
+ "Runtime (seconds)": [et],
143
+ "Validity": [scores["validity"].iloc[0]],
144
+ "Uniqueness": [scores["uniqueness"].iloc[0]],
145
+ "Novelty (Train)": [scores["novelty"].iloc[0]],
146
+ "Novelty (Test)": [scores["novelty_test"].iloc[0]],
147
+ "Drug Novelty": [scores["drug_novelty"].iloc[0]],
148
+ "Max Length": [scores["max_len"].iloc[0]],
149
+ "Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
150
+ "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
151
+ "SNN Drug": [scores["snn_drug"].iloc[0]],
152
+ "Internal Diversity": [scores["IntDiv"].iloc[0]],
153
+ "QED": [scores["qed"].iloc[0]],
154
+ "SA Score": [scores["sa"].iloc[0]]
155
+ })
156
+
157
  # Create basic metrics dataframe
158
  basic_metrics = pd.DataFrame({
159
  "Validity": [scores["validity"].iloc[0]],
160
  "Uniqueness": [scores["uniqueness"].iloc[0]],
161
  "Novelty (Train)": [scores["novelty"].iloc[0]],
162
+ "Novelty (Test)": [scores["novelty_test"].iloc[0]],
163
+ "Drug Novelty": [scores["drug_novelty"].iloc[0]],
164
  "Runtime (s)": [round(et, 2)]
165
  })
166
 
 
170
  "SA Score": [scores["sa"].iloc[0]],
171
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
172
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
173
+ "SNN Drug": [scores["snn_drug"].iloc[0]],
174
+ "Max Length": [scores["max_len"].iloc[0]]
175
  })
176
 
177
+ output_file_path = f'/home/user/app/experiments/inference/{model_name}/inference_drugs.txt'
178
+
179
+ new_path = f'{model_name}_denovo_mols.smi'
180
  os.rename(output_file_path, new_path)
181
 
182
  with open(new_path) as f:
 
184
 
185
  generated_molecule_list = inference_drugs.split("\n")[:-1]
186
 
 
187
  rng = random.Random(config.seed)
188
  if len(generated_molecule_list) > 12:
189
+ selected_molecules = rng.choices(generated_molecule_list, k=12)
190
  else:
191
+ selected_molecules = generated_molecule_list
192
+
193
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
194
 
195
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
196
  drawOptions.prepareMolsBeforeDrawing = False
 
201
  molsPerRow=3,
202
  subImgSize=(400, 400),
203
  maxMols=len(selected_molecules),
204
+ # legends=None,
205
  returnPNG=False,
206
  drawOptions=drawOptions,
207
  highlightAtomLists=None,
208
  highlightBondLists=None,
209
  )
210
 
211
+ # Clean up the temporary file if it was created
212
+ if input_mode == "smiles" and os.path.exists(temp_smiles_file):
213
+ os.remove(temp_smiles_file)
214
+
215
  return molecule_image, new_path, basic_metrics, advanced_metrics
216
 
217
 
218
+
219
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
220
  # Add custom CSS for styling
221
  gr.HTML("""
 
231
  </style>
232
  """)
233
 
234
+ with gr.Row():
235
+ with gr.Column(scale=1):
236
+ gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
237
+
238
+ gr.HTML("""
239
+ <div style="display: flex; gap: 10px; margin-bottom: 15px;">
240
+ <!-- arXiv badge -->
241
+ <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
242
+ <div style="
243
+ display: inline-block;
244
+ background-color: #b31b1b;
245
+ color: #ffffff !important; /* Force white text */
246
+ padding: 5px 10px;
247
+ border-radius: 5px;
248
+ font-size: 14px;"
249
+ >
250
+ <span style="font-weight: bold;">arXiv</span> 2302.07868
251
+ </div>
252
+ </a>
253
+
254
+ <!-- GitHub badge -->
255
+ <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
256
+ <div style="
257
+ display: inline-block;
258
+ background-color: #24292e;
259
+ color: #ffffff !important; /* Force white text */
260
+ padding: 5px 10px;
261
+ border-radius: 5px;
262
+ font-size: 14px;"
263
+ >
264
+ <span style="font-weight: bold;">GitHub</span> Repository
265
+ </div>
266
+ </a>
267
  </div>
268
+ """)
269
+
270
+ with gr.Accordion("About DrugGEN Models", open=False):
271
+ gr.Markdown("""
 
 
272
  ## Model Variations
273
 
274
  ### DrugGEN-AKT1
 
278
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
279
 
280
  ### DrugGEN-NoTarget
281
+ This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
282
+ - Exploring chemical space
283
+ - Generating diverse scaffolds
284
+ - Creating molecules with drug-like properties
285
 
286
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
287
+ """)
288
+
289
+ with gr.Accordion("Understanding the Metrics", open=False):
290
+ gr.Markdown("""
291
  ## Evaluation Metrics
292
 
293
  ### Basic Metrics
294
  - **Validity**: Percentage of generated molecules that are chemically valid
295
  - **Uniqueness**: Percentage of unique molecules among valid ones
296
+ - **Runtime**: Time taken to generate the requested molecules
297
 
298
  ### Novelty Metrics
299
  - **Novelty (Train)**: Percentage of molecules not found in the training set
300
+ - **Novelty (Test)**: Percentage of molecules not found in the test set
301
+ - **Drug Novelty**: Percentage of molecules not found in known inhibitors of the target protein
302
 
303
  ### Structural Metrics
304
+ - **Max Length**: Maximum component length in the generated molecules
305
  - **Mean Atom Type**: Average distribution of atom types
306
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
307
 
308
  ### Drug-likeness Metrics
309
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
310
+ - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
311
 
312
  ### Similarity Metrics
313
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
314
+ - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
315
+ """)
316
+
317
+ model_name = gr.Radio(
318
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
319
+ value="DrugGEN-AKT1",
320
+ label="Select Target Model",
321
+ info="Choose which protein target or general model to use for molecule generation"
322
+ )
323
+
324
+ input_mode = gr.Radio(
325
+ choices=["generate", "smiles"],
326
+ value="generate",
327
+ label="Input Mode",
328
+ info="Choose to generate new molecules or provide your own SMILES strings",
329
+ elem_id="input_mode"
330
+ )
331
+
332
+ # Create container for generation mode inputs
333
+ with gr.Group(visible=True) as generate_group:
334
+ num_molecules = gr.Slider(
335
+ minimum=10,
336
+ maximum=250,
337
+ value=100,
338
+ step=10,
339
+ label="Number of Molecules to Generate",
340
+ info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU."
341
+ )
342
+
343
+ # Create container for SMILES input mode
344
+ with gr.Group(visible=False) as smiles_group:
345
+ smiles_input = gr.Textbox(
346
+ label="Input SMILES",
347
+ info="Enter up to 100 SMILES strings, one per line",
348
+ lines=10,
349
+ placeholder="CC(=O)OC1=CC=CC=C1C(=O)O\nCCO\nC1=CC=C(C=C1)C(=O)O\n...",
350
+ )
351
+
352
+ # Seed input is used by both modes
353
+ seed_num = gr.Textbox(
354
+ label="Random Seed (Optional)",
355
+ value="",
356
+ info="Set a specific seed for reproducible results, or leave empty for random generation"
357
+ )
358
+
359
+ # Handle visibility toggling between the two input modes
360
+ input_mode.change(
361
+ fn=lambda x: [x == "generate", x == "smiles"],
362
+ inputs=[input_mode],
363
+ outputs=[generate_group, smiles_group]
364
+ )
365
+
366
+ submit_button = gr.Button(
367
+ value="Generate Molecules",
368
+ variant="primary",
369
+ size="lg"
370
+ )
371
+
372
+ with gr.Column(scale=2):
373
+ basic_metrics_df = gr.Dataframe(
374
+ headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Novelty (Drug)", "Runtime (s)"],
375
+ elem_id="basic-metrics"
376
+ )
377
+
378
+ advanced_metrics_df = gr.Dataframe(
379
+ headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Drug)", "Max Length"],
380
+ elem_id="advanced-metrics"
381
+ )
382
+
383
+ file_download = gr.File(
384
+ label="Download All Generated Molecules (SMILES format)",
385
+ )
386
+
387
+ image_output = gr.Image(
388
+ label="Structures of Randomly Selected Generated Molecules",
389
+ elem_id="molecule_display"
390
+ )
391
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
394
 
395
+ submit_button.click(
396
+ function,
397
+ inputs=[model_name, input_mode, num_molecules, seed_num, smiles_input],
 
398
  outputs=[
399
  image_output,
400
  file_download,
401
  basic_metrics_df,
402
  advanced_metrics_df
403
+ ],
404
+ api_name="inference"
 
 
 
 
 
 
 
 
 
 
 
 
405
  )
406
+ #demo.queue(concurrency_count=1)
407
  demo.queue()
408
  demo.launch()