fajarah commited on
Commit
49238a6
·
verified ·
1 Parent(s): c776a9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -30
app.py CHANGED
@@ -1,44 +1,31 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  from torch.nn.functional import sigmoid
4
  import torch
5
  from PIL import Image
6
- from torchvision import transforms
7
- import requests
8
 
9
  # Load text emotion model
10
  model_name = "SamLowe/roberta-base-go_emotions"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
 
14
- # Load image emotion model (fine-tuned ResNet-50)
15
- image_model_name = "Celal11/resnet-50-finetuned-FER2013CKPlus-0.003"
16
- image_emotion_model = AutoModelForSequenceClassification.from_pretrained(image_model_name)
17
- image_tokenizer = AutoTokenizer.from_pretrained("microsoft/resnet-50")
18
 
19
- # Transform for image preprocessing
20
- image_transform = transforms.Compose([
21
- transforms.Resize((224, 224)),
22
- transforms.ToTensor(),
23
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
- ])
25
-
26
- # FER labels
27
- image_labels = [
28
- "Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral", "Contempt"
29
- ]
30
-
31
- # Analyze image emotion
32
- def analyze_image_emotion(image_path):
33
- if image_path is None:
34
  return "No image provided."
35
- image = Image.open(image_path).convert("RGB")
36
- img_tensor = image_transform(image).unsqueeze(0)
37
  with torch.no_grad():
38
- output = image_emotion_model(img_tensor)
39
- probs = sigmoid(output.logits)[0]
40
- top_idx = torch.argmax(probs).item()
41
- return f"{image_labels[top_idx]} ({probs[top_idx]:.2f})"
 
 
42
 
43
  # Emotion label to icon mapping (subset)
44
  emotion_icons = {
@@ -126,7 +113,7 @@ demo = gr.Interface(
126
  inputs=[
127
  gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
128
  gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"),
129
- gr.Image(type="filepath", label="Upload Face Photo")
130
  ],
131
  outputs=[
132
  gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]),
@@ -138,4 +125,4 @@ demo = gr.Interface(
138
  css=custom_css
139
  )
140
 
141
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoImageProcessor, AutoModelForImageClassification
3
  from torch.nn.functional import sigmoid
4
  import torch
5
  from PIL import Image
 
 
6
 
7
  # Load text emotion model
8
  model_name = "SamLowe/roberta-base-go_emotions"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
 
12
+ # Load image emotion model
13
+ image_model_name = "Celal11/resnet-50-finetuned-FER2013-0.001"
14
+ image_processor = AutoImageProcessor.from_pretrained(image_model_name)
15
+ image_model = AutoModelForImageClassification.from_pretrained(image_model_name)
16
 
17
+ # Analyze image emotion using processor and model
18
+ def analyze_image_emotion(image):
19
+ if image is None:
 
 
 
 
 
 
 
 
 
 
 
 
20
  return "No image provided."
21
+ inputs = image_processor(images=image, return_tensors="pt")
 
22
  with torch.no_grad():
23
+ logits = image_model(**inputs).logits
24
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
25
+ pred_idx = torch.argmax(probs).item()
26
+ label = image_model.config.id2label[pred_idx]
27
+ score = probs[pred_idx].item()
28
+ return f"{label} ({score:.2f})"
29
 
30
  # Emotion label to icon mapping (subset)
31
  emotion_icons = {
 
113
  inputs=[
114
  gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
115
  gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"),
116
+ gr.Image(type="pil", label="Upload Face Photo")
117
  ],
118
  outputs=[
119
  gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]),
 
125
  css=custom_css
126
  )
127
 
128
+ demo.launch()