File size: 9,437 Bytes
ffd5d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import logging
from flask import Flask, request, render_template, send_file
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import os
from datetime import datetime
from datasets import load_dataset
from huggingface_hub import login

# Load Hugging Face token from environment variable
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")

# Authenticate with Hugging Face
if HUGGING_FACE_TOKEN:
    login(token=HUGGING_FACE_TOKEN)
else:
    raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")

# Initialize the Flask application
app = Flask(__name__)

# Set up the device (CUDA or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Optional: Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)

# Define a function to classify user persona based on the selected model
def classify_persona(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
    outputs = model(**inputs)
    logits = outputs.logits
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=1)
    
    # Print logits and probabilities for debugging
    print(f"Logits: {logits}")
    print(f"Probabilities: {probabilities}")
    
    # Get the predicted classes
    predictions = torch.argmax(probabilities, dim=1)
    
    persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
    
    # If there are multiple predictions, return the first one (or handle them as needed)
    predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
    
    # For now, let's assume you want the first prediction
    return predicted_personas[0]

# Define the function to determine if a message is polarized
def is_polarized(message):
    # If message is a list, join it into a single string
    if isinstance(message, list):
        message = ' '.join(message)
    
    polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
    return any(keyword in message.lower() for keyword in polarized_keywords)


# Define the function to generate AI-based nudges using the selected transformer model
def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
    # Ensure min_length is less than or equal to max_length
    min_length = min(min_length, max_length)
    
    if model_type == "seq2seq":
        prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
        inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
        generated_ids = model.generate(
            inputs['input_ids'],
            max_length=max_length,
            min_length=min_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            num_beams=4,
            early_stopping=True
        )
        nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    elif model_type == "causal":
        prompt = f"{message} [AI Nudge]:"
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        generated_ids = model.generate(
            inputs['input_ids'],
            max_length=max_length,
            min_length=min_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
        )
        nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    else:
        nudge = "This model is not suitable for generating text."
    
    return nudge


@app.route('/', methods=['GET', 'POST'])
def home():
    logging.debug("Home route accessed.")
    if request.method == 'POST':
        logging.debug("POST request received.")
        try:
            # Get the model names from the form
            persona_model_name = request.form.get('persona_model_name', 'roberta-base')
            nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
            logging.debug(f"Selected persona model: {persona_model_name}")
            logging.debug(f"Selected nudge model: {nudge_model_name}")

            # Load persona classification model
            persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
            persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)

            # Load nudge generation model
            if "bart" in nudge_model_name or "t5" in nudge_model_name:
                model_type = "seq2seq"
                nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
            elif "gpt2" in nudge_model_name:
                model_type = "causal"
                nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
            else:
                logging.error("Unsupported model selected.")
                return "Selected model is not supported for text generation tasks.", 400

            nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
            logging.debug("Models and tokenizers loaded.")

            use_online_dataset = request.form.get('use_online_dataset') == 'yes'
            
            if use_online_dataset:
                # Attempt to load the specified online dataset
                dataset_name = request.form.get('dataset_name')
                logging.debug(f"Selected online dataset: {dataset_name}")
                
                if dataset_name == 'personachat':
                    # Use AlekseyKorshuk/persona-chat if 'personachat' is selected
                    dataset_name = 'AlekseyKorshuk/persona-chat'

                dataset = load_dataset(dataset_name)
                df = pd.DataFrame(dataset['train'])  # Use the training split for processing
                df = df.rename(columns=lambda x: x.strip().lower())
                df = df[['utterances', 'personality']]  # Modify this according to the dataset structure
                df.columns = ['topic', 'post_reply']  # Standardize column names for processing

            else:
                uploaded_file = request.files['file']
                if uploaded_file.filename != '':
                    logging.debug(f"File uploaded: {uploaded_file.filename}")
                    
                    df = pd.read_csv(uploaded_file)
                    df.columns = df.columns.str.strip().str.lower()
                    
                    if 'post_reply' not in df.columns:
                        logging.error("Required column 'post_reply' is missing in the CSV.")
                        return "The uploaded CSV file must contain 'post_reply' column.", 400

            augmented_rows = []
            for _, row in df.iterrows():
                if 'user_persona' not in row or pd.isna(row['user_persona']):
                    # Classify user persona if not provided
                    row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
                augmented_rows.append(row.to_dict())
                
                if is_polarized(row['post_reply']):
                    nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
                    augmented_rows.append({
                        'topic': row['topic'],
                        'user_persona': 'AI Nudge',
                        'post_reply': nudge
                    })

            augmented_df = pd.DataFrame(augmented_rows)
            logging.debug("Processing completed.")

            # Generate the output filename
            persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
            nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
            current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
        
            # Instead of saving to a directory, create the CSV in memory
            csv_buffer = io.BytesIO()
            augmented_df.to_csv(csv_buffer, index=False)
            csv_buffer.seek(0)  # Reset buffer position to the start

            # Directly send the file for download without saving to a specific folder
            return send_file(
                csv_buffer,
                as_attachment=True,
                download_name=output_filename,
                mimetype='text/csv'
            )
        except Exception as e:
            logging.error(f"Error processing the request: {e}", exc_info=True)
            return "There was an error processing your request.", 500

    logging.debug("Rendering index.html")
    return render_template('index.html')

if __name__ == '__main__':
    app.run(debug=True)