DepolNudgeBot / app.py
learning4's picture
Update app.py
e8b8ff8 verified
raw
history blame contribute delete
9.23 kB
import logging
from flask import Flask, request, render_template, send_file
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import os
from datetime import datetime
from datasets import load_dataset
from huggingface_hub import login
# Load Hugging Face token from environment variable
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
# Authenticate with Hugging Face
if HUGGING_FACE_TOKEN:
login(token=HUGGING_FACE_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
# Initialize the Flask application
app = Flask(__name__)
# Set up the device (CUDA or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Optional: Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)
# Define a function to classify user persona based on the selected model
def classify_persona(text, model, tokenizer):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
outputs = model(**inputs)
logits = outputs.logits
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
# Print logits and probabilities for debugging
print(f"Logits: {logits}")
print(f"Probabilities: {probabilities}")
# Get the predicted classes
predictions = torch.argmax(probabilities, dim=1)
persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
# If there are multiple predictions, return the first one (or handle them as needed)
predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
# For now, let's assume you want the first prediction
return predicted_personas[0]
# Define the function to determine if a message is polarized
def is_polarized(message):
# If message is a list, join it into a single string
if isinstance(message, list):
message = ' '.join(message)
polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
return any(keyword in message.lower() for keyword in polarized_keywords)
# Define the function to generate AI-based nudges using the selected transformer model
def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
# Ensure min_length is less than or equal to max_length
min_length = min(min_length, max_length)
if model_type == "seq2seq":
prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
generated_ids = model.generate(
inputs['input_ids'],
max_length=max_length,
min_length=min_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_beams=4,
early_stopping=True
)
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
elif model_type == "causal":
prompt = f"{message} [AI Nudge]:"
inputs = tokenizer(prompt, return_tensors='pt').to(device)
generated_ids = model.generate(
inputs['input_ids'],
max_length=max_length,
min_length=min_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
)
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
else:
nudge = "This model is not suitable for generating text."
return nudge
@app.route('/', methods=['GET', 'POST'])
def home():
logging.debug("Home route accessed.")
if request.method == 'POST':
logging.debug("POST request received.")
try:
# Get the model names from the form
persona_model_name = request.form.get('persona_model_name', 'roberta-base')
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
logging.debug(f"Selected persona model: {persona_model_name}")
logging.debug(f"Selected nudge model: {nudge_model_name}")
# Load persona classification model
persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)
# Load nudge generation model
if "bart" in nudge_model_name or "t5" in nudge_model_name:
model_type = "seq2seq"
nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
elif "gpt2" in nudge_model_name:
model_type = "causal"
nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
else:
logging.error("Unsupported model selected.")
return "Selected model is not supported for text generation tasks.", 400
nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
logging.debug("Models and tokenizers loaded.")
use_online_dataset = request.form.get('use_online_dataset') == 'yes'
if use_online_dataset:
# Attempt to load the specified online dataset
dataset_name = request.form.get('dataset_name')
logging.debug(f"Selected online dataset: {dataset_name}")
if dataset_name == 'personachat':
# Use AlekseyKorshuk/persona-chat if 'personachat' is selected
dataset_name = 'AlekseyKorshuk/persona-chat'
dataset = load_dataset(dataset_name)
df = pd.DataFrame(dataset['train']) # Use the training split for processing
df = df.rename(columns=lambda x: x.strip().lower())
df = df[['utterances', 'personality']] # Modify this according to the dataset structure
df.columns = ['topic', 'post_reply'] # Standardize column names for processing
else:
uploaded_file = request.files['file']
if uploaded_file.filename != '':
logging.debug(f"File uploaded: {uploaded_file.filename}")
df = pd.read_csv(uploaded_file)
df.columns = df.columns.str.strip().str.lower()
if 'post_reply' not in df.columns:
logging.error("Required column 'post_reply' is missing in the CSV.")
return "The uploaded CSV file must contain 'post_reply' column.", 400
augmented_rows = []
for _, row in df.iterrows():
if 'user_persona' not in row or pd.isna(row['user_persona']):
# Classify user persona if not provided
row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
augmented_rows.append(row.to_dict())
if is_polarized(row['post_reply']):
nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
augmented_rows.append({
'topic': row['topic'],
'user_persona': 'AI Nudge',
'post_reply': nudge
})
augmented_df = pd.DataFrame(augmented_rows)
logging.debug("Processing completed.")
# Generate the output filename
persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
# Instead of saving to a directory, create the CSV in memory
csv_buffer = io.BytesIO()
augmented_df.to_csv(csv_buffer, index=False)
csv_buffer.seek(0) # Reset buffer position to the start
# Directly send the file for download without saving to a specific folder
return send_file(
csv_buffer,
as_attachment=True,
download_name=output_filename,
mimetype='text/csv'
)
except Exception as e:
logging.error(f"Error processing the request: {e}", exc_info=True)
return "There was an error processing your request.", 500
logging.debug("Rendering index.html")
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True)