pmkhanh7890's picture
1st
22e1b62
raw
history blame
11.1 kB
import os
from config import (
CHATGPT,
GEMINI,
GEMINI_MODEL,
IS_OUTPUT_NORMALIZATION,
MODEL_PATHS,
OPENAI_MODEL,
TEMPERATURE,
TOGETHER_API_KEY,
TOGETHER_PATH,
)
from evaluation import extract_by_best_similarity
from openai import OpenAI
from utils import (
generate_column_names,
generate_file_name,
get_column,
normalize_text,
print_and_log,
read_csv_data,
write_new_data,
write_to_csv,
)
def abstract_proofread(model_path, temperature, base_url, api_key, prompt):
"""
Function to proofread an abstract using an AI language model.
Parameters:
model_path (str): The path or identifier of the AI model to use.
temperature (float): Sampling temperature for the model's output.
base_url (str): The base URL for the API endpoint.
api_key (str): The API key for authentication.
prompt (str): The text prompt to provide to the AI for proofreading.
Returns:
str: The proofread abstract generated by the AI model.
"""
# Initialize the AI client with the provided API key and base URL
ai_client = OpenAI(api_key=api_key, base_url=base_url)
# Create a chat completion request with the system message and user prompt
chat_completion = ai_client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an AI assistant",
},
{
"role": "user",
"content": prompt,
},
],
model=model_path,
max_tokens=1024,
temperature=temperature,
)
# Return the content of the first choice's message
return chat_completion.choices[0].message.content
def proofread_by_model_name(model_name, input_text, normalize_output):
"""
Proofreads the given input text using the specified model.
Args:
model_name (str): The name of the model to use for proofreading.
input_text (str): The text to be proofread.
normalize_output (bool): Whether to normalize the output or not.
Returns:
str: The proofread text.
"""
# Constants for API access
base_url = TOGETHER_PATH
api_key = TOGETHER_API_KEY
temperature = TEMPERATURE
# Retrieve the model path from the dictionary
if model_name in MODEL_PATHS:
model_path = MODEL_PATHS[model_name]
else:
raise ValueError("Model name not found in the dictionary.")
# Formulate the prompt for the model
prompt = f"Proofreading for the text: ```{input_text}```"
# Apply output normalization if required
if normalize_output:
prompt = output_normalization(prompt)
# Debugging: Print the prompt
print(f"Prompt: {prompt}")
# Call the abstract proofreading function with the prepared parameters
return abstract_proofread(
model_path,
temperature,
base_url,
api_key,
prompt,
)
def gemini_proofread(input_text, normalize_output):
"""
Proofreads the given text using the GEMINI_MODEL.
Parameters:
input_text (str): The text to be proofread.
normalize_output (bool): Flag indicating whether to normalize the output.
Returns:
str: The proofread text.
"""
prompt = f"Proofreading for the text: ```{input_text}```"
if normalize_output:
prompt = output_normalization(prompt)
response = GEMINI_MODEL.generate_content(prompt)
return response.text
def chatGPT_proofread(input_text, normalize_output):
"""
Proofreads the given text using the chat_model.
Parameters:
input_text (str): The text to be proofread.
normalize_output (bool): Flag indicating whether to normalize the output.
Returns:
str: The proofread text.
"""
prompt = f"Proofreading for the text: ```{input_text}```"
if normalize_output:
prompt = output_normalization(prompt)
print(f"Starting API call with prompt: {prompt}")
result = OPENAI_MODEL.predict(prompt)
print(f"Ending API call with prompt: {prompt}")
return result
def output_normalization(prompt):
"""
Normalizes the output by appending a specific instruction to the prompt.
Parameters:
prompt (str): The initial prompt.
Returns:
str: The modified prompt.
"""
return (
prompt
+ " Please only output the proofread text without any explanation."
)
def proofread_with_best_similarity(input_text, model_kind):
"""
Proofreads the input text using the specified model and extracts the
best-corrected text based on similarity.
Args:
input_text (str): The original text to be proofread.
model_kind (str): The kind of model to use for proofreading
(e.g., CHATGPT, GEMINI).
Returns:
tuple: A tuple containing the raw proofread text and the
best-corrected text.
"""
# Normalize the input text
normalized_input_text = normalize_text(input_text)
print_and_log(f"INPUT = {normalized_input_text}")
result_text = ""
raw_text = ""
for i in range(
1,
): # Loop is redundant as it runs only once;
# consider removing if unnecessary
# Select the proofreading model based on model_kind
if model_kind == CHATGPT:
raw_text = chatGPT_proofread(
normalized_input_text,
normalize_output=IS_OUTPUT_NORMALIZATION,
)
elif model_kind == GEMINI:
raw_text = gemini_proofread(
normalized_input_text,
normalize_output=IS_OUTPUT_NORMALIZATION,
)
else:
raw_text = proofread_by_model_name(
model_kind,
normalized_input_text,
normalize_output=IS_OUTPUT_NORMALIZATION,
)
# Extract the best candidate text based on similarity
result_text = extract_by_best_similarity(
normalized_input_text,
raw_text,
)
# Log the raw and result texts
print_and_log(f"RAW_{i} = {raw_text}")
print
# Normalize the result text
result_text = normalize_text(result_text)
# If a valid result is obtained, return it
if result_text != "":
return raw_text, result_text
# Return the raw and result texts
return raw_text, result_text
def generate_new_data_with_best_similarity(
existing_data_file,
existing_kinds,
new_kinds,
):
"""
Generates new data with the best similarity based on existing and new
kinds, and writes the results to a CSV file.
Args:
existing_data_file (str): The path to the existing data file.
existing_kinds (list): A list of existing kinds.
new_kinds (list): A list of new kinds.
Returns:
None
"""
# Combine existing and new kinds into a single list
all_kinds = existing_kinds + new_kinds
# Generate column names for the CSV file
column_names = generate_column_names(all_kinds)
# Generate column names for existing kinds
existing_column_names = generate_column_names(existing_kinds)
# Generate the output file name
output_file = generate_file_name(
existing_data_file,
existing_kinds,
new_kinds,
)
# Create the output file with column names if it doesn't exist
if not os.path.exists(output_file):
write_to_csv(output_file, column_names)
# Read existing data from the file
existing_data = {
kind: get_column(existing_data_file, kind)
for kind in existing_column_names
}
# Read input data from the output file
input_data = read_csv_data(output_file)
start_index = len(input_data)
print(f"start_index = {start_index}")
num_rows = len(existing_data["human"])
global_generate_set = []
global_reuse = []
for index in range(start_index, num_rows):
# Initialize generation and reuse sets
generate_set = []
reuse_set = []
# Prepare the current generation dictionary
current_generation = {
kind: existing_data[kind][index] for kind in existing_column_names
}
print(f"current_generation before generation = {current_generation}")
human_text = current_generation["human"]
# Generate new kinds based on human text
for kind in new_kinds:
_, generated_text = proofread_with_best_similarity(
human_text,
kind,
)
current_generation[kind] = generated_text
generate_set.append(kind)
print(f"current_generation after generate one = {current_generation}")
# Generate combinations of kinds
for first_kind in all_kinds:
for second_kind in all_kinds:
combination_name = f"{first_kind}_{second_kind}"
if combination_name not in current_generation:
if (
first_kind in current_generation
and current_generation[first_kind] == human_text
):
generated_text = current_generation[second_kind]
reuse_set.append(
f"{combination_name} from {second_kind}",
)
else:
is_need_generation = True
for first_kind_2 in all_kinds:
if (
first_kind != first_kind_2
and current_generation[first_kind]
== current_generation[first_kind_2]
):
combination_name_2 = (
f"{first_kind_2}_{second_kind}"
)
if combination_name_2 in current_generation:
generated_text = current_generation[
combination_name_2
]
reuse_set.append(
f"{combination_name} from {combination_name_2}", # noqa: E501
)
is_need_generation = False
break
if is_need_generation:
_, generated_text = proofread_with_best_similarity(
current_generation[first_kind],
second_kind,
)
generate_set.append(f"{first_kind}_{second_kind}")
current_generation[combination_name] = generated_text
# Write the current generation to the output file
write_new_data(output_file, current_generation, column_names)
# Update global sets
global_generate_set.append(generate_set)
global_reuse