File size: 3,545 Bytes
7cf938c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import torch
import numpy as np
from PIL import Image
from model import CycleGAN, get_val_transform, de_normalize

# Configure page
st.set_page_config(
    page_title="CycleGAN Image Converter",
    page_icon="🎨",
    layout="wide"
)

# Get the best available device
@st.cache_resource
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        st.sidebar.success("Using GPU πŸš€")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        st.sidebar.success("Using Apple Silicon 🍎")
    else:
        device = torch.device("cpu")
        st.sidebar.info("Using CPU πŸ’»")
    return device

# Add custom CSS
st.markdown("""
    <style>
    .stApp {
        max-width: 1200px;
        margin: 0 auto;
    }
    .main {
        padding: 2rem;
    }
    </style>
""", unsafe_allow_html=True)

# Title and description
st.title("CycleGAN Image Converter 🎨")
st.markdown("""
    Transform images between different domains using CycleGAN.
    Upload an image and see it converted in real-time!
    
    *Note: Images will be resized to 256x256 pixels during conversion.*
""")

# Available models and their configurations
MODELS = [
    {
        "name": "Cezanne ↔ Photo",
        "id": "cezanne2photo",
        "model_path": "waleko/cyclegan",
        "description": "Convert between Cezanne's painting style and photographs"
    }
]

# Sidebar controls
with st.sidebar:
    st.header("Settings")
    
    # Model selection
    selected_model = st.selectbox(
        "Conversion Type",
        options=range(len(MODELS)),
        format_func=lambda x: MODELS[x]["name"]
    )
    
    # Direction selection
    direction = st.radio(
        "Conversion Direction",
        options=["A β†’ B", "B β†’ A"],
        help="A β†’ B: Convert from domain A to B\nB β†’ A: Convert from domain B to A"
    )

# Load model
@st.cache_resource
def load_model(model_path):
    device = get_device()
    model = CycleGAN.from_pretrained(model_path)
    model = model.to(device)
    model.eval()
    return model

# Process image
def process_image(image, model, direction):
    # Prepare transform
    transform = get_val_transform(model, direction)
    
    # Convert PIL image to tensor
    tensor = transform(np.array(image)).unsqueeze(0)
    
    # Move to appropriate device
    tensor = tensor.to(next(model.parameters()).device)
    
    # Process
    with torch.no_grad():
        if direction == "A β†’ B":
            output = model.generator_ab(tensor)
        else:
            output = model.generator_ba(tensor)
    
    # Convert back to image
    result = de_normalize(output[0], model, direction)
    return result

# Main interface
col1, col2 = st.columns(2)

with col1:
    st.subheader("Input Image")
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        input_image = Image.open(uploaded_file)
        st.image(input_image, use_column_width=True)

with col2:
    st.subheader("Converted Image")
    if uploaded_file is not None:
        try:
            # Load and process
            model = load_model(MODELS[selected_model]["model_path"])
            result = process_image(input_image, model, direction)
            
            # Display
            st.image(result, use_column_width=True)
        except Exception as e:
            st.error(f"Error during conversion: {str(e)}")
    else:
        st.info("Upload an image to see the conversion result")