NeMo / tools /customization_dataset_preparation /tests /test_customization_dataset_preparation.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
import pytest
from customization_dataset_preparation import (
convert_into_prompt_completion_only,
convert_into_template,
drop_duplicated_rows,
drop_unrequired_fields,
get_common_suffix,
get_prepared_filename,
parse_template,
recommend_hyperparameters,
show_first_example_in_df,
split_into_train_validation,
template_mapper,
validate_template,
warn_and_drop_long_samples,
warn_completion_is_not_empty,
warn_duplicated_rows,
warn_imbalanced_completion,
warn_low_n_samples,
warn_missing_suffix,
)
def test_recommend_hyperparameters():
df_100 = pd.DataFrame({'prompt': ['prompt'] * 100, 'completion': ['completion'] * 100})
assert recommend_hyperparameters(df_100) == "TODO: A batch_size=2 is recommended"
df_1000 = pd.DataFrame({'prompt': ['prompt'] * 1000, 'completion': ['completion'] * 1000})
assert recommend_hyperparameters(df_1000) == "TODO: A batch_size=2 is recommended"
df_10000 = pd.DataFrame({'prompt': ['prompt'] * 10000, 'completion': ['completion'] * 10000})
assert recommend_hyperparameters(df_10000) == "TODO: A batch_size=16 is recommended"
df_100000 = pd.DataFrame({'prompt': ['prompt'] * 100000, 'completion': ['completion'] * 100000})
assert recommend_hyperparameters(df_100000) == "TODO: A batch_size=128 is recommended"
def test_warn_completion_is_not_empty():
df_all_empty = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': [''] * 2})
msg_all_empty = (
"TODO: Note all completion fields are empty. This is possibly expected for inference but not for training"
)
assert warn_completion_is_not_empty(df_all_empty) == msg_all_empty
df_some_empty = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['', 'completion']})
msg_some_empty = f"""TODO: completion contains {1} empty values at rows ({[0]})
Please check the original file that the fields for prompt template are
not empty and rerun dataset validation"""
assert warn_completion_is_not_empty(df_some_empty) == msg_some_empty
df_no_empty = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['completion'] * 2})
assert warn_completion_is_not_empty(df_no_empty) is None
def test_warn_imbalanced_completion():
df_generation = pd.DataFrame(
{'prompt': [f'prompt{i}' for i in range(100)], 'completion': [f'completion{i}' for i in range(100)]}
)
assert warn_imbalanced_completion(df_generation) is None
df_classification_balanced = pd.DataFrame(
{'prompt': [f'prompt{i}' for i in range(100)], 'completion': [f'completion{i}' for i in range(5)] * 20}
)
msg_classification_balanced = (
f"There are {5} unique completions over {100} samples.\nThe five most common completions are:"
)
for i in range(5):
msg_classification_balanced += f"\n {20} samples ({20.0}%) with completion: completion{i}"
assert warn_imbalanced_completion(df_classification_balanced) == msg_classification_balanced
df_classification_imbalanced = pd.DataFrame(
{
'prompt': [f'prompt{i}' for i in range(100)],
'completion': ['completion0'] * 95 + [f'completion{i}' for i in range(5)],
}
)
msg_classification_imbalanced = (
f"There are {5} unique completions over {100} samples.\nThe five most common completions are:"
)
msg_classification_imbalanced += f"\n {96} samples ({96.0}%) with completion: completion0"
for i in range(1, 5):
msg_classification_imbalanced += f"\n {1} samples ({1.0}%) with completion: completion{i}"
assert warn_imbalanced_completion(df_classification_imbalanced) == msg_classification_imbalanced
def test_get_common_suffix():
df = pd.DataFrame(
{
'prompt': [f'prompt{i} answer:' for i in range(100)],
'completion': [f'completion{i}' for i in range(100)],
'empty_completion': [''] * 100,
'some_empty_completion': ['', 'completion'] * 50,
}
)
assert get_common_suffix(df.prompt) == " answer:"
assert get_common_suffix(df.completion) == ""
assert get_common_suffix(df.empty_completion) == ""
assert get_common_suffix(df.some_empty_completion) == ""
def test_warn_missing_suffix():
df_no_common = pd.DataFrame(
{'prompt': [f'prompt{i}' for i in range(100)], 'completion': [f'completion{i}' for i in range(100)],}
)
message = f"TODO: prompt does not have common suffix, please add one (e.g. \\n) at the end of prompt_template\n"
message += (
f"TODO: completion does not have common suffix, please add one (e.g. \\n) at the end of completion_template\n"
)
assert warn_missing_suffix(df_no_common) == message
df_common = pd.DataFrame(
{'prompt': [f'prompt{i} answer:' for i in range(100)], 'completion': [f'completion{i}\n' for i in range(100)],}
)
assert warn_missing_suffix(df_common) is None
def test_parse_template():
template_qa_prompt = "Context: {context}, Question: {question} Answer:"
template_qa_completion = "{answer}"
template_prompt = "{prompt}"
template_completion = "{completion}"
assert parse_template(template_qa_prompt) == ['context', 'question']
assert parse_template(template_qa_completion) == ['answer']
assert parse_template(template_prompt) == ['prompt']
assert parse_template(template_completion) == ['completion']
def test_validate_template():
template = "{prompt}"
template_missing_left = "prompt}"
template_missing_right = "{prompt"
template_twice = "{{prompt}}"
template_enclosed = "{prompt{enclosed}}"
assert validate_template(template) is None
with pytest.raises(ValueError):
validate_template(template_missing_left)
with pytest.raises(ValueError):
validate_template(template_missing_right)
with pytest.raises(ValueError):
validate_template(template_twice)
with pytest.raises(ValueError):
validate_template(template_enclosed)
def test_warn_duplicated_rows():
df_duplicated = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['completion'] * 2})
message_duplicated = f"TODO: There are {1} duplicated rows "
message_duplicated += f"at rows ([1]) \n"
message_duplicated += "Please check the original file to make sure that is expected\n"
message_duplicated += "If it is not, please add the argument --drop_duplicate"
assert warn_duplicated_rows(df_duplicated) == message_duplicated
df_unique = pd.DataFrame({'prompt': ['prompt', 'prompt1'], 'completion': ['completion', 'completion1']})
assert warn_duplicated_rows(df_unique) is None
df_only_prompt_duplicated = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['completion', 'completion1']})
assert warn_duplicated_rows(df_only_prompt_duplicated) is None
def test_drop_duplicated_rows():
df_deduplicated = pd.DataFrame({'prompt': ['prompt'], 'completion': ['completion']})
df_duplicated = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['completion'] * 2})
message_duplicated = "There are 1 duplicated rows\n"
message_duplicated += "Removed 1 duplicate rows"
assert drop_duplicated_rows(df_duplicated)[0].equals(df_deduplicated)
assert drop_duplicated_rows(df_duplicated)[1] == message_duplicated
df_unique = pd.DataFrame({'prompt': ['prompt', 'prompt1'], 'completion': ['completion', 'completion1']})
assert drop_duplicated_rows(df_unique) == (df_unique, None)
df_only_prompt_duplicated = pd.DataFrame({'prompt': ['prompt'] * 2, 'completion': ['completion', 'completion1']})
assert drop_duplicated_rows(df_only_prompt_duplicated) == (df_only_prompt_duplicated, None)
def test_template_mapper():
df = pd.DataFrame({'prompt': ['prompt sample'],})
template = "{prompt}"
field_names = ['prompt']
assert template_mapper(df.iloc[0], field_names, template) == 'prompt sample'
df_qa = pd.DataFrame({'question': ['question sample'], 'context': ['context sample']})
template_qa = "Context: {context} Question: {question} Answer:"
field_names_qa = ['context', 'question']
assert (
template_mapper(df_qa.iloc[0], field_names_qa, template_qa)
== "Context: context sample Question: question sample Answer:"
)
def test_drop_unrequired_fields():
df = pd.DataFrame(
{'question': ['question'], 'context': ['context'], 'prompt': ['prompt'], 'completion': ['completion']}
)
df_dropped_unnecessary_fields = pd.DataFrame({'prompt': ['prompt'], 'completion': ['completion']})
assert df_dropped_unnecessary_fields.equals(drop_unrequired_fields(df))
def test_convert_into_template():
df_non_existant_field_name = pd.DataFrame({'question': ['question']})
template = "Context: {context} Question: {question} Answer:"
with pytest.raises(ValueError):
convert_into_template(df_non_existant_field_name, template)
df = pd.DataFrame({'question': ['question sample'], 'context': ['context sample'],})
df_prompt = pd.DataFrame(
{
'question': ['question sample'],
'context': ['context sample'],
'prompt': ["Context: context sample Question: question sample Answer:"],
}
)
assert convert_into_template(df, template).equals(df_prompt)
def test_convert_into_prompt_completion_only():
df = pd.DataFrame({'question': ['question sample'], 'context': ['context sample'], 'answer': ['answer sample']})
df_prompt = pd.DataFrame(
{'prompt': ["Context: context sample Question: question sample Answer:"], 'completion': ["answer sample"]}
)
prompt_template = "Context: {context} Question: {question} Answer:"
completion_template = "{answer}"
assert df_prompt.equals(
convert_into_prompt_completion_only(
df, prompt_template=prompt_template, completion_template=completion_template
)
)
assert df_prompt.equals(convert_into_prompt_completion_only(df_prompt))
def get_indexes_of_long_examples(df, max_total_char_length):
long_examples = df.apply(lambda x: len(x.prompt) + len(x.completion) > max_total_char_length, axis=1)
return df.reset_index().index[long_examples].tolist()
def test_warn_and_drop_long_samples():
df = pd.DataFrame({'prompt': ['a' * 12000, 'a' * 9000, 'a'], 'completion': ['b' * 12000, 'b' * 2000, 'b']})
expected_df = pd.DataFrame({'prompt': ['a'], 'completion': ['b']})
message = f"""TODO: There are {2} / {3}
samples that have its prompt and completion too long
(over {10000} chars), which have been dropped.
If this proportion is too high, please prepare data again using the flag
--long_seq_model for use with a model with longer context length of 8,000 tokens"""
assert expected_df.equals(warn_and_drop_long_samples(df, 10000)[0])
assert warn_and_drop_long_samples(df, 10000)[1] == message
df_short = pd.DataFrame({'prompt': ['a'] * 2, 'completion': ['b'] * 2})
assert warn_and_drop_long_samples(df_short, 10000) == (df_short, None)
def test_warn_low_n_samples():
df_low = pd.DataFrame({'prompt': ['a'] * 10, 'completion': ['b'] * 10})
df_high = pd.DataFrame({'prompt': ['a'] * 100, 'completion': ['b'] * 100})
message = (
"TODO: We would recommend having more samples (>64) if possible but current_file only contains 10 samples. "
)
assert warn_low_n_samples(df_low) == message
assert warn_low_n_samples(df_high) is None
def test_show_first_example_in_df():
df = pd.DataFrame({'question': ['question sample'], 'context': ['context sample'], 'answer': ['answer sample']})
message = f"-->Column question:\nquestion sample\n"
message += f"-->Column context:\ncontext sample\n"
message += f"-->Column answer:\nanswer sample\n"
assert message == show_first_example_in_df(df)
def test_get_prepared_filename():
filename = "tmp/sample.jsonl"
prepared_filename = "tmp/sample_prepared.jsonl"
prepared_train_filename = "tmp/sample_prepared_train.jsonl"
prepared_val_filename = "tmp/sample_prepared_val.jsonl"
assert get_prepared_filename(filename) == prepared_filename
assert get_prepared_filename(filename, split_train_validation=True) == [
prepared_train_filename,
prepared_val_filename,
]
csv_filename = "tmp/sample.csv"
prepared_filename = "tmp/sample_prepared.jsonl"
assert get_prepared_filename(csv_filename) == prepared_filename
def test_split_into_train_validation():
df = pd.DataFrame({'prompt': ['a'] * 10, 'completion': ['b'] * 10})
df_train, df_val = split_into_train_validation(df, val_proportion=0.1)
assert len(df_train) == 9
assert len(df_val) == 1
df_train, df_val = split_into_train_validation(df, val_proportion=0.2)
assert len(df_train) == 8
assert len(df_val) == 2