Spaces:
Runtime error
Runtime error
File size: 9,437 Bytes
ffd5d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import logging
from flask import Flask, request, render_template, send_file
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import os
from datetime import datetime
from datasets import load_dataset
from huggingface_hub import login
# Load Hugging Face token from environment variable
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
# Authenticate with Hugging Face
if HUGGING_FACE_TOKEN:
login(token=HUGGING_FACE_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
# Initialize the Flask application
app = Flask(__name__)
# Set up the device (CUDA or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Optional: Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)
# Define a function to classify user persona based on the selected model
def classify_persona(text, model, tokenizer):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
outputs = model(**inputs)
logits = outputs.logits
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
# Print logits and probabilities for debugging
print(f"Logits: {logits}")
print(f"Probabilities: {probabilities}")
# Get the predicted classes
predictions = torch.argmax(probabilities, dim=1)
persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
# If there are multiple predictions, return the first one (or handle them as needed)
predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
# For now, let's assume you want the first prediction
return predicted_personas[0]
# Define the function to determine if a message is polarized
def is_polarized(message):
# If message is a list, join it into a single string
if isinstance(message, list):
message = ' '.join(message)
polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
return any(keyword in message.lower() for keyword in polarized_keywords)
# Define the function to generate AI-based nudges using the selected transformer model
def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
# Ensure min_length is less than or equal to max_length
min_length = min(min_length, max_length)
if model_type == "seq2seq":
prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
generated_ids = model.generate(
inputs['input_ids'],
max_length=max_length,
min_length=min_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_beams=4,
early_stopping=True
)
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
elif model_type == "causal":
prompt = f"{message} [AI Nudge]:"
inputs = tokenizer(prompt, return_tensors='pt').to(device)
generated_ids = model.generate(
inputs['input_ids'],
max_length=max_length,
min_length=min_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
)
nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
else:
nudge = "This model is not suitable for generating text."
return nudge
@app.route('/', methods=['GET', 'POST'])
def home():
logging.debug("Home route accessed.")
if request.method == 'POST':
logging.debug("POST request received.")
try:
# Get the model names from the form
persona_model_name = request.form.get('persona_model_name', 'roberta-base')
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
logging.debug(f"Selected persona model: {persona_model_name}")
logging.debug(f"Selected nudge model: {nudge_model_name}")
# Load persona classification model
persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)
# Load nudge generation model
if "bart" in nudge_model_name or "t5" in nudge_model_name:
model_type = "seq2seq"
nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
elif "gpt2" in nudge_model_name:
model_type = "causal"
nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
else:
logging.error("Unsupported model selected.")
return "Selected model is not supported for text generation tasks.", 400
nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
logging.debug("Models and tokenizers loaded.")
use_online_dataset = request.form.get('use_online_dataset') == 'yes'
if use_online_dataset:
# Attempt to load the specified online dataset
dataset_name = request.form.get('dataset_name')
logging.debug(f"Selected online dataset: {dataset_name}")
if dataset_name == 'personachat':
# Use AlekseyKorshuk/persona-chat if 'personachat' is selected
dataset_name = 'AlekseyKorshuk/persona-chat'
dataset = load_dataset(dataset_name)
df = pd.DataFrame(dataset['train']) # Use the training split for processing
df = df.rename(columns=lambda x: x.strip().lower())
df = df[['utterances', 'personality']] # Modify this according to the dataset structure
df.columns = ['topic', 'post_reply'] # Standardize column names for processing
else:
uploaded_file = request.files['file']
if uploaded_file.filename != '':
logging.debug(f"File uploaded: {uploaded_file.filename}")
df = pd.read_csv(uploaded_file)
df.columns = df.columns.str.strip().str.lower()
if 'post_reply' not in df.columns:
logging.error("Required column 'post_reply' is missing in the CSV.")
return "The uploaded CSV file must contain 'post_reply' column.", 400
augmented_rows = []
for _, row in df.iterrows():
if 'user_persona' not in row or pd.isna(row['user_persona']):
# Classify user persona if not provided
row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
augmented_rows.append(row.to_dict())
if is_polarized(row['post_reply']):
nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
augmented_rows.append({
'topic': row['topic'],
'user_persona': 'AI Nudge',
'post_reply': nudge
})
augmented_df = pd.DataFrame(augmented_rows)
logging.debug("Processing completed.")
# Generate the output filename
persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
# Instead of saving to a directory, create the CSV in memory
csv_buffer = io.BytesIO()
augmented_df.to_csv(csv_buffer, index=False)
csv_buffer.seek(0) # Reset buffer position to the start
# Directly send the file for download without saving to a specific folder
return send_file(
csv_buffer,
as_attachment=True,
download_name=output_filename,
mimetype='text/csv'
)
except Exception as e:
logging.error(f"Error processing the request: {e}", exc_info=True)
return "There was an error processing your request.", 500
logging.debug("Rendering index.html")
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True)
|