import logging from flask import Flask, request, render_template, send_file import pandas as pd from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM import torch import os from datetime import datetime from datasets import load_dataset from huggingface_hub import login # Load Hugging Face token from environment variable HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") # Authenticate with Hugging Face if HUGGING_FACE_TOKEN: login(token=HUGGING_FACE_TOKEN) else: raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.") # Initialize the Flask application app = Flask(__name__) # Set up the device (CUDA or CPU) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Optional: Set up logging for debugging logging.basicConfig(level=logging.DEBUG) # Define a function to classify user persona based on the selected model def classify_persona(text, model, tokenizer): inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device) outputs = model(**inputs) logits = outputs.logits # Convert logits to probabilities probabilities = torch.nn.functional.softmax(logits, dim=1) # Print logits and probabilities for debugging print(f"Logits: {logits}") print(f"Probabilities: {probabilities}") # Get the predicted classes predictions = torch.argmax(probabilities, dim=1) persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'} # If there are multiple predictions, return the first one (or handle them as needed) predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions] # For now, let's assume you want the first prediction return predicted_personas[0] # Define the function to determine if a message is polarized def is_polarized(message): # If message is a list, join it into a single string if isinstance(message, list): message = ' '.join(message) polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"] return any(keyword in message.lower() for keyword in polarized_keywords) # Define the function to generate AI-based nudges using the selected transformer model def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1): # Ensure min_length is less than or equal to max_length min_length = min(min_length, max_length) if model_type == "seq2seq": prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}" inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device) generated_ids = model.generate( inputs['input_ids'], max_length=max_length, min_length=min_length, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, num_beams=4, early_stopping=True ) nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) elif model_type == "causal": prompt = f"{message} [AI Nudge]:" inputs = tokenizer(prompt, return_tensors='pt').to(device) generated_ids = model.generate( inputs['input_ids'], max_length=max_length, min_length=min_length, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, ) nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) else: nudge = "This model is not suitable for generating text." return nudge @app.route('/', methods=['GET', 'POST']) def home(): logging.debug("Home route accessed.") if request.method == 'POST': logging.debug("POST request received.") try: # Get the model names from the form persona_model_name = request.form.get('persona_model_name', 'roberta-base') nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn') logging.debug(f"Selected persona model: {persona_model_name}") logging.debug(f"Selected nudge model: {nudge_model_name}") # Load persona classification model persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device) persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name) # Load nudge generation model if "bart" in nudge_model_name or "t5" in nudge_model_name: model_type = "seq2seq" nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device) elif "gpt2" in nudge_model_name: model_type = "causal" nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device) else: logging.error("Unsupported model selected.") return "Selected model is not supported for text generation tasks.", 400 nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name) logging.debug("Models and tokenizers loaded.") use_online_dataset = request.form.get('use_online_dataset') == 'yes' if use_online_dataset: # Attempt to load the specified online dataset dataset_name = request.form.get('dataset_name') logging.debug(f"Selected online dataset: {dataset_name}") if dataset_name == 'personachat': # Use AlekseyKorshuk/persona-chat if 'personachat' is selected dataset_name = 'AlekseyKorshuk/persona-chat' dataset = load_dataset(dataset_name) df = pd.DataFrame(dataset['train']) # Use the training split for processing df = df.rename(columns=lambda x: x.strip().lower()) df = df[['utterances', 'personality']] # Modify this according to the dataset structure df.columns = ['topic', 'post_reply'] # Standardize column names for processing else: uploaded_file = request.files['file'] if uploaded_file.filename != '': logging.debug(f"File uploaded: {uploaded_file.filename}") df = pd.read_csv(uploaded_file) df.columns = df.columns.str.strip().str.lower() if 'post_reply' not in df.columns: logging.error("Required column 'post_reply' is missing in the CSV.") return "The uploaded CSV file must contain 'post_reply' column.", 400 augmented_rows = [] for _, row in df.iterrows(): if 'user_persona' not in row or pd.isna(row['user_persona']): # Classify user persona if not provided row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer) augmented_rows.append(row.to_dict()) if is_polarized(row['post_reply']): nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type) augmented_rows.append({ 'topic': row['topic'], 'user_persona': 'AI Nudge', 'post_reply': nudge }) augmented_df = pd.DataFrame(augmented_rows) logging.debug("Processing completed.") # Generate the output filename persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_') nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_') current_time = datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv" # Instead of saving to a directory, create the CSV in memory csv_buffer = io.BytesIO() augmented_df.to_csv(csv_buffer, index=False) csv_buffer.seek(0) # Reset buffer position to the start # Directly send the file for download without saving to a specific folder return send_file( csv_buffer, as_attachment=True, download_name=output_filename, mimetype='text/csv' ) except Exception as e: logging.error(f"Error processing the request: {e}", exc_info=True) return "There was an error processing your request.", 500 logging.debug("Rendering index.html") return render_template('index.html') if __name__ == '__main__': app.run(debug=True)