Spaces:
Runtime error
Runtime error
import logging | |
from flask import Flask, request, render_template, send_file | |
import pandas as pd | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
import torch | |
import os | |
from datetime import datetime | |
from datasets import load_dataset | |
from huggingface_hub import login | |
# Load Hugging Face token from environment variable | |
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
# Authenticate with Hugging Face | |
if HUGGING_FACE_TOKEN: | |
login(token=HUGGING_FACE_TOKEN) | |
else: | |
raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
# Initialize the Flask application | |
app = Flask(__name__) | |
# Set up the device (CUDA or CPU) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Optional: Set up logging for debugging | |
logging.basicConfig(level=logging.DEBUG) | |
# Define a function to classify user persona based on the selected model | |
def classify_persona(text, model, tokenizer): | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Convert logits to probabilities | |
probabilities = torch.nn.functional.softmax(logits, dim=1) | |
# Print logits and probabilities for debugging | |
print(f"Logits: {logits}") | |
print(f"Probabilities: {probabilities}") | |
# Get the predicted classes | |
predictions = torch.argmax(probabilities, dim=1) | |
persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'} | |
# If there are multiple predictions, return the first one (or handle them as needed) | |
predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions] | |
# For now, let's assume you want the first prediction | |
return predicted_personas[0] | |
# Define the function to determine if a message is polarized | |
def is_polarized(message): | |
# If message is a list, join it into a single string | |
if isinstance(message, list): | |
message = ' '.join(message) | |
polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"] | |
return any(keyword in message.lower() for keyword in polarized_keywords) | |
# Define the function to generate AI-based nudges using the selected transformer model | |
def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1): | |
# Ensure min_length is less than or equal to max_length | |
min_length = min(min_length, max_length) | |
if model_type == "seq2seq": | |
prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}" | |
inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device) | |
generated_ids = model.generate( | |
inputs['input_ids'], | |
max_length=max_length, | |
min_length=min_length, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
num_beams=4, | |
early_stopping=True | |
) | |
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
elif model_type == "causal": | |
prompt = f"{message} [AI Nudge]:" | |
inputs = tokenizer(prompt, return_tensors='pt').to(device) | |
generated_ids = model.generate( | |
inputs['input_ids'], | |
max_length=max_length, | |
min_length=min_length, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
) | |
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
else: | |
nudge = "This model is not suitable for generating text." | |
return nudge | |
def home(): | |
logging.debug("Home route accessed.") | |
if request.method == 'POST': | |
logging.debug("POST request received.") | |
try: | |
# Get the model names from the form | |
persona_model_name = request.form.get('persona_model_name', 'roberta-base') | |
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn') | |
logging.debug(f"Selected persona model: {persona_model_name}") | |
logging.debug(f"Selected nudge model: {nudge_model_name}") | |
# Load persona classification model | |
persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device) | |
persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name) | |
# Load nudge generation model | |
if "bart" in nudge_model_name or "t5" in nudge_model_name: | |
model_type = "seq2seq" | |
nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device) | |
elif "gpt2" in nudge_model_name: | |
model_type = "causal" | |
nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device) | |
else: | |
logging.error("Unsupported model selected.") | |
return "Selected model is not supported for text generation tasks.", 400 | |
nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name) | |
logging.debug("Models and tokenizers loaded.") | |
use_online_dataset = request.form.get('use_online_dataset') == 'yes' | |
if use_online_dataset: | |
# Attempt to load the specified online dataset | |
dataset_name = request.form.get('dataset_name') | |
logging.debug(f"Selected online dataset: {dataset_name}") | |
if dataset_name == 'personachat': | |
# Use AlekseyKorshuk/persona-chat if 'personachat' is selected | |
dataset_name = 'AlekseyKorshuk/persona-chat' | |
dataset = load_dataset(dataset_name) | |
df = pd.DataFrame(dataset['train']) # Use the training split for processing | |
df = df.rename(columns=lambda x: x.strip().lower()) | |
df = df[['utterances', 'personality']] # Modify this according to the dataset structure | |
df.columns = ['topic', 'post_reply'] # Standardize column names for processing | |
else: | |
uploaded_file = request.files['file'] | |
if uploaded_file.filename != '': | |
logging.debug(f"File uploaded: {uploaded_file.filename}") | |
df = pd.read_csv(uploaded_file) | |
df.columns = df.columns.str.strip().str.lower() | |
if 'post_reply' not in df.columns: | |
logging.error("Required column 'post_reply' is missing in the CSV.") | |
return "The uploaded CSV file must contain 'post_reply' column.", 400 | |
augmented_rows = [] | |
for _, row in df.iterrows(): | |
if 'user_persona' not in row or pd.isna(row['user_persona']): | |
# Classify user persona if not provided | |
row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer) | |
augmented_rows.append(row.to_dict()) | |
if is_polarized(row['post_reply']): | |
nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type) | |
augmented_rows.append({ | |
'topic': row['topic'], | |
'user_persona': 'AI Nudge', | |
'post_reply': nudge | |
}) | |
augmented_df = pd.DataFrame(augmented_rows) | |
logging.debug("Processing completed.") | |
# Generate the output filename | |
persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_') | |
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_') | |
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv" | |
# Instead of saving to a directory, create the CSV in memory | |
csv_buffer = io.BytesIO() | |
augmented_df.to_csv(csv_buffer, index=False) | |
csv_buffer.seek(0) # Reset buffer position to the start | |
# Directly send the file for download without saving to a specific folder | |
return send_file( | |
csv_buffer, | |
as_attachment=True, | |
download_name=output_filename, | |
mimetype='text/csv' | |
) | |
except Exception as e: | |
logging.error(f"Error processing the request: {e}", exc_info=True) | |
return "There was an error processing your request.", 500 | |
logging.debug("Rendering index.html") | |
return render_template('index.html') | |
if __name__ == '__main__': | |
app.run(debug=True) | |