import os from config import ( CHATGPT, GEMINI, GEMINI_MODEL, IS_OUTPUT_NORMALIZATION, MODEL_PATHS, OPENAI_MODEL, TEMPERATURE, TOGETHER_API_KEY, TOGETHER_PATH, ) from evaluation import extract_by_best_similarity from openai import OpenAI from utils import ( generate_column_names, generate_file_name, get_column, normalize_text, print_and_log, read_csv_data, write_new_data, write_to_csv, ) def abstract_proofread(model_path, temperature, base_url, api_key, prompt): """ Function to proofread an abstract using an AI language model. Parameters: model_path (str): The path or identifier of the AI model to use. temperature (float): Sampling temperature for the model's output. base_url (str): The base URL for the API endpoint. api_key (str): The API key for authentication. prompt (str): The text prompt to provide to the AI for proofreading. Returns: str: The proofread abstract generated by the AI model. """ # Initialize the AI client with the provided API key and base URL ai_client = OpenAI(api_key=api_key, base_url=base_url) # Create a chat completion request with the system message and user prompt chat_completion = ai_client.chat.completions.create( messages=[ { "role": "system", "content": "You are an AI assistant", }, { "role": "user", "content": prompt, }, ], model=model_path, max_tokens=1024, temperature=temperature, ) # Return the content of the first choice's message return chat_completion.choices[0].message.content def proofread_by_model_name(model_name, input_text, normalize_output): """ Proofreads the given input text using the specified model. Args: model_name (str): The name of the model to use for proofreading. input_text (str): The text to be proofread. normalize_output (bool): Whether to normalize the output or not. Returns: str: The proofread text. """ # Constants for API access base_url = TOGETHER_PATH api_key = TOGETHER_API_KEY temperature = TEMPERATURE # Retrieve the model path from the dictionary if model_name in MODEL_PATHS: model_path = MODEL_PATHS[model_name] else: raise ValueError("Model name not found in the dictionary.") # Formulate the prompt for the model prompt = f"Proofreading for the text: ```{input_text}```" # Apply output normalization if required if normalize_output: prompt = output_normalization(prompt) # Debugging: Print the prompt print(f"Prompt: {prompt}") # Call the abstract proofreading function with the prepared parameters return abstract_proofread( model_path, temperature, base_url, api_key, prompt, ) def gemini_proofread(input_text, normalize_output): """ Proofreads the given text using the GEMINI_MODEL. Parameters: input_text (str): The text to be proofread. normalize_output (bool): Flag indicating whether to normalize the output. Returns: str: The proofread text. """ prompt = f"Proofreading for the text: ```{input_text}```" if normalize_output: prompt = output_normalization(prompt) response = GEMINI_MODEL.generate_content(prompt) return response.text def chatGPT_proofread(input_text, normalize_output): """ Proofreads the given text using the chat_model. Parameters: input_text (str): The text to be proofread. normalize_output (bool): Flag indicating whether to normalize the output. Returns: str: The proofread text. """ prompt = f"Proofreading for the text: ```{input_text}```" if normalize_output: prompt = output_normalization(prompt) print(f"Starting API call with prompt: {prompt}") result = OPENAI_MODEL.predict(prompt) print(f"Ending API call with prompt: {prompt}") return result def output_normalization(prompt): """ Normalizes the output by appending a specific instruction to the prompt. Parameters: prompt (str): The initial prompt. Returns: str: The modified prompt. """ return ( prompt + " Please only output the proofread text without any explanation." ) def proofread_with_best_similarity(input_text, model_kind): """ Proofreads the input text using the specified model and extracts the best-corrected text based on similarity. Args: input_text (str): The original text to be proofread. model_kind (str): The kind of model to use for proofreading (e.g., CHATGPT, GEMINI). Returns: tuple: A tuple containing the raw proofread text and the best-corrected text. """ # Normalize the input text normalized_input_text = normalize_text(input_text) print_and_log(f"INPUT = {normalized_input_text}") result_text = "" raw_text = "" for i in range( 1, ): # Loop is redundant as it runs only once; # consider removing if unnecessary # Select the proofreading model based on model_kind if model_kind == CHATGPT: raw_text = chatGPT_proofread( normalized_input_text, normalize_output=IS_OUTPUT_NORMALIZATION, ) elif model_kind == GEMINI: raw_text = gemini_proofread( normalized_input_text, normalize_output=IS_OUTPUT_NORMALIZATION, ) else: raw_text = proofread_by_model_name( model_kind, normalized_input_text, normalize_output=IS_OUTPUT_NORMALIZATION, ) # Extract the best candidate text based on similarity result_text = extract_by_best_similarity( normalized_input_text, raw_text, ) # Log the raw and result texts print_and_log(f"RAW_{i} = {raw_text}") print # Normalize the result text result_text = normalize_text(result_text) # If a valid result is obtained, return it if result_text != "": return raw_text, result_text # Return the raw and result texts return raw_text, result_text def generate_new_data_with_best_similarity( existing_data_file, existing_kinds, new_kinds, ): """ Generates new data with the best similarity based on existing and new kinds, and writes the results to a CSV file. Args: existing_data_file (str): The path to the existing data file. existing_kinds (list): A list of existing kinds. new_kinds (list): A list of new kinds. Returns: None """ # Combine existing and new kinds into a single list all_kinds = existing_kinds + new_kinds # Generate column names for the CSV file column_names = generate_column_names(all_kinds) # Generate column names for existing kinds existing_column_names = generate_column_names(existing_kinds) # Generate the output file name output_file = generate_file_name( existing_data_file, existing_kinds, new_kinds, ) # Create the output file with column names if it doesn't exist if not os.path.exists(output_file): write_to_csv(output_file, column_names) # Read existing data from the file existing_data = { kind: get_column(existing_data_file, kind) for kind in existing_column_names } # Read input data from the output file input_data = read_csv_data(output_file) start_index = len(input_data) print(f"start_index = {start_index}") num_rows = len(existing_data["human"]) global_generate_set = [] global_reuse = [] for index in range(start_index, num_rows): # Initialize generation and reuse sets generate_set = [] reuse_set = [] # Prepare the current generation dictionary current_generation = { kind: existing_data[kind][index] for kind in existing_column_names } print(f"current_generation before generation = {current_generation}") human_text = current_generation["human"] # Generate new kinds based on human text for kind in new_kinds: _, generated_text = proofread_with_best_similarity( human_text, kind, ) current_generation[kind] = generated_text generate_set.append(kind) print(f"current_generation after generate one = {current_generation}") # Generate combinations of kinds for first_kind in all_kinds: for second_kind in all_kinds: combination_name = f"{first_kind}_{second_kind}" if combination_name not in current_generation: if ( first_kind in current_generation and current_generation[first_kind] == human_text ): generated_text = current_generation[second_kind] reuse_set.append( f"{combination_name} from {second_kind}", ) else: is_need_generation = True for first_kind_2 in all_kinds: if ( first_kind != first_kind_2 and current_generation[first_kind] == current_generation[first_kind_2] ): combination_name_2 = ( f"{first_kind_2}_{second_kind}" ) if combination_name_2 in current_generation: generated_text = current_generation[ combination_name_2 ] reuse_set.append( f"{combination_name} from {combination_name_2}", # noqa: E501 ) is_need_generation = False break if is_need_generation: _, generated_text = proofread_with_best_similarity( current_generation[first_kind], second_kind, ) generate_set.append(f"{first_kind}_{second_kind}") current_generation[combination_name] = generated_text # Write the current generation to the output file write_new_data(output_file, current_generation, column_names) # Update global sets global_generate_set.append(generate_set) global_reuse