Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
from PIL import Image | |
from model import CycleGAN, get_val_transform, de_normalize | |
# Configure page | |
st.set_page_config( | |
page_title="CycleGAN Image Converter", | |
page_icon="π¨", | |
layout="wide" | |
) | |
# Get the best available device | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.sidebar.success("Using GPU π") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
st.sidebar.success("Using Apple Silicon π") | |
else: | |
device = torch.device("cpu") | |
st.sidebar.info("Using CPU π»") | |
return device | |
# Add custom CSS | |
st.markdown(""" | |
<style> | |
.stApp { | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
.main { | |
padding: 2rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Title and description | |
st.title("CycleGAN Image Converter π¨") | |
st.markdown(""" | |
Transform images between different domains using CycleGAN. | |
Upload an image and see it converted in real-time! | |
*Note: Images will be resized to 256x256 pixels during conversion.* | |
""") | |
# Available models and their configurations | |
MODELS = [ | |
{ | |
"name": "Cezanne β Photo", | |
"id": "cezanne2photo", | |
"model_path": "waleko/cyclegan", | |
"description": "Convert between Cezanne's painting style and photographs" | |
} | |
] | |
# Sidebar controls | |
with st.sidebar: | |
st.header("Settings") | |
# Model selection | |
selected_model = st.selectbox( | |
"Conversion Type", | |
options=range(len(MODELS)), | |
format_func=lambda x: MODELS[x]["name"] | |
) | |
# Direction selection | |
direction = st.radio( | |
"Conversion Direction", | |
options=["A β B", "B β A"], | |
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A" | |
) | |
# Load model | |
def load_model(model_path): | |
device = get_device() | |
model = CycleGAN.from_pretrained(model_path) | |
model = model.to(device) | |
model.eval() | |
return model | |
# Process image | |
def process_image(image, model, direction): | |
# Prepare transform | |
transform = get_val_transform(model, direction) | |
# Convert PIL image to tensor | |
tensor = transform(np.array(image)).unsqueeze(0) | |
# Move to appropriate device | |
tensor = tensor.to(next(model.parameters()).device) | |
# Process | |
with torch.no_grad(): | |
if direction == "A β B": | |
output = model.generator_ab(tensor) | |
else: | |
output = model.generator_ba(tensor) | |
# Convert back to image | |
result = de_normalize(output[0], model, direction) | |
return result | |
# Main interface | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Input Image") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
input_image = Image.open(uploaded_file) | |
st.image(input_image, use_column_width=True) | |
with col2: | |
st.subheader("Converted Image") | |
if uploaded_file is not None: | |
try: | |
# Load and process | |
model = load_model(MODELS[selected_model]["model_path"]) | |
result = process_image(input_image, model, direction) | |
# Display | |
st.image(result, use_column_width=True) | |
except Exception as e: | |
st.error(f"Error during conversion: {str(e)}") | |
else: | |
st.info("Upload an image to see the conversion result") |