mgyigit commited on
Commit
39488b0
·
verified ·
1 Parent(s): 3efbb3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -367
app.py CHANGED
@@ -13,50 +13,50 @@ import time
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
- submodel='DrugGEN'
17
- inference_model="/home/user/app/experiments/models/DrugGEN/"
18
- sample_num=100
19
 
20
  # Data configuration
21
- inf_smiles='/home/user/app/data/chembl_test.smi'
22
- train_smiles='/home/user/app/data/chembl_train.smi'
23
- inf_batch_size=1
24
- mol_data_dir='/home/user/app/data'
25
- features=False
26
 
27
  # Model configuration
28
- act='relu'
29
- max_atom=45
30
- dim=128
31
- depth=1
32
- heads=8
33
- mlp_ratio=3
34
- dropout=0.
35
 
36
  # Seed configuration
37
- set_seed=True
38
- seed=10
39
 
40
- disable_correction=False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
- submodel='DrugGEN'
45
- inference_model="/home/user/app/experiments/models/DrugGEN-akt1/"
46
- train_drug_smiles='/home/user/app/data/akt_train.smi'
47
- max_atom=45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
- submodel='DrugGEN'
52
- inference_model="/home/user/app/experiments/models/DrugGEN-cdk2/"
53
- train_drug_smiles='/home/user/app//data/cdk2_train.smi'
54
- max_atom=38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
- submodel="NoTarget"
59
- inference_model="/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
@@ -66,24 +66,34 @@ model_configs = {
66
  }
67
 
68
 
69
-
70
- def function(model_name: str, input_mode: str, num_molecules: int = None, seed_num: str = None, smiles_input: str = None):
71
- '''
72
- Returns:
73
- image, metrics_df, file_path, basic_metrics, advanced_metrics
74
- '''
75
- if model_name == "DrugGEN-NoTarget":
76
- model_name = "NoTarget"
77
 
 
 
 
78
  config = model_configs[model_name]
79
-
80
- # Handle the input mode
81
- if input_mode == "generate":
82
- config.sample_num = num_molecules
83
-
84
- if config.sample_num > 250:
85
- raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if seed_num is None or seed_num.strip() == "":
88
  config.seed = random.randint(0, 10000)
89
  else:
@@ -91,70 +101,25 @@ def function(model_name: str, input_mode: str, num_molecules: int = None, seed_n
91
  config.seed = int(seed_num)
92
  except ValueError:
93
  raise gr.Error("The seed must be an integer value!")
94
- else: # input_mode == "smiles"
95
- if not smiles_input or smiles_input.strip() == "":
96
- raise gr.Error("Please enter at least one SMILES string.")
97
-
98
- # Split by newlines and filter empty lines
99
- smiles_list = [s.strip() for s in smiles_input.strip().split('\n') if s.strip()]
100
-
101
- if len(smiles_list) > 100:
102
- raise gr.Error("You have entered more than the allowed limit of 100 SMILES. Please reduce your input.")
103
-
104
- # Validate all SMILES
105
- invalid_smiles = []
106
- for i, smi in enumerate(smiles_list):
107
- mol = Chem.MolFromSmiles(smi)
108
- if mol is None:
109
- invalid_smiles.append((i+1, smi))
110
-
111
- if invalid_smiles:
112
- invalid_str = "\n".join([f"Line {i}: {smi}" for i, smi in invalid_smiles])
113
- raise gr.Error(f"The following SMILES are invalid:\n{invalid_str}")
114
-
115
- # Save SMILES to a temporary file that matches the expected input format
116
- temp_smiles_file = f'/home/user/app/data/temp_input.smi'
117
- with open(temp_smiles_file, 'w') as f:
118
- f.write('\n'.join(smiles_list))
119
-
120
- # Update config to use this file
121
- config.inf_smiles = temp_smiles_file
122
- config.sample_num = len(smiles_list)
123
-
124
- # Always use a fixed seed for SMILES mode
125
- config.seed = 42
126
 
127
- if model_name != "NoTarget":
128
- model_name = "DrugGEN"
 
 
 
129
 
130
  inferer = Inference(config)
131
  start_time = time.time()
132
  scores = inferer.inference() # This returns a DataFrame with specific columns
133
  et = time.time() - start_time
134
 
135
- score_df = pd.DataFrame({
136
- "Runtime (seconds)": [et],
137
- "Validity": [scores["validity"].iloc[0]],
138
- "Uniqueness": [scores["uniqueness"].iloc[0]],
139
- "Novelty (Train)": [scores["novelty"].iloc[0]],
140
- "Novelty (Test)": [scores["novelty_test"].iloc[0]],
141
- "Drug Novelty": [scores["drug_novelty"].iloc[0]],
142
- "Max Length": [scores["max_len"].iloc[0]],
143
- "Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
144
- "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
145
- "SNN Drug": [scores["snn_drug"].iloc[0]],
146
- "Internal Diversity": [scores["IntDiv"].iloc[0]],
147
- "QED": [scores["qed"].iloc[0]],
148
- "SA Score": [scores["sa"].iloc[0]]
149
- })
150
-
151
  # Create basic metrics dataframe
152
  basic_metrics = pd.DataFrame({
153
  "Validity": [scores["validity"].iloc[0]],
154
  "Uniqueness": [scores["uniqueness"].iloc[0]],
155
  "Novelty (Train)": [scores["novelty"].iloc[0]],
156
- "Novelty (Test)": [scores["novelty_test"].iloc[0]],
157
- "Drug Novelty": [scores["drug_novelty"].iloc[0]],
158
  "Runtime (s)": [round(et, 2)]
159
  })
160
 
@@ -164,13 +129,13 @@ def function(model_name: str, input_mode: str, num_molecules: int = None, seed_n
164
  "SA Score": [scores["sa"].iloc[0]],
165
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
166
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
167
- "SNN Drug": [scores["snn_drug"].iloc[0]],
168
- "Max Length": [scores["max_len"].iloc[0]]
169
  })
170
 
171
- output_file_path = f'/home/user/app/experiments/inference/{model_name}/inference_drugs.txt'
172
-
173
- new_path = f'{model_name}_denovo_mols.smi'
174
  os.rename(output_file_path, new_path)
175
 
176
  with open(new_path) as f:
@@ -178,13 +143,14 @@ def function(model_name: str, input_mode: str, num_molecules: int = None, seed_n
178
 
179
  generated_molecule_list = inference_drugs.split("\n")[:-1]
180
 
 
181
  rng = random.Random(config.seed)
182
  if len(generated_molecule_list) > 12:
183
- selected_molecules = rng.choices(generated_molecule_list, k=12)
184
  else:
185
- selected_molecules = generated_molecule_list
186
-
187
- selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
188
 
189
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
190
  drawOptions.prepareMolsBeforeDrawing = False
@@ -195,21 +161,15 @@ def function(model_name: str, input_mode: str, num_molecules: int = None, seed_n
195
  molsPerRow=3,
196
  subImgSize=(400, 400),
197
  maxMols=len(selected_molecules),
198
- # legends=None,
199
  returnPNG=False,
200
  drawOptions=drawOptions,
201
  highlightAtomLists=None,
202
  highlightBondLists=None,
203
  )
204
 
205
- # Clean up the temporary file if it was created
206
- if input_mode == "smiles" and os.path.exists(temp_smiles_file):
207
- os.remove(temp_smiles_file)
208
-
209
  return molecule_image, new_path, basic_metrics, advanced_metrics
210
 
211
 
212
-
213
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
214
  # Add custom CSS for styling
215
  gr.HTML("""
@@ -225,44 +185,40 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
225
  </style>
226
  """)
227
 
228
- with gr.Row():
229
- with gr.Column(scale=1):
230
- gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
231
-
232
- gr.HTML("""
233
- <div style="display: flex; gap: 10px; margin-bottom: 15px;">
234
- <!-- arXiv badge -->
235
- <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
236
- <div style="
237
- display: inline-block;
238
- background-color: #b31b1b;
239
- color: #ffffff !important; /* Force white text */
240
- padding: 5px 10px;
241
- border-radius: 5px;
242
- font-size: 14px;"
243
- >
244
- <span style="font-weight: bold;">arXiv</span> 2302.07868
245
- </div>
246
- </a>
247
-
248
- <!-- GitHub badge -->
249
- <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
250
- <div style="
251
- display: inline-block;
252
- background-color: #24292e;
253
- color: #ffffff !important; /* Force white text */
254
- padding: 5px 10px;
255
- border-radius: 5px;
256
- font-size: 14px;"
257
- >
258
- <span style="font-weight: bold;">GitHub</span> Repository
259
- </div>
260
- </a>
261
  </div>
262
- """)
263
-
264
- with gr.Accordion("About DrugGEN Models", open=False):
265
- gr.Markdown("""
 
 
266
  ## Model Variations
267
 
268
  ### DrugGEN-AKT1
@@ -272,256 +228,159 @@ This model is designed to generate molecules targeting the human AKT1 protein (U
272
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
273
 
274
  ### DrugGEN-NoTarget
275
- This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
276
- - Exploring chemical space
277
- - Generating diverse scaffolds
278
- - Creating molecules with drug-like properties
279
 
280
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
281
- """)
282
-
283
- with gr.Accordion("Understanding the Metrics", open=False):
284
- gr.Markdown("""
285
  ## Evaluation Metrics
286
 
287
  ### Basic Metrics
288
  - **Validity**: Percentage of generated molecules that are chemically valid
289
  - **Uniqueness**: Percentage of unique molecules among valid ones
290
- - **Runtime**: Time taken to generate the requested molecules
291
 
292
  ### Novelty Metrics
293
  - **Novelty (Train)**: Percentage of molecules not found in the training set
294
- - **Novelty (Test)**: Percentage of molecules not found in the test set
295
- - **Drug Novelty**: Percentage of molecules not found in known inhibitors of the target protein
296
 
297
  ### Structural Metrics
298
- - **Max Length**: Maximum component length in the generated molecules
299
  - **Mean Atom Type**: Average distribution of atom types
300
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
301
 
302
  ### Drug-likeness Metrics
303
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
304
- - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
305
 
306
  ### Similarity Metrics
307
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
308
- - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
309
- """)
310
-
311
- model_name = gr.Radio(
312
- choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
313
- value="DrugGEN-AKT1",
314
- label="Select Target Model",
315
- info="Choose which protein target or general model to use for molecule generation"
316
- )
317
-
318
- # Add a separator between model selection and input mode
319
- gr.Markdown("---")
320
- gr.Markdown("## Input Settings")
321
-
322
- # Replace radio with switch using a better layout
323
- with gr.Row(equal_height=True):
324
- with gr.Column(scale=1, min_width=150):
325
- gr.Markdown("### Classic Generation", elem_id="generate-mode-label")
326
-
327
- with gr.Column(scale=1, min_width=150):
328
- input_mode_switch = gr.Checkbox(
329
- value=False,
330
- label="Switch Input Mode",
331
- elem_id="input-mode-switch"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  )
333
 
334
- with gr.Column(scale=1, min_width=150):
335
- gr.Markdown("### Custom SMILES Input", elem_id="smiles-mode-label")
336
-
337
- # Add custom CSS and JavaScript for better styling
338
- gr.HTML("""
339
- <style>
340
- #input-mode-switch {
341
- margin: 20px auto;
342
- display: flex;
343
- justify-content: center;
344
- }
345
-
346
- #generate-mode-label, #smiles-mode-label {
347
- text-align: center;
348
- margin-top: 10px;
349
- font-weight: bold;
350
- transition: opacity 0.3s ease;
351
- }
352
-
353
- /* Make the inactive mode label more subtle */
354
- #generate-mode-label {
355
- opacity: 1;
356
- color: #4CAF50;
357
- }
358
-
359
- #smiles-mode-label {
360
- opacity: 0.5;
361
- color: #2196F3;
362
- }
363
-
364
- .active-mode {
365
- text-decoration: underline;
366
- font-size: 1.1em;
367
- }
368
-
369
- /* Style for the input boxes */
370
- .input-box {
371
- border: 2px solid rgba(128, 128, 228, 0.3);
372
- border-radius: 10px;
373
- padding: 15px;
374
- margin-top: 15px;
375
- background-color: rgba(32, 36, 45, 0.7);
376
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
377
- transition: all 0.3s ease;
378
- }
379
-
380
- .input-box:hover {
381
- border-color: rgba(128, 128, 228, 0.6);
382
- box-shadow: 0 6px 8px rgba(0, 0, 0, 0.15);
383
- }
384
-
385
- /* Style the checkbox */
386
- #input-mode-switch label {
387
- font-weight: bold;
388
- font-size: 1.1em;
389
- color: rgba(128, 128, 228, 0.9);
390
- }
391
-
392
- /* Add a hint to indicate the toggle functionality */
393
- #input-mode-switch::after {
394
- content: 'Click to toggle between modes';
395
- display: block;
396
- text-align: center;
397
- font-size: 0.8em;
398
- opacity: 0.7;
399
- margin-top: 5px;
400
- }
401
- </style>
402
-
403
- <script>
404
- // Add JavaScript to enhance the mode switching UI
405
- document.addEventListener('DOMContentLoaded', function() {
406
- // Get references to elements
407
- const checkbox = document.querySelector('#input-mode-switch input[type="checkbox"]');
408
- const generateLabel = document.querySelector('#generate-mode-label');
409
- const smilesLabel = document.querySelector('#smiles-mode-label');
410
-
411
- // Add initial active class
412
- generateLabel.classList.add('active-mode');
413
-
414
- // Add event listener to checkbox
415
- if (checkbox) {
416
- checkbox.addEventListener('change', function() {
417
- if (this.checked) {
418
- // SMILES mode is active
419
- generateLabel.style.opacity = '0.5';
420
- smilesLabel.style.opacity = '1';
421
- generateLabel.classList.remove('active-mode');
422
- smilesLabel.classList.add('active-mode');
423
- } else {
424
- // Generate mode is active
425
- generateLabel.style.opacity = '1';
426
- smilesLabel.style.opacity = '0.5';
427
- generateLabel.classList.add('active-mode');
428
- smilesLabel.classList.remove('active-mode');
429
- }
430
- });
431
- }
432
- });
433
- </script>
434
- """)
435
-
436
- # Create container for generation mode inputs
437
- with gr.Group(visible=True, elem_id="generate-box", elem_classes="input-box") as generate_group:
438
- num_molecules = gr.Slider(
439
- minimum=10,
440
- maximum=250,
441
- value=100,
442
- step=10,
443
- label="Number of Molecules to Generate",
444
- info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU."
445
- )
446
-
447
- # Seed input used in generate mode
448
- seed_num_generate = gr.Textbox(
449
- label="Random Seed (Optional)",
450
- value="",
451
- info="Set a specific seed for reproducible results, or leave empty for random generation"
452
- )
453
-
454
- # Create container for SMILES input mode
455
- with gr.Group(visible=False, elem_id="smiles-box", elem_classes="input-box") as smiles_group:
456
- smiles_input = gr.Textbox(
457
- label="Input SMILES",
458
- info="Enter up to 100 SMILES strings, one per line",
459
- lines=10,
460
- placeholder="CC(=O)OC1=CC=CC=C1C(=O)O\nCCO\nC1=CC=C(C=C1)C(=O)O\n...",
461
- )
462
-
463
- # Handle visibility toggling between the two input modes
464
- def toggle_visibility(checkbox_value):
465
- return not checkbox_value, checkbox_value
466
-
467
- input_mode_switch.change(
468
- fn=toggle_visibility,
469
- inputs=[input_mode_switch],
470
- outputs=[generate_group, smiles_group]
471
- )
472
-
473
- submit_button = gr.Button(
474
- value="Generate Molecules",
475
- variant="primary",
476
- size="lg"
477
- )
478
-
479
- # Helper function to determine which mode is active and which seed to use
480
- def get_inputs(checkbox_value, num_mols, seed_gen, smiles):
481
- mode = "smiles" if checkbox_value else "generate"
482
- seed = "42" if checkbox_value else seed_gen # Use default seed 42 for SMILES mode
483
- return [mode, num_mols, seed, smiles]
484
-
485
- with gr.Column(scale=2):
486
- basic_metrics_df = gr.Dataframe(
487
- headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Novelty (Drug)", "Runtime (s)"],
488
- elem_id="basic-metrics"
489
- )
490
-
491
- advanced_metrics_df = gr.Dataframe(
492
- headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Drug)", "Max Length"],
493
- elem_id="advanced-metrics"
494
- )
495
-
496
- file_download = gr.File(
497
- label="Download All Generated Molecules (SMILES format)",
498
- )
499
-
500
- image_output = gr.Image(
501
- label="Structures of Randomly Selected Generated Molecules",
502
- elem_id="molecule_display"
503
- )
504
 
 
 
 
 
 
 
 
 
505
 
506
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
507
 
508
- submit_button.click(
509
- fn=lambda model, checkbox, num_mols, seed_gen, smiles: function(
510
- model,
511
- "smiles" if checkbox else "generate",
512
- num_mols,
513
- "42" if checkbox else seed_gen, # Use default seed 42 for SMILES mode
514
- smiles
515
- ),
516
- inputs=[model_name, input_mode_switch, num_molecules, seed_num_generate, smiles_input],
517
  outputs=[
518
  image_output,
519
  file_download,
520
  basic_metrics_df,
521
  advanced_metrics_df
522
- ],
523
- api_name="inference"
 
 
 
 
 
 
 
 
 
 
 
 
524
  )
525
- #demo.queue(concurrency_count=1)
526
  demo.queue()
527
- demo.launch()
 
 
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
+ submodel = 'DrugGEN'
17
+ inference_model = "/home/user/app/experiments/models/DrugGEN/"
18
+ sample_num = 100
19
 
20
  # Data configuration
21
+ inf_smiles = '/home/user/app/data/chembl_test.smi'
22
+ train_smiles = '/home/user/app/data/chembl_train.smi'
23
+ inf_batch_size = 1
24
+ mol_data_dir = '/home/user/app/data'
25
+ features = False
26
 
27
  # Model configuration
28
+ act = 'relu'
29
+ max_atom = 45
30
+ dim = 128
31
+ depth = 1
32
+ heads = 8
33
+ mlp_ratio = 3
34
+ dropout = 0.
35
 
36
  # Seed configuration
37
+ set_seed = True
38
+ seed = 10
39
 
40
+ disable_correction = False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
+ submodel = 'DrugGEN'
45
+ inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/"
46
+ train_drug_smiles = '/home/user/app/data/akt_train.smi'
47
+ max_atom = 45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
+ submodel = 'DrugGEN'
52
+ inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/"
53
+ train_drug_smiles = '/home/user/app/data/cdk2_train.smi'
54
+ max_atom = 38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
+ submodel = "NoTarget"
59
+ inference_model = "/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
 
66
  }
67
 
68
 
69
+ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str):
70
+ """
71
+ Depending on the selected mode, either generate new molecules or evaluate provided SMILES.
 
 
 
 
 
72
 
73
+ Returns:
74
+ image, file_path, basic_metrics, advanced_metrics
75
+ """
76
  config = model_configs[model_name]
 
 
 
 
 
 
 
77
 
78
+ if mode == "Custom Input SMILES":
79
+ # Process the custom input SMILES
80
+ smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""]
81
+ if len(smiles_list) > 100:
82
+ raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.")
83
+ # Write the custom SMILES to a temporary file and update config
84
+ temp_input_file = "custom_input.smi"
85
+ with open(temp_input_file, "w") as f:
86
+ for s in smiles_list:
87
+ f.write(s + "\n")
88
+ config.inf_smiles = temp_input_file
89
+ config.sample_num = len(smiles_list)
90
+ # Always use a random seed for custom mode
91
+ config.seed = random.randint(0, 10000)
92
+ else:
93
+ # Classical Generation mode
94
+ config.sample_num = num_molecules
95
+ if config.sample_num > 200:
96
+ raise gr.Error("You have requested to generate more than the allowed limit of 200 molecules. Please reduce your request to 200 or fewer.")
97
  if seed_num is None or seed_num.strip() == "":
98
  config.seed = random.randint(0, 10000)
99
  else:
 
101
  config.seed = int(seed_num)
102
  except ValueError:
103
  raise gr.Error("The seed must be an integer value!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Adjust model name for the inference if not using NoTarget
106
+ if model_name != "DrugGEN-NoTarget":
107
+ target_model_name = "DrugGEN"
108
+ else:
109
+ target_model_name = "NoTarget"
110
 
111
  inferer = Inference(config)
112
  start_time = time.time()
113
  scores = inferer.inference() # This returns a DataFrame with specific columns
114
  et = time.time() - start_time
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Create basic metrics dataframe
117
  basic_metrics = pd.DataFrame({
118
  "Validity": [scores["validity"].iloc[0]],
119
  "Uniqueness": [scores["uniqueness"].iloc[0]],
120
  "Novelty (Train)": [scores["novelty"].iloc[0]],
121
+ "Novelty (Inference)": [scores["novelty_test"].iloc[0]],
122
+ "Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]],
123
  "Runtime (s)": [round(et, 2)]
124
  })
125
 
 
129
  "SA Score": [scores["sa"].iloc[0]],
130
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
131
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
132
+ "SNN Real Inhibitors": [scores["snn_drug"].iloc[0]],
133
+ "Average Length": [scores["max_len"].iloc[0]]
134
  })
135
 
136
+ # Process the output file from inference
137
+ output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt'
138
+ new_path = f'{target_model_name}_denovo_mols.smi'
139
  os.rename(output_file_path, new_path)
140
 
141
  with open(new_path) as f:
 
143
 
144
  generated_molecule_list = inference_drugs.split("\n")[:-1]
145
 
146
+ # Randomly select up to 12 molecules for display
147
  rng = random.Random(config.seed)
148
  if len(generated_molecule_list) > 12:
149
+ selected_smiles = rng.choices(generated_molecule_list, k=12)
150
  else:
151
+ selected_smiles = generated_molecule_list
152
+
153
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None]
154
 
155
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
156
  drawOptions.prepareMolsBeforeDrawing = False
 
161
  molsPerRow=3,
162
  subImgSize=(400, 400),
163
  maxMols=len(selected_molecules),
 
164
  returnPNG=False,
165
  drawOptions=drawOptions,
166
  highlightAtomLists=None,
167
  highlightBondLists=None,
168
  )
169
 
 
 
 
 
170
  return molecule_image, new_path, basic_metrics, advanced_metrics
171
 
172
 
 
173
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
174
  # Add custom CSS for styling
175
  gr.HTML("""
 
185
  </style>
186
  """)
187
 
188
+ gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
189
+
190
+ gr.HTML("""
191
+ <div style="display: flex; gap: 10px; margin-bottom: 15px;">
192
+ <!-- arXiv badge -->
193
+ <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
194
+ <div style="
195
+ display: inline-block;
196
+ background-color: #b31b1b;
197
+ color: #ffffff !important;
198
+ padding: 5px 10px;
199
+ border-radius: 5px;
200
+ font-size: 14px;">
201
+ <span style="font-weight: bold;">arXiv</span> 2302.07868
202
+ </div>
203
+ </a>
204
+
205
+ <!-- GitHub badge -->
206
+ <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
207
+ <div style="
208
+ display: inline-block;
209
+ background-color: #24292e;
210
+ color: #ffffff !important;
211
+ padding: 5px 10px;
212
+ border-radius: 5px;
213
+ font-size: 14px;">
214
+ <span style="font-weight: bold;">GitHub</span> Repository
 
 
 
 
 
 
215
  </div>
216
+ </a>
217
+ </div>
218
+ """)
219
+
220
+ with gr.Accordion("About DrugGEN Models", open=False):
221
+ gr.Markdown("""
222
  ## Model Variations
223
 
224
  ### DrugGEN-AKT1
 
228
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
229
 
230
  ### DrugGEN-NoTarget
231
+ This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein.
232
+ - Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties.
 
 
233
 
234
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
235
+ """)
236
+
237
+ with gr.Accordion("Understanding the Metrics", open=False):
238
+ gr.Markdown("""
239
  ## Evaluation Metrics
240
 
241
  ### Basic Metrics
242
  - **Validity**: Percentage of generated molecules that are chemically valid
243
  - **Uniqueness**: Percentage of unique molecules among valid ones
244
+ - **Runtime**: Time taken to generate or evaluate the molecules
245
 
246
  ### Novelty Metrics
247
  - **Novelty (Train)**: Percentage of molecules not found in the training set
248
+ - **Novelty (Inference)**: Percentage of molecules not found in the test set
249
+ - **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein
250
 
251
  ### Structural Metrics
252
+ - **Average Length**: Average component length in the generated molecules
253
  - **Mean Atom Type**: Average distribution of atom types
254
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
255
 
256
  ### Drug-likeness Metrics
257
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
258
+ - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better)
259
 
260
  ### Similarity Metrics
261
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
262
+ - **SNN Real Inhibitors**: Similarity to known drugs (higher means more similar to approved drugs)
263
+ """)
264
+
265
+ # Use Gradio Tabs to separate the two modes.
266
+ with gr.Tabs():
267
+ with gr.TabItem("Classical Generation"):
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ model_name = gr.Radio(
271
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
272
+ value="DrugGEN-AKT1",
273
+ label="Select Target Model",
274
+ info="Choose which protein target or general model to use for molecule generation"
275
+ )
276
+
277
+ num_molecules = gr.Slider(
278
+ minimum=10,
279
+ maximum=200,
280
+ value=100,
281
+ step=10,
282
+ label="Number of Molecules to Generate",
283
+ info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 200-molecule cap."
284
+ )
285
+
286
+ seed_num = gr.Textbox(
287
+ label="Random Seed (Optional)",
288
+ value="",
289
+ info="Set a specific seed for reproducible results, or leave empty for random generation"
290
+ )
291
+
292
+ classical_submit = gr.Button(
293
+ value="Generate Molecules",
294
+ variant="primary",
295
+ size="lg"
296
+ )
297
+ with gr.Column(scale=2):
298
+ basic_metrics_df = gr.Dataframe(
299
+ headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
300
+ elem_id="basic-metrics"
301
+ )
302
+
303
+ advanced_metrics_df = gr.Dataframe(
304
+ headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
305
+ elem_id="advanced-metrics"
306
+ )
307
+
308
+ file_download = gr.File(
309
+ label="Download All Generated Molecules (SMILES format)"
310
  )
311
 
312
+ image_output = gr.Image(
313
+ label="Structures of Randomly Selected Generated Molecules",
314
+ elem_id="molecule_display"
315
+ )
316
+
317
+ with gr.TabItem("Custom Input SMILES"):
318
+ with gr.Row():
319
+ with gr.Column(scale=1):
320
+ # Reuse model selection for custom input
321
+ model_name_custom = gr.Radio(
322
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
323
+ value="DrugGEN-AKT1",
324
+ label="Select Target Model",
325
+ info="Choose which protein target or general model to use for evaluation"
326
+ )
327
+ custom_smiles = gr.Textbox(
328
+ label="Input SMILES (one per line, maximum 100 molecules)",
329
+ placeholder="C(C(=O)O)N\nCCO\n...",
330
+ lines=10
331
+ )
332
+ custom_submit = gr.Button(
333
+ value="Evaluate Custom SMILES",
334
+ variant="primary",
335
+ size="lg"
336
+ )
337
+ with gr.Column(scale=2):
338
+ basic_metrics_df_custom = gr.Dataframe(
339
+ headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
340
+ elem_id="basic-metrics-custom"
341
+ )
342
+
343
+ advanced_metrics_df_custom = gr.Dataframe(
344
+ headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
345
+ elem_id="advanced-metrics-custom"
346
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
+ file_download_custom = gr.File(
349
+ label="Download All Molecules (SMILES format)"
350
+ )
351
+
352
+ image_output_custom = gr.Image(
353
+ label="Structures of Randomly Selected Molecules",
354
+ elem_id="molecule_display_custom"
355
+ )
356
 
357
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
358
 
359
+ # Set up the click actions for each tab.
360
+ classical_submit.click(
361
+ run_inference,
362
+ inputs=[gr.State("Generate Molecules"), model_name, num_molecules, seed_num, gr.State("")],
 
 
 
 
 
363
  outputs=[
364
  image_output,
365
  file_download,
366
  basic_metrics_df,
367
  advanced_metrics_df
368
+ ],
369
+ api_name="inference_classical"
370
+ )
371
+
372
+ custom_submit.click(
373
+ run_inference,
374
+ inputs=[gr.State("Custom Input SMILES"), model_name_custom, gr.State(0), gr.State(""), custom_smiles],
375
+ outputs=[
376
+ image_output_custom,
377
+ file_download_custom,
378
+ basic_metrics_df_custom,
379
+ advanced_metrics_df_custom
380
+ ],
381
+ api_name="inference_custom"
382
  )
383
+
384
  demo.queue()
385
+ demo.launch()
386
+