Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -41,13 +41,24 @@ test_texts = [example['text'] for example in test_data]
|
|
41 |
test_labels = [example['label'] for example in test_data]
|
42 |
label_mapping = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
|
43 |
test_labels = [label_mapping[label].lower() for label in test_labels]
|
44 |
-
|
45 |
# Evaluate DistilBERT
|
46 |
distilbert_predictions = [distilbert_classifier(text)[0]['label'].lower() for text in test_texts]
|
47 |
print("DistilBERT Model Evaluation Metrics:")
|
48 |
print("Accuracy:", accuracy_score(test_labels, distilbert_predictions))
|
49 |
print("F1 Score:", f1_score(test_labels, distilbert_predictions, average='weighted'))
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Function to generate an image from a prompt using Stable Diffusion
|
53 |
def generate_image(prompt):
|
|
|
41 |
test_labels = [example['label'] for example in test_data]
|
42 |
label_mapping = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
|
43 |
test_labels = [label_mapping[label].lower() for label in test_labels]
|
44 |
+
|
45 |
# Evaluate DistilBERT
|
46 |
distilbert_predictions = [distilbert_classifier(text)[0]['label'].lower() for text in test_texts]
|
47 |
print("DistilBERT Model Evaluation Metrics:")
|
48 |
print("Accuracy:", accuracy_score(test_labels, distilbert_predictions))
|
49 |
print("F1 Score:", f1_score(test_labels, distilbert_predictions, average='weighted'))
|
50 |
+
|
51 |
+
# Evaluate RoBERTa
|
52 |
+
roberta_predictions = []
|
53 |
+
for text in test_texts:
|
54 |
+
inputs = roberta_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
55 |
+
outputs = roberta_model(**inputs)
|
56 |
+
prediction = torch.argmax(outputs.logits, dim=-1).item()
|
57 |
+
roberta_predictions.append(label_mapping[prediction].lower())
|
58 |
+
|
59 |
+
print("RoBERTa Model Evaluation Metrics:")
|
60 |
+
print("Accuracy:", accuracy_score(test_labels, roberta_predictions))
|
61 |
+
print("F1 Score:", f1_score(test_labels, roberta_predictions, average='weighted'))
|
62 |
|
63 |
# Function to generate an image from a prompt using Stable Diffusion
|
64 |
def generate_image(prompt):
|