learning4 commited on
Commit
ffd5d34
·
verified ·
1 Parent(s): 1eb077e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from flask import Flask, request, render_template, send_file
3
+ import pandas as pd
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
5
+ import torch
6
+ import os
7
+ from datetime import datetime
8
+ from datasets import load_dataset
9
+ from huggingface_hub import login
10
+
11
+ # Load Hugging Face token from environment variable
12
+ HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
13
+
14
+ # Authenticate with Hugging Face
15
+ if HUGGING_FACE_TOKEN:
16
+ login(token=HUGGING_FACE_TOKEN)
17
+ else:
18
+ raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
19
+
20
+ # Initialize the Flask application
21
+ app = Flask(__name__)
22
+
23
+ # Set up the device (CUDA or CPU)
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ # Optional: Set up logging for debugging
27
+ logging.basicConfig(level=logging.DEBUG)
28
+
29
+ # Define a function to classify user persona based on the selected model
30
+ def classify_persona(text, model, tokenizer):
31
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits
34
+
35
+ # Convert logits to probabilities
36
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
37
+
38
+ # Print logits and probabilities for debugging
39
+ print(f"Logits: {logits}")
40
+ print(f"Probabilities: {probabilities}")
41
+
42
+ # Get the predicted classes
43
+ predictions = torch.argmax(probabilities, dim=1)
44
+
45
+ persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
46
+
47
+ # If there are multiple predictions, return the first one (or handle them as needed)
48
+ predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
49
+
50
+ # For now, let's assume you want the first prediction
51
+ return predicted_personas[0]
52
+
53
+ # Define the function to determine if a message is polarized
54
+ def is_polarized(message):
55
+ # If message is a list, join it into a single string
56
+ if isinstance(message, list):
57
+ message = ' '.join(message)
58
+
59
+ polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
60
+ return any(keyword in message.lower() for keyword in polarized_keywords)
61
+
62
+
63
+ # Define the function to generate AI-based nudges using the selected transformer model
64
+ def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
65
+ # Ensure min_length is less than or equal to max_length
66
+ min_length = min(min_length, max_length)
67
+
68
+ if model_type == "seq2seq":
69
+ prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
70
+ inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
71
+ generated_ids = model.generate(
72
+ inputs['input_ids'],
73
+ max_length=max_length,
74
+ min_length=min_length,
75
+ temperature=temperature,
76
+ top_p=top_p,
77
+ repetition_penalty=repetition_penalty,
78
+ do_sample=True,
79
+ num_beams=4,
80
+ early_stopping=True
81
+ )
82
+ nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
83
+ elif model_type == "causal":
84
+ prompt = f"{message} [AI Nudge]:"
85
+ inputs = tokenizer(prompt, return_tensors='pt').to(device)
86
+ generated_ids = model.generate(
87
+ inputs['input_ids'],
88
+ max_length=max_length,
89
+ min_length=min_length,
90
+ temperature=temperature,
91
+ top_p=top_p,
92
+ repetition_penalty=repetition_penalty,
93
+ do_sample=True,
94
+ )
95
+ nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
96
+ else:
97
+ nudge = "This model is not suitable for generating text."
98
+
99
+ return nudge
100
+
101
+
102
+ @app.route('/', methods=['GET', 'POST'])
103
+ def home():
104
+ logging.debug("Home route accessed.")
105
+ if request.method == 'POST':
106
+ logging.debug("POST request received.")
107
+ try:
108
+ # Get the model names from the form
109
+ persona_model_name = request.form.get('persona_model_name', 'roberta-base')
110
+ nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
111
+ logging.debug(f"Selected persona model: {persona_model_name}")
112
+ logging.debug(f"Selected nudge model: {nudge_model_name}")
113
+
114
+ # Load persona classification model
115
+ persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
116
+ persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)
117
+
118
+ # Load nudge generation model
119
+ if "bart" in nudge_model_name or "t5" in nudge_model_name:
120
+ model_type = "seq2seq"
121
+ nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
122
+ elif "gpt2" in nudge_model_name:
123
+ model_type = "causal"
124
+ nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
125
+ else:
126
+ logging.error("Unsupported model selected.")
127
+ return "Selected model is not supported for text generation tasks.", 400
128
+
129
+ nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
130
+ logging.debug("Models and tokenizers loaded.")
131
+
132
+ use_online_dataset = request.form.get('use_online_dataset') == 'yes'
133
+
134
+ if use_online_dataset:
135
+ # Attempt to load the specified online dataset
136
+ dataset_name = request.form.get('dataset_name')
137
+ logging.debug(f"Selected online dataset: {dataset_name}")
138
+
139
+ if dataset_name == 'personachat':
140
+ # Use AlekseyKorshuk/persona-chat if 'personachat' is selected
141
+ dataset_name = 'AlekseyKorshuk/persona-chat'
142
+
143
+ dataset = load_dataset(dataset_name)
144
+ df = pd.DataFrame(dataset['train']) # Use the training split for processing
145
+ df = df.rename(columns=lambda x: x.strip().lower())
146
+ df = df[['utterances', 'personality']] # Modify this according to the dataset structure
147
+ df.columns = ['topic', 'post_reply'] # Standardize column names for processing
148
+
149
+ else:
150
+ uploaded_file = request.files['file']
151
+ if uploaded_file.filename != '':
152
+ logging.debug(f"File uploaded: {uploaded_file.filename}")
153
+
154
+ df = pd.read_csv(uploaded_file)
155
+ df.columns = df.columns.str.strip().str.lower()
156
+
157
+ if 'post_reply' not in df.columns:
158
+ logging.error("Required column 'post_reply' is missing in the CSV.")
159
+ return "The uploaded CSV file must contain 'post_reply' column.", 400
160
+
161
+ augmented_rows = []
162
+ for _, row in df.iterrows():
163
+ if 'user_persona' not in row or pd.isna(row['user_persona']):
164
+ # Classify user persona if not provided
165
+ row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
166
+ augmented_rows.append(row.to_dict())
167
+
168
+ if is_polarized(row['post_reply']):
169
+ nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
170
+ augmented_rows.append({
171
+ 'topic': row['topic'],
172
+ 'user_persona': 'AI Nudge',
173
+ 'post_reply': nudge
174
+ })
175
+
176
+ augmented_df = pd.DataFrame(augmented_rows)
177
+ logging.debug("Processing completed.")
178
+
179
+ # Generate the output filename
180
+ persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
181
+ nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
182
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
183
+ output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
184
+
185
+ # Instead of saving to a directory, create the CSV in memory
186
+ csv_buffer = io.BytesIO()
187
+ augmented_df.to_csv(csv_buffer, index=False)
188
+ csv_buffer.seek(0) # Reset buffer position to the start
189
+
190
+ # Directly send the file for download without saving to a specific folder
191
+ return send_file(
192
+ csv_buffer,
193
+ as_attachment=True,
194
+ download_name=output_filename,
195
+ mimetype='text/csv'
196
+ )
197
+ except Exception as e:
198
+ logging.error(f"Error processing the request: {e}", exc_info=True)
199
+ return "There was an error processing your request.", 500
200
+
201
+ logging.debug("Rendering index.html")
202
+ return render_template('index.html')
203
+
204
+ if __name__ == '__main__':
205
+ app.run(debug=True)