toxic-comment-classifier_rlhf / refine_paraphrases.py
JanviMl's picture
Update refine_paraphrases.py
77f7351 verified
# refine_paraphrases.py
from datasets import load_dataset
import pandas as pd
from paraphraser import paraphrase_comment
from metrics import compute_reward_scores
import os
# Configuration
DATA_PATH = "JanviMl/toxi_refined_paraphrases"
OUTPUT_PATH = "iterated_paraphrases.csv"
MAX_ITERATIONS = 1
TARGET_SCORES = {
"empathy": 0.9,
"toxicity": 0.1,
"bias": 0.1,
"hallucination": 0.1,
"reward": 0.25
}
def meets_targets(scores):
return (scores["empathy"] >= TARGET_SCORES["empathy"] and
scores["toxicity"] <= TARGET_SCORES["toxicity"] and
scores["bias"] <= TARGET_SCORES["bias"] and
scores["hallucination"] <= TARGET_SCORES["hallucination"] and
scores["reward"] >= TARGET_SCORES["reward"])
def generate_new_paraphrase(original, current_paraphrase, current_scores, issues):
prompt = (
f"Original comment: '{original}'. "
f"Current paraphrase: '{current_paraphrase}'. "
f"Current scores: {current_scores}. "
f"Human feedback: {issues}. "
f"Generate a new paraphrase that improves empathy (>= {TARGET_SCORES['empathy']}), "
f"reduces toxicity (<= {TARGET_SCORES['toxicity']}), bias (<= {TARGET_SCORES['bias']}), "
f"and hallucination (<= {TARGET_SCORES['hallucination']}), and increases reward (>= {TARGET_SCORES['reward']})."
)
return paraphrase_comment(prompt)
def refine_paraphrase(row: pd.Series) -> tuple:
"""
Iteratively refine a single paraphrase.
Returns new paraphrase, scores, and reasoning.
"""
original = row["Comment"]
current_paraphrase = row["Refined_Paraphrase"]
current_scores = {
"empathy": row["Refined_Empathy"],
"toxicity": row["Refined_Toxicity"],
"bias": row["Refined_Bias"],
"hallucination": row["Refined_Hallucination"],
"reward": row["Refined_Reward_Score"]
}
issues = row["Human_Evaluation_Reasoning"]
iteration = 0
reasoning = []
print(f"Processing comment: {original}")
while iteration < MAX_ITERATIONS and not meets_targets(current_scores):
print(f"Starting iteration {iteration + 1} for comment: {original}")
# Generate new paraphrase
new_paraphrase = generate_new_paraphrase(original, current_paraphrase, current_scores, issues)
print(f"Generated paraphrase: {new_paraphrase}")
# Check if paraphrasing failed
if "Error: Unable to generate paraphrase" in new_paraphrase:
reasoning.append(f"Iteration {iteration + 1}: Paraphrasing failed - {new_paraphrase}")
break
# Evaluate new paraphrase
new_scores = compute_reward_scores(original, new_paraphrase)
print(f"New scores: {new_scores}")
# Log reasoning
reasoning.append(
f"Iteration {iteration + 1}: Generated '{new_paraphrase}' with scores {new_scores}. "
f"Previous scores {current_scores}."
)
# Update if improved
if new_scores["reward"] > current_scores["reward"]:
current_paraphrase = new_paraphrase
current_scores = new_scores
reasoning.append("Accepted new paraphrase due to improved reward score.")
else:
reasoning.append("Rejected new paraphrase; no improvement in reward score.")
iteration += 1
print(f"Finished processing comment: {original}")
return current_paraphrase, current_scores, "; ".join(reasoning)
def main():
# Load dataset from Hugging Face Hub
try:
df = load_dataset(DATA_PATH, split="train").to_pandas()[:1] # Process only 1 row
except Exception as e:
print(f"Error loading dataset: {str(e)}")
return
results = []
for idx, row in df.iterrows():
new_paraphrase, new_scores, reasoning = refine_paraphrase(row)
result = {
"Comment": row["Comment"],
"Original_Paraphrase": row["Original_Paraphrase"],
"Refined_Paraphrase": row["Refined_Paraphrase"],
"Iterated_Paraphrase": new_paraphrase,
"Original_Reward_Score": row["Original_Reward_Score"],
"Refined_Reward_Score": row["Refined_Reward_Score"],
"Iterated_Reward_Score": new_scores["reward"],
"Iterated_Empathy": new_scores["empathy"],
"Iterated_Toxicity": new_scores["toxicity"],
"Iterated_Bias": new_scores["bias"],
"Iterated_Hallucination": new_scores["hallucination"],
"Iteration_Reasoning": reasoning
}
results.append(result)
# Save results to CSV
result_df = pd.DataFrame(results)
result_df.to_csv(OUTPUT_PATH, index=False)
print(f"Refinement complete. Results saved to {OUTPUT_PATH}")
# Push to Hugging Face Hub
try:
dataset = load_dataset("pandas", data_files=OUTPUT_PATH)
dataset.push_to_hub("JanviMl/toxi_iterated_paraphrases", token=os.getenv("HF_TOKEN"))
print("Pushed to Hugging Face Hub: JanviMl/toxi_iterated_paraphrases")
except Exception as e:
print(f"Error pushing to Hugging Face Hub: {str(e)}")
if __name__ == "__main__":
main()