Spaces:
Sleeping
Sleeping
File size: 3,545 Bytes
7cf938c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
import torch
import numpy as np
from PIL import Image
from model import CycleGAN, get_val_transform, de_normalize
# Configure page
st.set_page_config(
page_title="CycleGAN Image Converter",
page_icon="π¨",
layout="wide"
)
# Get the best available device
@st.cache_resource
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
st.sidebar.success("Using GPU π")
elif torch.backends.mps.is_available():
device = torch.device("mps")
st.sidebar.success("Using Apple Silicon π")
else:
device = torch.device("cpu")
st.sidebar.info("Using CPU π»")
return device
# Add custom CSS
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.main {
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
# Title and description
st.title("CycleGAN Image Converter π¨")
st.markdown("""
Transform images between different domains using CycleGAN.
Upload an image and see it converted in real-time!
*Note: Images will be resized to 256x256 pixels during conversion.*
""")
# Available models and their configurations
MODELS = [
{
"name": "Cezanne β Photo",
"id": "cezanne2photo",
"model_path": "waleko/cyclegan",
"description": "Convert between Cezanne's painting style and photographs"
}
]
# Sidebar controls
with st.sidebar:
st.header("Settings")
# Model selection
selected_model = st.selectbox(
"Conversion Type",
options=range(len(MODELS)),
format_func=lambda x: MODELS[x]["name"]
)
# Direction selection
direction = st.radio(
"Conversion Direction",
options=["A β B", "B β A"],
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A"
)
# Load model
@st.cache_resource
def load_model(model_path):
device = get_device()
model = CycleGAN.from_pretrained(model_path)
model = model.to(device)
model.eval()
return model
# Process image
def process_image(image, model, direction):
# Prepare transform
transform = get_val_transform(model, direction)
# Convert PIL image to tensor
tensor = transform(np.array(image)).unsqueeze(0)
# Move to appropriate device
tensor = tensor.to(next(model.parameters()).device)
# Process
with torch.no_grad():
if direction == "A β B":
output = model.generator_ab(tensor)
else:
output = model.generator_ba(tensor)
# Convert back to image
result = de_normalize(output[0], model, direction)
return result
# Main interface
col1, col2 = st.columns(2)
with col1:
st.subheader("Input Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
input_image = Image.open(uploaded_file)
st.image(input_image, use_column_width=True)
with col2:
st.subheader("Converted Image")
if uploaded_file is not None:
try:
# Load and process
model = load_model(MODELS[selected_model]["model_path"])
result = process_image(input_image, model, direction)
# Display
st.image(result, use_column_width=True)
except Exception as e:
st.error(f"Error during conversion: {str(e)}")
else:
st.info("Upload an image to see the conversion result") |