File size: 4,762 Bytes
13362e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2024 Llamole Team
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
import json
import math
import torch
from torch.utils.data import DataLoader
from typing import TYPE_CHECKING, List, Optional, Dict, Any

from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
from ..extras.misc import get_logits_processor
from ..extras.ploting import plot_loss
from ..model import load_tokenizer, GraphLLMForCausalMLM
from ..hparams import get_train_args
from .dataset import MolQADataset

if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments
    from ..hparams import (
        DataArguments,
        FinetuningArguments,
        GeneratingArguments,
        ModelArguments,
    )

def remove_extra_spaces(text):
    cleaned_text = re.sub(r'\s+', ' ', text)
    return cleaned_text.strip()

def load_model_and_tokenizer(args):
    model_args, data_args, training_args, finetuning_args, generating_args = (
        get_train_args(args)
    )
    tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]
    tokenizer.pad_token = tokenizer.eos_token

    model = GraphLLMForCausalMLM.from_pretrained(
        tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True
    )

    return model, tokenizer, generating_args

def process_input(input_data: Dict[str, Any], model, tokenizer, generating_args: "GeneratingArguments"):
    
    dataset = MolQADataset([input_data], tokenizer, generating_args.max_length)
    dataloader = DataLoader(
        dataset, batch_size=1, shuffle=False
    )

    gen_kwargs = generating_args.to_dict()
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
    gen_kwargs["logits_processor"] = get_logits_processor()

    return dataloader, gen_kwargs

def generate(model, dataloader, gen_kwargs):
    property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]

    for batch in dataloader:
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        property_data = batch["property"].to(model.device)

        model.eval()
        with torch.no_grad():
            all_info_dict = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                molecule_properties=property_data,
                do_molecular_design=True,
                do_retrosynthesis=True,
                expansion_topk=50,
                iterations=100,
                max_planning_time=30,
                rollback=True,
                **gen_kwargs,
            )

            assert len(all_info_dict["smiles_list"]) == 1

            for i in range(len(all_info_dict["smiles_list"])):
                llm_response = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
                result = {
                    "llm_smiles": all_info_dict["smiles_list"][i],
                    "property": {},
                }
                for j, prop_name in enumerate(property_names):
                    prop_value = property_data[i][j].item()
                    if not math.isnan(prop_value):
                        result["property"][prop_name] = prop_value

                retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
                result["llm_reactions"] = []
                if retro_plan["success"]:
                    for reaction, template, cost in zip(
                        retro_plan["reaction_list"],
                        retro_plan["templates"],
                        retro_plan["cost"],
                    ):
                        result["llm_reactions"].append(
                            {"reaction": reaction, "template": template, "cost": cost}
                        )
                result["llm_response"] = remove_extra_spaces(llm_response)
                return result