|
import datetime |
|
import os |
|
import subprocess |
|
import sys |
|
from timeit import default_timer as timer |
|
|
|
sys.path.insert(0, os.getcwd()) |
|
|
|
start = timer() |
|
|
|
result_filename_prefix = os.getenv("RESULT_FILENAME_PREFIX") |
|
if not result_filename_prefix: |
|
now = datetime.datetime.now() |
|
result_filename_prefix = "Tune_{:%Y-%m-%d_%H-%M-%S}".format(now) |
|
|
|
print(f"Result filename prefix: {result_filename_prefix}") |
|
all_results_filename = f"./data/results/{result_filename_prefix}.csv" |
|
|
|
repetition_penalty_delta = 0.02 |
|
repetition_penalty_end = 1.3 |
|
repetition_penalty = 1.0 |
|
|
|
repetition_penalty_start = os.getenv("REPETITION_PENALTY_START") |
|
if repetition_penalty_start: |
|
repetition_penalty = float(repetition_penalty_start) |
|
print(f"Starting from RP: {repetition_penalty}") |
|
|
|
while repetition_penalty <= repetition_penalty_end + 1e-5: |
|
new_env = os.environ.copy() |
|
|
|
repetition_penalty_str = f"{repetition_penalty:.3f}" |
|
new_env["HFTGI_RP"] = repetition_penalty_str |
|
new_env["HF_RP"] = repetition_penalty_str |
|
new_env["ML_RP"] = repetition_penalty_str |
|
new_env["SL_RP"] = repetition_penalty_str |
|
|
|
log_file = "./data/logs/{}_RP_{}.txt".format( |
|
result_filename_prefix, repetition_penalty_str |
|
) |
|
test_results_filename = "./data/results/{}_RP_{}.csv".format( |
|
result_filename_prefix, repetition_penalty_str |
|
) |
|
new_env["TEST_RESULTS_CSV_FILE"] = test_results_filename |
|
new_env["ALL_RESULTS_CSV_FILE"] = all_results_filename |
|
|
|
num_questions = os.getenv("NUM_QUESTIONS") or "" |
|
|
|
with open(log_file, "w") as f_obj: |
|
subprocess.run( |
|
f"python qa_chain_test.py {num_questions}", |
|
shell=True, |
|
env=new_env, |
|
stdout=f_obj, |
|
text=True, |
|
) |
|
|
|
repetition_penalty += repetition_penalty_delta |
|
|
|
print(f"All results saved to {all_results_filename}") |
|
|
|
end = timer() |
|
total_time = end - start |
|
print(f"Total time used: {total_time:.3f} s") |
|
|