import re from datasets import load_dataset, Dataset, DatasetDict from itertools import chain from tqdm import tqdm from collections import Counter from accelerate import Accelerator LANGUAGES_TO_DECODE_FROM_BYTES = ["he", "fr", "uk"] STREAMING_DATASETS = ["fineweb-edu"] def load_pg19_val_and_test(): # Load the dataset in streaming mode streaming_dataset = load_dataset("deepmind/pg19", split=None, streaming=True) # Extract test and validation splits test_split = list(streaming_dataset["test"]) validation_split = list(streaming_dataset["validation"]) # Convert them into regular datasets test_dataset = Dataset.from_list(test_split) validation_dataset = Dataset.from_list(validation_split) # validation_dataset = load_dataset("deepmind/pg19", split="validation") # test_dataset = load_dataset("deepmind/pg19", split="test") return DatasetDict({"validation": validation_dataset, "test": test_dataset}) def load_pubmed(n_samples=10000): # Load the dataset in streaming mode streaming_dataset = load_dataset("MedRAG/pubmed", streaming=True) # Extract test and validation splits data = list(streaming_dataset["train"].take(n_samples*4)) train = data[:2*n_samples] validation = data[2*n_samples:3*n_samples] test = data[3*n_samples:] # Convert them into regular datasets train = Dataset.from_list(train) validation = Dataset.from_list(validation) test = Dataset.from_list(test) dataset = DatasetDict({"train": train, 'validation': validation, 'test': test}) dataset = dataset.rename_column('content', 'text') return dataset def load_lm_dataset(dataset_name, language="en", split=None): """ Loads a popular pretraining or perplexity evaluation dataset by name and language. Args: dataset_name (str): The name of the dataset to load. Options include: - 'wikitext' (wikitext-2, smaller WikiText dataset) - 'wikitext-103' (larger WikiText dataset) - 'pg19' (Project Gutenberg dataset for long-context modeling) - 'c4' (Common Crawl-based English corpus) - 'wiki40b' (Wikipedia dataset in multiple languages) - 'mc4' (Multilingual C4 dataset in various languages) language (str): Language code for datasets that support multilingual options (e.g., 'en' for English). Defaults to 'en'. Returns: Dataset: Loaded Hugging Face dataset. """ if dataset_name.lower() == 'wikitext': return load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split) elif dataset_name.lower() == 'fineweb-edu': return load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT") elif dataset_name.lower() == 'wikitext-103': return load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split) elif dataset_name.lower() == 'cord19': return load_dataset("allenai/cord19", "fulltext", trust_remote_code=True) elif dataset_name.lower() == 'pubmed': return load_pubmed() elif dataset_name.lower() == 'wikilingua': dataset = load_dataset("GEM/wiki_lingua", trust_remote_code=True) dataset = dataset.filter(lambda ex: (ex['source_language'] == "en") & (ex['target_language'] == "en")) dataset = dataset.rename_column("source", "text") dataset = dataset.rename_column("target", "summary") return dataset elif dataset_name.lower() == 'xsum': dataset = load_dataset("EdinburghNLP/xsum") dataset = dataset.rename_column("document", "text") return dataset elif dataset_name.lower() == 'cnn': dataset = load_dataset("abisee/cnn_dailymail", "3.0.0") dataset = dataset.rename_column("article", "text") dataset = dataset.rename_column("highlights", "summary") dataset = dataset.map(lambda example: {"text": example["text"].replace("(CNN)", "")}) return dataset elif dataset_name.lower() == 'pg19': return load_pg19_val_and_test() elif dataset_name.lower() == 'wiki40b': dataset = load_dataset("google/wiki40b", language, split=split) if language in LANGUAGES_TO_DECODE_FROM_BYTES: dataset = dataset.map(lambda x: { "text": bytes(x["text"][2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8").replace("_NEWLINE_", "\n") }) return dataset else: raise ValueError( "Dataset not recognized. Available options: 'wikitext-2', 'wikitext-103', 'pg19', 'c4', 'wiki40b', 'mc4'.") def extract_new_words_from_dataset( dataset: Dataset, tokenizer, text_column: str = "text", max_samples: int = None, filter_func=(lambda word, token_count: True)): """ Loads a Hugging Face dataset and extracts all unique words from the specified text column. Args: dataset (Dataset): Name of the dataset to load. split (str): Dataset split to use, typically 'train' for training data. Defaults to 'train'. text_column (str): The column in the dataset containing text. Defaults to 'text'. max_samples (int): Number of samples from the dataset to go over. Returns: set: A set of unique words in the dataset. """ if max_samples: dataset = dataset.select(range(max_samples)) # Regular expression to split text into words (adjust as needed for specific languages) # word_pattern = re.compile(r"\b\w+\b") word_pattern = re.compile(r"\b\w+(?:[-']\w+)*\b") # Iterate over each entry in the dataset and extract unique words all_words = list() new_words = list() for record in tqdm(dataset, total=len(dataset), miniters=10, desc="Extracting all words from dataset...", unit="examples"): text = record.get(text_column, "") words = word_pattern.findall(text) all_words += words # all_words = list(dict.fromkeys(all_words)) word_frequencies = Counter(all_words) all_words = list(word_frequencies.keys()) token_counts = [len(x) for x in tokenizer(all_words, add_special_tokens=False)["input_ids"]] w_whitespace_token_counts = [len(x) for x in tokenizer([f" {w}" for w in all_words], add_special_tokens=False)["input_ids"]] new_words = [word for word, count, w_whitespace_count in zip(all_words, token_counts, w_whitespace_token_counts) if ((count > 1) and (w_whitespace_count > 1) and filter_func(word, count))] new_words_freq = {word: word_frequencies[word] for word in new_words} # for word, token_count in tqdm(all_words, total=len(all_words), miniters=10, desc="Finding new words...", unit="words"): # if (not tokenizer.vocab.get(word, False)) and : # new_words.append(word) # remove duplicates and return return new_words, new_words_freq def get_group_texts_func(block_size=1024): def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result return group_texts def get_tokenize_func(tokenizer, text_col_name): def _tokenize(examples): output = tokenizer( examples[text_col_name], return_token_type_ids=False, add_special_tokens=False, ) return output return _tokenize def tokenize_and_prepare_dataset( dataset, tokenizer, accelerator=None, text_col_name: str = "text", max_length: int = 256, eval_max_samples: int = None, ): if tokenizer.bos_token is not None and max_length: # leave room for token to be added: max_tokenized_len = max_length - 1 else: max_tokenized_len = max_length tokenize_function = get_tokenize_func(tokenizer, text_col_name) column_names = dataset.column_names tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=column_names, load_from_cache_file=False, desc="Running tokenizer on dataset", ) group_texts = get_group_texts_func(block_size=max_tokenized_len) lm_dataset = tokenized_dataset.map( group_texts, batched=True, ) if eval_max_samples: lm_dataset = lm_dataset.select(range(eval_max_samples)) return lm_dataset