import os import gc import random import warnings warnings.filterwarnings('ignore') import numpy as np import pandas as pd import torch import tokenizers import transformers from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSeq2SeqLM import sentencepiece from rdkit import Chem import rdkit import streamlit as st st.title('predictproduct-t5') st.text('At this space, you can predict the products of reactions from their inputs.') st.text('The format of the string is like "REACTANT:{reactants of the reaction}CATALYST:{catalysts of the reaction}REAGENT:{reagents of the reaction}SOLVENT:{solvent of the reaction}".') st.text('If there are no catalyst or reagent, fill the blank with a space. And if there are multiple reactants, concatenate them with "."') display_text = 'input the reaction smiles (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)' class CFG(): input_data = st.text_area(display_text) model_name_or_path = 'sagawa/ZINC-t5-productpredicition' model = 't5' num_beams = 5 num_return_sequences = 5 seed = 42 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def seed_everything(seed=42): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True seed_everything(seed=CFG.seed) tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt') if CFG.model == 't5': model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) elif CFG.model == 'deberta': model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device) input_compound = CFG.input_data min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) inp = tokenizer(input_compound, return_tensors='pt').to(device) output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) scores = output['sequences_scores'].tolist() output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']] for ith, out in enumerate(output): mol = Chem.MolFromSmiles(out.rstrip('.')) if type(mol) == rdkit.Chem.rdchem.Mol: output.append(out.rstrip('.')) scores.append(scores[ith]) break if type(mol) == None: output.append(None) scores.append(None) output += scores output = [input_compound] + output output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) st.table(output_df)