waleko commited on
Commit
1121140
Β·
1 Parent(s): 0958450
Files changed (2) hide show
  1. app.py +23 -27
  2. model.py +2 -4
app.py CHANGED
@@ -4,14 +4,14 @@ import numpy as np
4
  from PIL import Image
5
  from model import CycleGAN, get_val_transform, de_normalize
6
 
7
- # Configure page
8
  st.set_page_config(
9
  page_title="CycleGAN Image Converter",
10
  page_icon="🎨",
11
  layout="wide"
12
  )
13
 
14
- # Get the best available device
15
  @st.cache_resource
16
  def get_device():
17
  if torch.cuda.is_available():
@@ -25,7 +25,7 @@ def get_device():
25
  st.sidebar.info("Using CPU πŸ’»")
26
  return device
27
 
28
- # Add custom CSS
29
  st.markdown("""
30
  <style>
31
  .stApp {
@@ -38,7 +38,7 @@ st.markdown("""
38
  </style>
39
  """, unsafe_allow_html=True)
40
 
41
- # Title and description
42
  st.title("CycleGAN Image Converter 🎨")
43
  st.markdown("""
44
  Transform images between different domains using CycleGAN.
@@ -47,7 +47,7 @@ st.markdown("""
47
  *Note: Images will be resized to 256x256 pixels during conversion.*
48
  """)
49
 
50
- # Available models and their configurations
51
  MODELS = [
52
  {
53
  "name": "Cezanne ↔ Photo",
@@ -57,25 +57,25 @@ MODELS = [
57
  }
58
  ]
59
 
60
- # Sidebar controls
61
  with st.sidebar:
62
  st.header("Settings")
63
 
64
- # Model selection
65
  selected_model = st.selectbox(
66
  "Conversion Type",
67
  options=range(len(MODELS)),
68
  format_func=lambda x: MODELS[x]["name"]
69
  )
70
 
71
- # Direction selection
72
  direction = st.radio(
73
  "Conversion Direction",
74
  options=["A β†’ B", "B β†’ A"],
75
  help="A β†’ B: Convert from domain A to B\nB β†’ A: Convert from domain B to A"
76
  )
77
 
78
- # Load model
79
  @st.cache_resource
80
  def load_model(model_path):
81
  device = get_device()
@@ -84,29 +84,21 @@ def load_model(model_path):
84
  model.eval()
85
  return model
86
 
87
- # Process image
88
  def process_image(image, model, direction):
89
- # Prepare transform
90
  transform = get_val_transform(model, direction)
91
-
92
- # Convert PIL image to tensor
93
  tensor = transform(np.array(image)).unsqueeze(0)
94
-
95
- # Move to appropriate device
96
  tensor = tensor.to(next(model.parameters()).device)
97
-
98
- # Process
99
  with torch.no_grad():
100
  if direction == "A β†’ B":
101
  output = model.generator_ab(tensor)
102
  else:
103
  output = model.generator_ba(tensor)
104
-
105
- # Convert back to image
106
  result = de_normalize(output[0], model, direction)
107
- return result
 
 
108
 
109
- # Main interface
110
  col1, col2 = st.columns(2)
111
 
112
  with col1:
@@ -115,19 +107,23 @@ with col1:
115
 
116
  if uploaded_file is not None:
117
  input_image = Image.open(uploaded_file)
118
- st.image(input_image, use_column_width=True)
 
 
 
119
 
120
  with col2:
121
  st.subheader("Converted Image")
122
  if uploaded_file is not None:
123
  try:
124
- # Load and process
125
  model = load_model(MODELS[selected_model]["model_path"])
126
  result = process_image(input_image, model, direction)
127
 
128
- # Display
129
- st.image(result, use_column_width=True)
130
  except Exception as e:
131
- st.error(f"Error during conversion: {str(e)}")
 
132
  else:
133
- st.info("Upload an image to see the conversion result")
 
4
  from PIL import Image
5
  from model import CycleGAN, get_val_transform, de_normalize
6
 
7
+
8
  st.set_page_config(
9
  page_title="CycleGAN Image Converter",
10
  page_icon="🎨",
11
  layout="wide"
12
  )
13
 
14
+
15
  @st.cache_resource
16
  def get_device():
17
  if torch.cuda.is_available():
 
25
  st.sidebar.info("Using CPU πŸ’»")
26
  return device
27
 
28
+
29
  st.markdown("""
30
  <style>
31
  .stApp {
 
38
  </style>
39
  """, unsafe_allow_html=True)
40
 
41
+
42
  st.title("CycleGAN Image Converter 🎨")
43
  st.markdown("""
44
  Transform images between different domains using CycleGAN.
 
47
  *Note: Images will be resized to 256x256 pixels during conversion.*
48
  """)
49
 
50
+
51
  MODELS = [
52
  {
53
  "name": "Cezanne ↔ Photo",
 
57
  }
58
  ]
59
 
60
+
61
  with st.sidebar:
62
  st.header("Settings")
63
 
64
+
65
  selected_model = st.selectbox(
66
  "Conversion Type",
67
  options=range(len(MODELS)),
68
  format_func=lambda x: MODELS[x]["name"]
69
  )
70
 
71
+
72
  direction = st.radio(
73
  "Conversion Direction",
74
  options=["A β†’ B", "B β†’ A"],
75
  help="A β†’ B: Convert from domain A to B\nB β†’ A: Convert from domain B to A"
76
  )
77
 
78
+
79
  @st.cache_resource
80
  def load_model(model_path):
81
  device = get_device()
 
84
  model.eval()
85
  return model
86
 
87
+
88
  def process_image(image, model, direction):
 
89
  transform = get_val_transform(model, direction)
 
 
90
  tensor = transform(np.array(image)).unsqueeze(0)
 
 
91
  tensor = tensor.to(next(model.parameters()).device)
 
 
92
  with torch.no_grad():
93
  if direction == "A β†’ B":
94
  output = model.generator_ab(tensor)
95
  else:
96
  output = model.generator_ba(tensor)
 
 
97
  result = de_normalize(output[0], model, direction)
98
+ image = Image.fromarray(result.cpu().detach().numpy(), 'RGB')
99
+ return image
100
+
101
 
 
102
  col1, col2 = st.columns(2)
103
 
104
  with col1:
 
107
 
108
  if uploaded_file is not None:
109
  input_image = Image.open(uploaded_file)
110
+
111
+ if input_image.mode == 'RGBA':
112
+ input_image = input_image.convert('RGB')
113
+ st.image(input_image)
114
 
115
  with col2:
116
  st.subheader("Converted Image")
117
  if uploaded_file is not None:
118
  try:
119
+
120
  model = load_model(MODELS[selected_model]["model_path"])
121
  result = process_image(input_image, model, direction)
122
 
123
+
124
+ st.image(result)
125
  except Exception as e:
126
+ st.error(f"Error during conversion: {str(e)} {e.__traceback__}")
127
+ raise
128
  else:
129
+ st.info("Upload an image to see the conversion result")
model.py CHANGED
@@ -77,7 +77,6 @@ class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
77
  def get_val_transform(model, direction="a_to_b", size=256):
78
  mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
79
  std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
80
-
81
  return tr.Compose([
82
  tr.ToPILImage(),
83
  tr.Resize(size),
@@ -87,9 +86,8 @@ def get_val_transform(model, direction="a_to_b", size=256):
87
  ])
88
 
89
  def de_normalize(tensor, model, direction="a_to_b"):
90
- img_tensor = tensor.cpu().detach().clone()
91
  mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
92
  std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
93
-
94
  img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
95
- return torch.clamp(img_tensor.permute(1, 2, 0), 0.0, 1.0)
 
77
  def get_val_transform(model, direction="a_to_b", size=256):
78
  mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
79
  std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
 
80
  return tr.Compose([
81
  tr.ToPILImage(),
82
  tr.Resize(size),
 
86
  ])
87
 
88
  def de_normalize(tensor, model, direction="a_to_b"):
89
+ img_tensor = tensor
90
  mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
91
  std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
 
92
  img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
93
+ return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8)