Spaces:
Sleeping
Sleeping
fix bugs
Browse files
app.py
CHANGED
@@ -4,14 +4,14 @@ import numpy as np
|
|
4 |
from PIL import Image
|
5 |
from model import CycleGAN, get_val_transform, de_normalize
|
6 |
|
7 |
-
|
8 |
st.set_page_config(
|
9 |
page_title="CycleGAN Image Converter",
|
10 |
page_icon="π¨",
|
11 |
layout="wide"
|
12 |
)
|
13 |
|
14 |
-
|
15 |
@st.cache_resource
|
16 |
def get_device():
|
17 |
if torch.cuda.is_available():
|
@@ -25,7 +25,7 @@ def get_device():
|
|
25 |
st.sidebar.info("Using CPU π»")
|
26 |
return device
|
27 |
|
28 |
-
|
29 |
st.markdown("""
|
30 |
<style>
|
31 |
.stApp {
|
@@ -38,7 +38,7 @@ st.markdown("""
|
|
38 |
</style>
|
39 |
""", unsafe_allow_html=True)
|
40 |
|
41 |
-
|
42 |
st.title("CycleGAN Image Converter π¨")
|
43 |
st.markdown("""
|
44 |
Transform images between different domains using CycleGAN.
|
@@ -47,7 +47,7 @@ st.markdown("""
|
|
47 |
*Note: Images will be resized to 256x256 pixels during conversion.*
|
48 |
""")
|
49 |
|
50 |
-
|
51 |
MODELS = [
|
52 |
{
|
53 |
"name": "Cezanne β Photo",
|
@@ -57,25 +57,25 @@ MODELS = [
|
|
57 |
}
|
58 |
]
|
59 |
|
60 |
-
|
61 |
with st.sidebar:
|
62 |
st.header("Settings")
|
63 |
|
64 |
-
|
65 |
selected_model = st.selectbox(
|
66 |
"Conversion Type",
|
67 |
options=range(len(MODELS)),
|
68 |
format_func=lambda x: MODELS[x]["name"]
|
69 |
)
|
70 |
|
71 |
-
|
72 |
direction = st.radio(
|
73 |
"Conversion Direction",
|
74 |
options=["A β B", "B β A"],
|
75 |
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A"
|
76 |
)
|
77 |
|
78 |
-
|
79 |
@st.cache_resource
|
80 |
def load_model(model_path):
|
81 |
device = get_device()
|
@@ -84,29 +84,21 @@ def load_model(model_path):
|
|
84 |
model.eval()
|
85 |
return model
|
86 |
|
87 |
-
|
88 |
def process_image(image, model, direction):
|
89 |
-
# Prepare transform
|
90 |
transform = get_val_transform(model, direction)
|
91 |
-
|
92 |
-
# Convert PIL image to tensor
|
93 |
tensor = transform(np.array(image)).unsqueeze(0)
|
94 |
-
|
95 |
-
# Move to appropriate device
|
96 |
tensor = tensor.to(next(model.parameters()).device)
|
97 |
-
|
98 |
-
# Process
|
99 |
with torch.no_grad():
|
100 |
if direction == "A β B":
|
101 |
output = model.generator_ab(tensor)
|
102 |
else:
|
103 |
output = model.generator_ba(tensor)
|
104 |
-
|
105 |
-
# Convert back to image
|
106 |
result = de_normalize(output[0], model, direction)
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
# Main interface
|
110 |
col1, col2 = st.columns(2)
|
111 |
|
112 |
with col1:
|
@@ -115,19 +107,23 @@ with col1:
|
|
115 |
|
116 |
if uploaded_file is not None:
|
117 |
input_image = Image.open(uploaded_file)
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
with col2:
|
121 |
st.subheader("Converted Image")
|
122 |
if uploaded_file is not None:
|
123 |
try:
|
124 |
-
|
125 |
model = load_model(MODELS[selected_model]["model_path"])
|
126 |
result = process_image(input_image, model, direction)
|
127 |
|
128 |
-
|
129 |
-
st.image(result
|
130 |
except Exception as e:
|
131 |
-
st.error(f"Error during conversion: {str(e)}")
|
|
|
132 |
else:
|
133 |
-
st.info("Upload an image to see the conversion result")
|
|
|
4 |
from PIL import Image
|
5 |
from model import CycleGAN, get_val_transform, de_normalize
|
6 |
|
7 |
+
|
8 |
st.set_page_config(
|
9 |
page_title="CycleGAN Image Converter",
|
10 |
page_icon="π¨",
|
11 |
layout="wide"
|
12 |
)
|
13 |
|
14 |
+
|
15 |
@st.cache_resource
|
16 |
def get_device():
|
17 |
if torch.cuda.is_available():
|
|
|
25 |
st.sidebar.info("Using CPU π»")
|
26 |
return device
|
27 |
|
28 |
+
|
29 |
st.markdown("""
|
30 |
<style>
|
31 |
.stApp {
|
|
|
38 |
</style>
|
39 |
""", unsafe_allow_html=True)
|
40 |
|
41 |
+
|
42 |
st.title("CycleGAN Image Converter π¨")
|
43 |
st.markdown("""
|
44 |
Transform images between different domains using CycleGAN.
|
|
|
47 |
*Note: Images will be resized to 256x256 pixels during conversion.*
|
48 |
""")
|
49 |
|
50 |
+
|
51 |
MODELS = [
|
52 |
{
|
53 |
"name": "Cezanne β Photo",
|
|
|
57 |
}
|
58 |
]
|
59 |
|
60 |
+
|
61 |
with st.sidebar:
|
62 |
st.header("Settings")
|
63 |
|
64 |
+
|
65 |
selected_model = st.selectbox(
|
66 |
"Conversion Type",
|
67 |
options=range(len(MODELS)),
|
68 |
format_func=lambda x: MODELS[x]["name"]
|
69 |
)
|
70 |
|
71 |
+
|
72 |
direction = st.radio(
|
73 |
"Conversion Direction",
|
74 |
options=["A β B", "B β A"],
|
75 |
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A"
|
76 |
)
|
77 |
|
78 |
+
|
79 |
@st.cache_resource
|
80 |
def load_model(model_path):
|
81 |
device = get_device()
|
|
|
84 |
model.eval()
|
85 |
return model
|
86 |
|
87 |
+
|
88 |
def process_image(image, model, direction):
|
|
|
89 |
transform = get_val_transform(model, direction)
|
|
|
|
|
90 |
tensor = transform(np.array(image)).unsqueeze(0)
|
|
|
|
|
91 |
tensor = tensor.to(next(model.parameters()).device)
|
|
|
|
|
92 |
with torch.no_grad():
|
93 |
if direction == "A β B":
|
94 |
output = model.generator_ab(tensor)
|
95 |
else:
|
96 |
output = model.generator_ba(tensor)
|
|
|
|
|
97 |
result = de_normalize(output[0], model, direction)
|
98 |
+
image = Image.fromarray(result.cpu().detach().numpy(), 'RGB')
|
99 |
+
return image
|
100 |
+
|
101 |
|
|
|
102 |
col1, col2 = st.columns(2)
|
103 |
|
104 |
with col1:
|
|
|
107 |
|
108 |
if uploaded_file is not None:
|
109 |
input_image = Image.open(uploaded_file)
|
110 |
+
|
111 |
+
if input_image.mode == 'RGBA':
|
112 |
+
input_image = input_image.convert('RGB')
|
113 |
+
st.image(input_image)
|
114 |
|
115 |
with col2:
|
116 |
st.subheader("Converted Image")
|
117 |
if uploaded_file is not None:
|
118 |
try:
|
119 |
+
|
120 |
model = load_model(MODELS[selected_model]["model_path"])
|
121 |
result = process_image(input_image, model, direction)
|
122 |
|
123 |
+
|
124 |
+
st.image(result)
|
125 |
except Exception as e:
|
126 |
+
st.error(f"Error during conversion: {str(e)} {e.__traceback__}")
|
127 |
+
raise
|
128 |
else:
|
129 |
+
st.info("Upload an image to see the conversion result")
|
model.py
CHANGED
@@ -77,7 +77,6 @@ class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
|
|
77 |
def get_val_transform(model, direction="a_to_b", size=256):
|
78 |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
79 |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
80 |
-
|
81 |
return tr.Compose([
|
82 |
tr.ToPILImage(),
|
83 |
tr.Resize(size),
|
@@ -87,9 +86,8 @@ def get_val_transform(model, direction="a_to_b", size=256):
|
|
87 |
])
|
88 |
|
89 |
def de_normalize(tensor, model, direction="a_to_b"):
|
90 |
-
img_tensor = tensor
|
91 |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
92 |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
93 |
-
|
94 |
img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
|
95 |
-
return torch.clamp(img_tensor.permute(1, 2, 0), 0.0,
|
|
|
77 |
def get_val_transform(model, direction="a_to_b", size=256):
|
78 |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
79 |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
|
|
80 |
return tr.Compose([
|
81 |
tr.ToPILImage(),
|
82 |
tr.Resize(size),
|
|
|
86 |
])
|
87 |
|
88 |
def de_normalize(tensor, model, direction="a_to_b"):
|
89 |
+
img_tensor = tensor
|
90 |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
91 |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
|
|
92 |
img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
|
93 |
+
return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8)
|