import torch from torch.nn import Linear from torch_geometric.nn import HGTConv, MLP import pandas as pd import yaml import os from datasets import load_dataset import gdown import copy class ProtHGT(torch.nn.Module): def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in data.node_types: input_dim = data[node_type].x.size(1) # Get actual input dimension from data self.lin_dict[node_type] = Linear(input_dim, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') self.convs.append(conv) self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None) def generate_embeddings(self, x_dict, edge_index_dict): # Generate updated embeddings through the HGT layers x_dict = { node_type: self.lin_dict[node_type](x).relu_() for node_type, x in x_dict.items() } for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return x_dict def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False): # Get updated embeddings x_dict = self.generate_embeddings(x_dict, edge_index_dict) # Make predictions row, col = tr_edge_label_index z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1) return self.mlp(z).view(-1), x_dict def _load_data(heterodata, protein_ids, go_category): """Process the loaded heterodata for specific proteins and GO categories.""" # Get protein indices for all input proteins protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids] n_terms = len(heterodata[go_category]['id_mapping']) all_edges = [] for protein_idx in protein_indices: for term_idx in range(n_terms): all_edges.append([protein_idx, term_idx]) edge_index = torch.tensor(all_edges).t() heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]]) return heterodata def get_available_proteins(protein_list_file='data/available_proteins.txt'): with open(protein_list_file, 'r') as file: return [line.strip() for line in file.readlines()] def _generate_predictions(heterodata, model, target_type): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() heterodata = heterodata.to(device) with torch.no_grad(): edge_label_index = heterodata.edge_index_dict[('Protein', 'protein_function', target_type)] predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type) predictions = torch.sigmoid(predictions) return predictions.cpu() def _create_prediction_df(predictions, heterodata, protein_ids, go_category): go_category_dict = { 'GO_term_F': 'Molecular Function', 'GO_term_P': 'Biological Process', 'GO_term_C': 'Cellular Component' } # Get number of GO terms for this category n_go_terms = len(heterodata[go_category]['id_mapping']) # Create lists to store the data all_proteins = [] all_go_terms = [] all_categories = [] all_probabilities = [] # Get list of GO terms once go_terms = list(heterodata[go_category]['id_mapping'].keys()) # Process predictions for each protein for i, protein_id in enumerate(protein_ids): # Get predictions for this protein start_idx = i * n_go_terms end_idx = (i + 1) * n_go_terms protein_predictions = predictions[start_idx:end_idx] # Extend the lists all_proteins.extend([protein_id] * n_go_terms) all_go_terms.extend(go_terms) all_categories.extend([go_category_dict[go_category]] * n_go_terms) all_probabilities.extend(protein_predictions.tolist()) # Create DataFrame prediction_df = pd.DataFrame({ 'Protein': all_proteins, 'GO_term': all_go_terms, 'GO_category': all_categories, 'Probability': all_probabilities }) return prediction_df def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category): all_predictions = [] # Convert single protein ID to list if necessary if isinstance(protein_ids, str): protein_ids = [protein_ids] # Load dataset once # heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz") print('Loading data...') file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK" output = "data/prothgt-kg.pt" if not os.path.exists(output): try: url = f"https://drive.google.com/uc?id={file_id}" print(f"Downloading file from {url}...") except Exception as e: print(f"Error downloading file: {e}") raise else: print(f"File already exists at {output}") heterodata = torch.load(output) # Remove unnecessary edge types edge_types_to_remove = [ ('Protein', 'protein_function', 'GO_term_F'), ('Protein', 'protein_function', 'GO_term_P'), ('Protein', 'protein_function', 'GO_term_C'), ('GO_term_F', 'rev_protein_function', 'Protein'), ('GO_term_P', 'rev_protein_function', 'Protein'), ('GO_term_C', 'rev_protein_function', 'Protein') ] for edge_type in edge_types_to_remove: if edge_type in heterodata.edge_index_dict: del heterodata.edge_index_dict[edge_type] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths): print(f'Generating predictions for {go_cat}...') # Process data for current GO category processed_data = _load_data(copy.deepcopy(heterodata), protein_ids, go_cat) # Load model config with open(model_config_path, 'r') as file: model_config = yaml.safe_load(file) # Initialize model with configuration model = ProtHGT( processed_data, hidden_channels=model_config['hidden_channels'][0], num_heads=model_config['num_heads'], num_layers=model_config['num_layers'], mlp_hidden_layers=model_config['hidden_channels'][1], mlp_dropout=model_config['mlp_dropout'] ) # Load model weights model.load_state_dict(torch.load(model_path, map_location=device)) print(f'Loaded model weights from {model_path}') # Generate predictions predictions = _generate_predictions(processed_data, model, go_cat) prediction_df = _create_prediction_df(predictions, processed_data, protein_ids, go_cat) all_predictions.append(prediction_df) # Clean up memory del processed_data del model del predictions torch.cuda.empty_cache() # Clear CUDA cache if using GPU del heterodata # Combine all predictions final_df = pd.concat(all_predictions, ignore_index=True) # Clean up del all_predictions torch.cuda.empty_cache() return final_df