Tassawar commited on
Commit
e775794
·
verified ·
1 Parent(s): 99ecdb1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +98 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import AutoModelForImageSegmentation
6
+ import io
7
+ import os
8
+
9
+ # Set matmul precision (important for performance on some systems)
10
+ torch.set_float32_matmul_precision(["high", "highest"][0])
11
+
12
+ # Load the model (outside the function for efficiency)
13
+ @st.cache_resource # Cache the model to avoid reloading on every run
14
+ def load_model():
15
+ model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
16
+ model.to("cuda" if torch.cuda.is_available() else "cpu") # Use CUDA if available
17
+ return model
18
+
19
+ birefnet = load_model()
20
+
21
+ # Image transformation
22
+ transform_image = transforms.Compose([
23
+ transforms.Resize((1024, 1024)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
26
+ ])
27
+
28
+ @st.cache_data # Cache the processed images.
29
+ def process(image):
30
+ image_size = image.size
31
+ input_images = transform_image(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
32
+ with torch.no_grad():
33
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
34
+ pred = preds[0].squeeze()
35
+ pred_pil = transforms.ToPILImage()(pred)
36
+ mask = pred_pil.resize(image_size)
37
+ image.putalpha(mask)
38
+ return image
39
+
40
+ def process_file(uploaded_file):
41
+ try:
42
+ image = Image.open(uploaded_file).convert("RGB")
43
+ transparent = process(image)
44
+
45
+ # Convert to bytes for download
46
+ img_bytes = io.BytesIO()
47
+ transparent.save(img_bytes, format="PNG")
48
+ img_bytes = img_bytes.getvalue()
49
+
50
+ return img_bytes, transparent # Return bytes for download and PIL image for display
51
+
52
+ except Exception as e:
53
+ st.error(f"An error occurred: {e}")
54
+ return None, None
55
+
56
+
57
+ st.title("Background Removal Tool")
58
+
59
+ # Tabs for different input methods
60
+ tabs = ["Image Upload", "URL Input", "File Output"]
61
+ selected_tab = st.sidebar.radio("Select Input Method", tabs)
62
+
63
+ if selected_tab == "Image Upload":
64
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
65
+ if uploaded_file is not None:
66
+ image = Image.open(uploaded_file).convert("RGB")
67
+ processed_image = process(image)
68
+ st.image(processed_image, caption="Processed Image")
69
+
70
+ elif selected_tab == "URL Input":
71
+ image_url = st.text_input("Paste an image URL")
72
+ if image_url:
73
+ try:
74
+ import requests
75
+ from io import BytesIO
76
+ response = requests.get(image_url, stream=True)
77
+ response.raise_for_status() # Raise an exception for bad status codes
78
+ image = Image.open(BytesIO(response.content)).convert("RGB")
79
+ processed_image = process(image)
80
+ st.image(processed_image, caption="Processed Image from URL")
81
+ except requests.exceptions.RequestException as e:
82
+ st.error(f"Error fetching image from URL: {e}")
83
+ except Exception as e:
84
+ st.error(f"Error processing image: {e}")
85
+
86
+
87
+ elif selected_tab == "File Output":
88
+ uploaded_file = st.file_uploader("Upload an image for file output", type=["jpg", "jpeg", "png"])
89
+ if uploaded_file is not None:
90
+ file_bytes, processed_image = process_file(uploaded_file)
91
+ if file_bytes:
92
+ st.image(processed_image, caption="Processed Image") # Display the image
93
+ st.download_button(
94
+ label="Download PNG",
95
+ data=file_bytes,
96
+ file_name=f"{uploaded_file.name.rsplit('.', 1)[0]}.png",
97
+ mime="image/png",
98
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Pillow==11.1.0
2
+ Requests==2.32.3
3
+ streamlit==1.42.0
4
+ torch==2.6.0
5
+ torchvision==0.21.0
6
+ transformers==4.48.3