Tahir5 commited on
Commit
dfa48cf
·
verified ·
1 Parent(s): 63c2909

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torch import nn
4
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
5
+ from PIL import Image
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ def load_model():
10
+ """Load the Segformer model and processor."""
11
+ processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
12
+ model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
13
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
14
+ model.to(device)
15
+ return processor, model, device
16
+
17
+ def process_image(image: Image.Image, processor, model, device):
18
+ """Run inference on the image and return the segmentation mask."""
19
+ inputs = processor(images=image, return_tensors="pt").to(device)
20
+ outputs = model(**inputs)
21
+ logits = outputs.logits
22
+ upsampled_logits = nn.functional.interpolate(
23
+ logits, size=image.size[::-1], mode="bilinear", align_corners=False
24
+ )
25
+ labels = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
26
+ return labels
27
+
28
+ def visualize_segmentation(labels: np.ndarray):
29
+ """Visualize segmentation mask using Matplotlib."""
30
+ fig, ax = plt.subplots()
31
+ ax.imshow(labels, cmap="jet", alpha=0.7)
32
+ ax.axis("off")
33
+ st.pyplot(fig)
34
+
35
+ # Streamlit UI
36
+ st.title("Face Parsing using Segformer")
37
+ st.write("Upload an image to perform semantic segmentation on faces.")
38
+
39
+ # Load model
40
+ processor, model, device = load_model()
41
+
42
+ # File uploader
43
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
44
+
45
+ if uploaded_file:
46
+ image = Image.open(uploaded_file).convert("RGB")
47
+ st.image(image, caption="Uploaded Image", use_column_width=True)
48
+
49
+ # Process image
50
+ with st.spinner("Processing..."):
51
+ labels = process_image(image, processor, model, device)
52
+
53
+ # Display result
54
+ visualize_segmentation(labels)