Spaces:
Sleeping
Sleeping
import os | |
from config import ( | |
CHATGPT, | |
GEMINI, | |
GEMINI_MODEL, | |
IS_OUTPUT_NORMALIZATION, | |
MODEL_PATHS, | |
OPENAI_MODEL, | |
TEMPERATURE, | |
TOGETHER_API_KEY, | |
TOGETHER_PATH, | |
) | |
from evaluation import extract_by_best_similarity | |
from openai import OpenAI | |
from utils import ( | |
generate_column_names, | |
generate_file_name, | |
get_column, | |
normalize_text, | |
print_and_log, | |
read_csv_data, | |
write_new_data, | |
write_to_csv, | |
) | |
def abstract_proofread(model_path, temperature, base_url, api_key, prompt): | |
""" | |
Function to proofread an abstract using an AI language model. | |
Parameters: | |
model_path (str): The path or identifier of the AI model to use. | |
temperature (float): Sampling temperature for the model's output. | |
base_url (str): The base URL for the API endpoint. | |
api_key (str): The API key for authentication. | |
prompt (str): The text prompt to provide to the AI for proofreading. | |
Returns: | |
str: The proofread abstract generated by the AI model. | |
""" | |
# Initialize the AI client with the provided API key and base URL | |
ai_client = OpenAI(api_key=api_key, base_url=base_url) | |
# Create a chat completion request with the system message and user prompt | |
chat_completion = ai_client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are an AI assistant", | |
}, | |
{ | |
"role": "user", | |
"content": prompt, | |
}, | |
], | |
model=model_path, | |
max_tokens=1024, | |
temperature=temperature, | |
) | |
# Return the content of the first choice's message | |
return chat_completion.choices[0].message.content | |
def proofread_by_model_name(model_name, input_text, normalize_output): | |
""" | |
Proofreads the given input text using the specified model. | |
Args: | |
model_name (str): The name of the model to use for proofreading. | |
input_text (str): The text to be proofread. | |
normalize_output (bool): Whether to normalize the output or not. | |
Returns: | |
str: The proofread text. | |
""" | |
# Constants for API access | |
base_url = TOGETHER_PATH | |
api_key = TOGETHER_API_KEY | |
temperature = TEMPERATURE | |
# Retrieve the model path from the dictionary | |
if model_name in MODEL_PATHS: | |
model_path = MODEL_PATHS[model_name] | |
else: | |
raise ValueError("Model name not found in the dictionary.") | |
# Formulate the prompt for the model | |
prompt = f"Proofreading for the text: ```{input_text}```" | |
# Apply output normalization if required | |
if normalize_output: | |
prompt = output_normalization(prompt) | |
# Debugging: Print the prompt | |
print(f"Prompt: {prompt}") | |
# Call the abstract proofreading function with the prepared parameters | |
return abstract_proofread( | |
model_path, | |
temperature, | |
base_url, | |
api_key, | |
prompt, | |
) | |
def gemini_proofread(input_text, normalize_output): | |
""" | |
Proofreads the given text using the GEMINI_MODEL. | |
Parameters: | |
input_text (str): The text to be proofread. | |
normalize_output (bool): Flag indicating whether to normalize the output. | |
Returns: | |
str: The proofread text. | |
""" | |
prompt = f"Proofreading for the text: ```{input_text}```" | |
if normalize_output: | |
prompt = output_normalization(prompt) | |
response = GEMINI_MODEL.generate_content(prompt) | |
return response.text | |
def chatGPT_proofread(input_text, normalize_output): | |
""" | |
Proofreads the given text using the chat_model. | |
Parameters: | |
input_text (str): The text to be proofread. | |
normalize_output (bool): Flag indicating whether to normalize the output. | |
Returns: | |
str: The proofread text. | |
""" | |
prompt = f"Proofreading for the text: ```{input_text}```" | |
if normalize_output: | |
prompt = output_normalization(prompt) | |
print(f"Starting API call with prompt: {prompt}") | |
result = OPENAI_MODEL.predict(prompt) | |
print(f"Ending API call with prompt: {prompt}") | |
return result | |
def output_normalization(prompt): | |
""" | |
Normalizes the output by appending a specific instruction to the prompt. | |
Parameters: | |
prompt (str): The initial prompt. | |
Returns: | |
str: The modified prompt. | |
""" | |
return ( | |
prompt | |
+ " Please only output the proofread text without any explanation." | |
) | |
def proofread_with_best_similarity(input_text, model_kind): | |
""" | |
Proofreads the input text using the specified model and extracts the | |
best-corrected text based on similarity. | |
Args: | |
input_text (str): The original text to be proofread. | |
model_kind (str): The kind of model to use for proofreading | |
(e.g., CHATGPT, GEMINI). | |
Returns: | |
tuple: A tuple containing the raw proofread text and the | |
best-corrected text. | |
""" | |
# Normalize the input text | |
normalized_input_text = normalize_text(input_text) | |
print_and_log(f"INPUT = {normalized_input_text}") | |
result_text = "" | |
raw_text = "" | |
for i in range( | |
1, | |
): # Loop is redundant as it runs only once; | |
# consider removing if unnecessary | |
# Select the proofreading model based on model_kind | |
if model_kind == CHATGPT: | |
raw_text = chatGPT_proofread( | |
normalized_input_text, | |
normalize_output=IS_OUTPUT_NORMALIZATION, | |
) | |
elif model_kind == GEMINI: | |
raw_text = gemini_proofread( | |
normalized_input_text, | |
normalize_output=IS_OUTPUT_NORMALIZATION, | |
) | |
else: | |
raw_text = proofread_by_model_name( | |
model_kind, | |
normalized_input_text, | |
normalize_output=IS_OUTPUT_NORMALIZATION, | |
) | |
# Extract the best candidate text based on similarity | |
result_text = extract_by_best_similarity( | |
normalized_input_text, | |
raw_text, | |
) | |
# Log the raw and result texts | |
print_and_log(f"RAW_{i} = {raw_text}") | |
# Normalize the result text | |
result_text = normalize_text(result_text) | |
# If a valid result is obtained, return it | |
if result_text != "": | |
return raw_text, result_text | |
# Return the raw and result texts | |
return raw_text, result_text | |
def generate_new_data_with_best_similarity( | |
existing_data_file, | |
existing_kinds, | |
new_kinds, | |
): | |
""" | |
Generates new data with the best similarity based on existing and new | |
kinds, and writes the results to a CSV file. | |
Args: | |
existing_data_file (str): The path to the existing data file. | |
existing_kinds (list): A list of existing kinds. | |
new_kinds (list): A list of new kinds. | |
Returns: | |
None | |
""" | |
# Combine existing and new kinds into a single list | |
all_kinds = existing_kinds + new_kinds | |
# Generate column names for the CSV file | |
column_names = generate_column_names(all_kinds) | |
# Generate column names for existing kinds | |
existing_column_names = generate_column_names(existing_kinds) | |
# Generate the output file name | |
output_file = generate_file_name( | |
existing_data_file, | |
existing_kinds, | |
new_kinds, | |
) | |
# Create the output file with column names if it doesn't exist | |
if not os.path.exists(output_file): | |
write_to_csv(output_file, column_names) | |
# Read existing data from the file | |
existing_data = { | |
kind: get_column(existing_data_file, kind) | |
for kind in existing_column_names | |
} | |
# Read input data from the output file | |
input_data = read_csv_data(output_file) | |
start_index = len(input_data) | |
print(f"start_index = {start_index}") | |
num_rows = len(existing_data["human"]) | |
global_generate_set = [] | |
global_reuse = [] | |
for index in range(start_index, num_rows): | |
# Initialize generation and reuse sets | |
generate_set = [] | |
reuse_set = [] | |
# Prepare the current generation dictionary | |
current_generation = { | |
kind: existing_data[kind][index] for kind in existing_column_names | |
} | |
print(f"current_generation before generation = {current_generation}") | |
human_text = current_generation["human"] | |
# Generate new kinds based on human text | |
for kind in new_kinds: | |
_, generated_text = proofread_with_best_similarity( | |
human_text, | |
kind, | |
) | |
current_generation[kind] = generated_text | |
generate_set.append(kind) | |
print(f"current_generation after generate one = {current_generation}") | |
# Generate combinations of kinds | |
for first_kind in all_kinds: | |
for second_kind in all_kinds: | |
combination_name = f"{first_kind}_{second_kind}" | |
if combination_name not in current_generation: | |
if ( | |
first_kind in current_generation | |
and current_generation[first_kind] == human_text | |
): | |
generated_text = current_generation[second_kind] | |
reuse_set.append( | |
f"{combination_name} from {second_kind}", | |
) | |
else: | |
is_need_generation = True | |
for first_kind_2 in all_kinds: | |
if ( | |
first_kind != first_kind_2 | |
and current_generation[first_kind] | |
== current_generation[first_kind_2] | |
): | |
combination_name_2 = ( | |
f"{first_kind_2}_{second_kind}" | |
) | |
if combination_name_2 in current_generation: | |
generated_text = current_generation[ | |
combination_name_2 | |
] | |
reuse_set.append( | |
f"{combination_name} from {combination_name_2}", # noqa: E501 | |
) | |
is_need_generation = False | |
break | |
if is_need_generation: | |
_, generated_text = proofread_with_best_similarity( | |
current_generation[first_kind], | |
second_kind, | |
) | |
generate_set.append(f"{first_kind}_{second_kind}") | |
current_generation[combination_name] = generated_text | |
# Write the current generation to the output file | |
write_new_data(output_file, current_generation, column_names) | |
# Update global sets | |
global_generate_set.append(generate_set) | |
global_reuse | |