|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple |
|
|
|
from ...extras.constants import IGNORE_INDEX, BOND_INDEX, NO_LABEL_INDEX |
|
from ...extras.logging import get_logger |
|
|
|
if TYPE_CHECKING: |
|
from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
|
from ...hparams import DataArguments |
|
from ..template import Template |
|
|
|
import os |
|
from rdkit import Chem |
|
import torch |
|
from torch_geometric.data import Data, Batch |
|
import pickle |
|
|
|
logger = get_logger(__name__) |
|
|
|
import os |
|
import torch |
|
from typing import Dict |
|
from torch_geometric.data import Data |
|
from rdkit import Chem |
|
import pickle |
|
|
|
|
|
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: |
|
if target_len * 2 < cutoff_len: |
|
max_target_len = cutoff_len |
|
elif source_len * 2 < cutoff_len: |
|
max_target_len = cutoff_len - source_len |
|
else: |
|
max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) |
|
|
|
new_target_len = min(max_target_len, target_len) |
|
new_source_len = max(cutoff_len - new_target_len, 0) |
|
return new_source_len, new_target_len |
|
|
|
def encode_graph_pyg( |
|
data_path: Optional[str] = None, mol_id_to_smiles: Optional[Dict[str, str]] = None |
|
) -> Dict[str, Data]: |
|
""" |
|
Converts molecule data to a dictionary of PyTorch Geometric Data objects, with caching functionality. |
|
Uses a sparse representation for efficiency. |
|
|
|
Args: |
|
data_path (Optional[str]): Path to the Hugging Face dataset folder. |
|
mol_id_to_smiles (Optional[Dict[str, str]]): Dictionary where keys are molecule IDs |
|
and values are SMILES strings. |
|
|
|
Returns: |
|
Dict[str, Data]: Dictionary where keys are molecule IDs and values are |
|
PyTorch Geometric Data objects. |
|
|
|
Raises: |
|
ValueError: If both data_path and mol_id_to_smiles are None, or if data_path is provided but loading fails. |
|
""" |
|
print(f"Current execution directory: {os.getcwd()}") |
|
|
|
if data_path is None and mol_id_to_smiles is None: |
|
raise ValueError("Either data_path or mol_id_to_smiles must be provided.") |
|
|
|
if data_path is not None: |
|
cache_file = os.path.join(data_path, "pyg_molecule.pickle") |
|
|
|
|
|
if os.path.exists(cache_file): |
|
try: |
|
with open(cache_file, "rb") as f: |
|
return pickle.load(f) |
|
except Exception as e: |
|
print(f"Failed to load cached data: {e}") |
|
|
|
mol_id_to_pyg = {} |
|
|
|
for mol_id, smiles in mol_id_to_smiles.items(): |
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
raise ValueError(f"Invalid SMILES string for molecule {mol_id}: {smiles}") |
|
|
|
type_idx = [] |
|
heavy_atom_indices = [] |
|
for atom in mol.GetAtoms(): |
|
if atom.GetAtomicNum() != 1: |
|
type_idx.append( |
|
119 - 2 if atom.GetSymbol() == "*" else atom.GetAtomicNum() - 2 |
|
) |
|
heavy_atom_indices.append(atom.GetIdx()) |
|
|
|
x = torch.LongTensor(type_idx) |
|
|
|
edge_index = [] |
|
edge_attr = [] |
|
for bond in mol.GetBonds(): |
|
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
|
if start in heavy_atom_indices and end in heavy_atom_indices: |
|
start_new, end_new = heavy_atom_indices.index( |
|
start |
|
), heavy_atom_indices.index(end) |
|
edge_index.extend([[start_new, end_new], [end_new, start_new]]) |
|
bond_type = BOND_INDEX[bond.GetBondType()] |
|
edge_attr.extend([bond_type, bond_type]) |
|
|
|
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() |
|
edge_attr = torch.tensor(edge_attr, dtype=torch.long) |
|
|
|
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) |
|
|
|
mol_id_to_pyg[mol_id] = data |
|
|
|
|
|
if data_path is not None: |
|
with open(cache_file, "wb") as f: |
|
pickle.dump(mol_id_to_pyg, f) |
|
|
|
print(f"Saved PyG data to {cache_file}") |
|
|
|
return mol_id_to_pyg |
|
|
|
def encode_supervised_example( |
|
prompt: Sequence[Dict[str, str]], |
|
response: Sequence[Dict[str, str]], |
|
system: Optional[str], |
|
molecule_ids: List[int], |
|
retro_product_ids: List[int], |
|
retro_labels: List[int], |
|
template: "Template", |
|
tokenizer: "PreTrainedTokenizer", |
|
data_args: "DataArguments", |
|
) -> Tuple[List[int], List[int], List[int], List[int], List[int]]: |
|
|
|
messages = prompt + response |
|
input_ids, labels = [], [] |
|
final_molecule_ids = [] |
|
final_product_ids = [] |
|
final_retro_labels = [] |
|
|
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system) |
|
special_tokens = [ |
|
"<design_start>", |
|
"<design_end>", |
|
"<design_body>", |
|
"<molecule>", |
|
"<retro_start>", |
|
"<retro_end>", |
|
"<retro_body>", |
|
] |
|
special_token_ids = template._convert_elements_to_ids(tokenizer, special_tokens) |
|
special_token_dict = dict(zip(special_tokens, special_token_ids)) |
|
|
|
total_length = 1 if template.efficient_eos else 0 |
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): |
|
if total_length >= data_args.cutoff_len: |
|
break |
|
|
|
source_len, target_len = infer_seqlen( |
|
len(source_ids), len(target_ids), data_args.cutoff_len - total_length |
|
) |
|
source_ids = source_ids[:source_len] |
|
|
|
|
|
retro_start_indices = [ |
|
i |
|
for i, id in enumerate(target_ids) |
|
if id == special_token_dict["<retro_start>"] |
|
] |
|
retro_end_indices = [ |
|
i |
|
for i, id in enumerate(target_ids) |
|
if id == special_token_dict["<retro_end>"] |
|
] |
|
|
|
if retro_start_indices and retro_end_indices: |
|
|
|
last_pair_index = -1 |
|
for start, end in zip(retro_start_indices, retro_end_indices): |
|
if end < target_len: |
|
last_pair_index = end |
|
else: |
|
break |
|
|
|
if last_pair_index >= 0: |
|
target_len = last_pair_index + 1 |
|
else: |
|
|
|
target_len = ( |
|
min(target_len, retro_start_indices[0]) |
|
if retro_start_indices |
|
else target_len |
|
) |
|
|
|
target_ids = target_ids[:target_len] |
|
|
|
|
|
molecules_in_turn = target_ids.count(special_token_dict["<molecule>"]) |
|
retro_start_in_turn = target_ids.count(special_token_dict["<retro_start>"]) |
|
retro_end_in_turn = target_ids.count(special_token_dict["<retro_end>"]) |
|
|
|
assert retro_start_in_turn == retro_end_in_turn |
|
|
|
retro_product_ids_in_turn = retro_product_ids[:retro_end_in_turn] |
|
retro_labels_in_turn = retro_labels[:retro_end_in_turn] |
|
|
|
|
|
final_molecule_ids.extend(molecule_ids[:molecules_in_turn]) |
|
final_product_ids.extend(retro_product_ids_in_turn) |
|
final_retro_labels.extend(retro_labels_in_turn) |
|
|
|
total_length += source_len + target_len |
|
|
|
if data_args.train_on_prompt: |
|
source_mask = source_ids |
|
elif turn_idx != 0 and template.efficient_eos: |
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * ( |
|
len(source_ids) - 1 |
|
) |
|
else: |
|
source_mask = [IGNORE_INDEX] * len(source_ids) |
|
|
|
source_mask = [ |
|
IGNORE_INDEX if id in special_token_dict.values() else id |
|
for id in source_mask |
|
] |
|
target_ids_mask = [ |
|
id if id in [special_token_dict["<retro_start>"], special_token_dict["<design_start>"]] |
|
else (IGNORE_INDEX if id in special_token_dict.values() else id) |
|
for id in target_ids |
|
] |
|
|
|
input_ids += source_ids + target_ids |
|
labels += source_mask + target_ids_mask |
|
|
|
if template.efficient_eos: |
|
input_ids += [tokenizer.eos_token_id] |
|
labels += [tokenizer.eos_token_id] |
|
|
|
return input_ids, labels, final_molecule_ids, final_product_ids, final_retro_labels |
|
|
|
|
|
def preprocess_mmsupervised_dataset( |
|
examples: Dict[str, List[Any]], |
|
template: "Template", |
|
tokenizer: "PreTrainedTokenizer", |
|
data_args: "DataArguments", |
|
) -> Dict[str, List[List[int]]]: |
|
model_inputs = { |
|
"input_ids": [], |
|
"attention_mask": [], |
|
"labels": [], |
|
"molecule_ids": [], |
|
"molecule_properties": [], |
|
"retro_labels": [], |
|
"retro_product_ids": [], |
|
} |
|
|
|
for i in range(len(examples["prompt"])): |
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: |
|
logger.warning( |
|
"Dropped invalid example: {}".format( |
|
examples["prompt"][i] + examples["response"][i] |
|
) |
|
) |
|
continue |
|
|
|
retro_product_ids = examples["retro_products"][i] |
|
retro_labels = [ |
|
NO_LABEL_INDEX if label is None else label |
|
for label in examples["retro_labels"][i] |
|
] |
|
properties = [ |
|
NO_LABEL_INDEX if prop is None else prop for prop in examples["property"][i] |
|
] |
|
|
|
input_ids, labels, molecule_ids, retro_product_ids, retro_labels = ( |
|
encode_supervised_example( |
|
prompt=examples["prompt"][i], |
|
response=examples["response"][i], |
|
system=examples["system"][i], |
|
molecule_ids=examples["molecules"][i], |
|
retro_product_ids=retro_product_ids, |
|
retro_labels=retro_labels, |
|
template=template, |
|
tokenizer=tokenizer, |
|
data_args=data_args, |
|
) |
|
) |
|
|
|
|
|
model_inputs["input_ids"].append(input_ids) |
|
model_inputs["attention_mask"].append([1] * len(input_ids)) |
|
model_inputs["labels"].append(labels) |
|
model_inputs["molecule_ids"].append(molecule_ids) |
|
model_inputs["molecule_properties"].append(properties) |
|
model_inputs["retro_labels"].append(retro_labels) |
|
model_inputs["retro_product_ids"].append(retro_product_ids) |
|
|
|
return model_inputs |
|
|
|
def print_supervised_dataset_example( |
|
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer" |
|
) -> None: |
|
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) |
|
print("Print_supervised_dataset_example") |
|
|
|
print("input_ids:\n{}".format(example["input_ids"])) |
|
print( |
|
"inputs:\n{}".format( |
|
tokenizer.decode(example["input_ids"], skip_special_tokens=False) |
|
) |
|
) |
|
print("label_ids:\n{}".format(example["labels"])) |
|
print( |
|
"labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)) |
|
) |
|
print("molecule_ids:\n{}".format(example["molecule_ids"])) |
|
print("molecule_properties:\n{}".format(example["molecule_properties"])) |
|
print("retro_labels:\n{}".format(example["retro_labels"])) |
|
print("retro_product_ids:\n{}".format(example["retro_product_ids"])) |
|
|