Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import subprocess | |
import threading | |
from pathlib import Path | |
import numpy as np | |
import torch | |
def fasta_file_path(prefix_path): | |
return prefix_path + ".fasta" | |
class FastaDataset(torch.utils.data.Dataset): | |
""" | |
For loading protein sequence datasets in the common FASTA data format | |
""" | |
def __init__(self, path: str, cache_indices=False): | |
self.fn = fasta_file_path(path) | |
self.threadlocal = threading.local() | |
self.cache = Path(f"{path}.fasta.idx.npy") | |
if cache_indices: | |
if self.cache.exists(): | |
self.offsets, self.sizes = np.load(self.cache) | |
else: | |
self.offsets, self.sizes = self._build_index(path) | |
np.save(self.cache, np.stack([self.offsets, self.sizes])) | |
else: | |
self.offsets, self.sizes = self._build_index(path) | |
def _get_file(self): | |
if not hasattr(self.threadlocal, "f"): | |
self.threadlocal.f = open(self.fn, "r") | |
return self.threadlocal.f | |
def __getitem__(self, idx): | |
f = self._get_file() | |
f.seek(self.offsets[idx]) | |
desc = f.readline().strip() | |
line = f.readline() | |
seq = "" | |
while line != "" and line[0] != ">": | |
seq += line.strip() | |
line = f.readline() | |
return desc, seq | |
def __len__(self): | |
return self.offsets.size | |
def _build_index(self, path: str): | |
# Use grep and awk to get 100M/s on local SSD. | |
# Should process your enormous 100G fasta in ~10 min single core... | |
path = fasta_file_path(path) | |
bytes_offsets = subprocess.check_output( | |
f"cat {path} | tqdm --bytes --total $(wc -c < {path})" | |
"| grep --byte-offset '^>' -o | cut -d: -f1", | |
shell=True, | |
) | |
fasta_lengths = subprocess.check_output( | |
f"cat {path} | tqdm --bytes --total $(wc -c < {path})" | |
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", | |
shell=True, | |
) | |
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") | |
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") | |
return bytes_np, sizes_np | |
def __setstate__(self, state): | |
self.__dict__ = state | |
self.threadlocal = threading.local() | |
def __getstate__(self): | |
d = {} | |
for i, v in self.__dict__.items(): | |
if i != "threadlocal": | |
d[i] = v | |
return d | |
def __del__(self): | |
if hasattr(self.threadlocal, "f"): | |
self.threadlocal.f.close() | |
del self.threadlocal.f | |
def exists(path): | |
return os.path.exists(fasta_file_path(path)) | |
class EncodedFastaDataset(FastaDataset): | |
""" | |
The FastaDataset returns raw sequences - this allows us to return | |
indices with a dictionary instead. | |
""" | |
def __init__(self, path, dictionary): | |
super().__init__(path, cache_indices=True) | |
self.dictionary = dictionary | |
def __getitem__(self, idx): | |
desc, seq = super().__getitem__(idx) | |
return self.dictionary.encode_line(seq, line_tokenizer=list).long() | |