ugly-holiday-card-generator / pages /2_Generate_Multiple_Holiday_Borders.py
Mikiko Bazeley
Refactored and removed controlnet
57eccf2
# Generate_holiday_borders.py
import streamlit as st
import os
from PIL import Image
from io import BytesIO
import zipfile
from dotenv import load_dotenv
# Import from helper_utilities.py
from utils.helper_utilities import (
get_closest_aspect_ratio, process_image, generate_flux_image, # Replace ControlNet API call with Flux API call
draw_crop_preview, combine_images, get_next_largest_aspect_ratio
)
# Import from configuration.py
from utils.configuration import (
default_guidance_scale,
default_num_inference_steps, default_seed,
holiday_border_prompts
)
# Initialize session state
if 'uploaded_file' not in st.session_state:
st.session_state.uploaded_file = None
if 'card_params' not in st.session_state:
st.session_state.card_params = [{} for _ in range(4)]
if 'generated_cards' not in st.session_state:
st.session_state.generated_cards = [None for _ in range(4)]
# Streamlit app starts here
st.image("img/fireworksai_logo.png")
st.title("🎨 Holiday Multi-Card Generator🎨")
st.markdown(
"""Welcome to the first part of your holiday card creation journey! 🌟 Here, you'll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in the section 'Customize Holiday Borders'. Let your creativity flow! πŸŽ‰
### How it works:
1. **πŸ–ΌοΈ Upload Your Image:** Choose the image that will be the center of your card.
2. **βœ‚οΈ Crop It:** Adjust the crop to highlight the most important part of your image.
3. **πŸ’‘ Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique.
4. **βš™οΈ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic.
5. **πŸ‘€ Preview & Download:** See your generated holiday cards, tweak them until they're just right, and download the final designs and metadata in a neat ZIP file!
Once you've got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! πŸ’Œ
"""
)
# Load API Key
st.divider()
st.subheader("Load Fireworks API Key")
# Define and ensure the .env directory and file exist
dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env')
os.makedirs(os.path.dirname(dotenv_path), exist_ok=True)
# Create the .env file if it doesn't exist
if not os.path.exists(dotenv_path):
with open(dotenv_path, "w") as f:
st.success(f"Created {dotenv_path}")
# Load environment variables from the .env file
load_dotenv(dotenv_path, override=True)
# Check if the Fireworks API key is set or blank
fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
# Show the entire app but disable running parts if no API key
if not fireworks_api_key or fireworks_api_key.strip() == "":
fireworks_api_key = st.text_input("Enter Fireworks API Key", type="password")
# Optionally, allow the user to save the API key to the .env file
if fireworks_api_key and st.checkbox("Save API key for future use"):
with open(dotenv_path, "a") as f:
f.write(f"FIREWORKS_API_KEY={fireworks_api_key}\n")
st.success("API key saved to .env file.")
else:
st.success(f"API key loaded successfully: partial preview {fireworks_api_key[:5]}")
# Step 1: Upload Image
st.divider()
st.subheader("πŸ–ΌοΈ Step 1: Upload Your Picture!")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.session_state.uploaded_file = uploaded_file
original_image = Image.open(uploaded_file)
img_width, img_height = original_image.size
# Calculate the next largest valid aspect ratio
aspect_ratio = get_next_largest_aspect_ratio(img_width, img_height) # Ensure the aspect ratio is valid
st.image(original_image, caption="Uploaded Image", use_column_width=True)
# Step 2: Crop Image
st.divider()
st.subheader("βœ‚οΈ Step 2: Crop It Like It's Hot!")
img_width, img_height = original_image.size
col1, col2 = st.columns(2)
with col1:
x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4)
crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos))
with col2:
y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4)
crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos))
preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
st.image(preview_image, caption="Crop Preview", use_column_width=True)
# Step 3: Set Card Parameters
st.divider()
st.subheader("βš™οΈ Step 3: Set Your Festive Border Design with Flux + Fireworks!")
for i in range(4):
with st.expander(f"Holiday Card {i + 1} Parameters"):
card_params = st.session_state.card_params[i]
# Set default values for card parameters if not already set
card_params.setdefault("prompt", holiday_border_prompts[i % len(holiday_border_prompts)]) # Set default from holiday prompts
card_params.setdefault("guidance_scale", default_guidance_scale)
card_params.setdefault("num_inference_steps", default_num_inference_steps)
card_params.setdefault("seed", i * 100)
selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_border_prompts)
custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt
# Allow the user to tweak other parameters
guidance_scale = st.slider(
f"Guidance Scale for Holiday Card {i + 1}",
min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1
)
num_inference_steps = st.slider(
f"Number of Inference Steps for Holiday Card {i + 1}",
min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1
)
seed = st.slider(
f"Random Seed for Holiday Card {i + 1}",
min_value=0, max_value=1000, value=card_params["seed"]
)
st.session_state.card_params[i] = {
"prompt": custom_prompt,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"seed": seed
}
# Generate Holiday Cards
st.divider()
st.subheader("Preview and Share the Holiday Cheer! πŸŽ…πŸ“¬")
st.markdown("""
Click "Generate Image" and watch the magic happen! Your holiday card is just moments away from spreading joy to everyone on your list. πŸŽ„πŸŽβœ¨
""")
# Disable the generate button if the API key is missing
if not fireworks_api_key or fireworks_api_key.strip() == "":
st.warning("Enter a valid Fireworks API key to enable card generation.")
generate_button = st.button("Generate Holiday Cards", disabled=True)
else:
generate_button = st.button("Generate Holiday Cards")
if generate_button:
with st.spinner("Processing..."):
cols = st.columns(4)
image_files = []
metadata = []
for i, params in enumerate(st.session_state.card_params):
# Generate image using Flux API with the next largest valid aspect ratio
generated_image = generate_flux_image(
model_path="flux-1-schnell-fp8",
prompt=params['prompt'],
steps=params['num_inference_steps'],
guidance_scale=params['guidance_scale'],
seed=params['seed'],
api_key=fireworks_api_key,
aspect_ratio=f"{aspect_ratio[0]}:{aspect_ratio[1]}" # Ensure aspect ratio is passed as a string in "width:height" format
)
if generated_image:
generated_image = generated_image.resize(original_image.size)
# Center the cropped original image onto the generated image
cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
flux_width, flux_height = generated_image.size
cropped_width, cropped_height = cropped_original.size
center_x = (flux_width - cropped_width) // 2
center_y = (flux_height - cropped_height) // 2
final_image = generated_image.copy()
final_image.paste(cropped_original, (center_x, center_y))
# Save final image and metadata
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))
metadata.append({
"Card": f"Holiday Card {i + 1}",
"Prompt": params['prompt'],
"Guidance Scale": params['guidance_scale'],
"Inference Steps": params['num_inference_steps'],
"Seed": params['seed']
})
st.session_state.generated_cards[i] = {
"image": final_image,
"metadata": metadata[-1]
}
# Display the final holiday card
cols[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
cols[i].write(f"**Prompt:** {params['prompt']}")
cols[i].write(f"**Guidance Scale:** {params['guidance_scale']}")
cols[i].write(f"**Inference Steps:** {params['num_inference_steps']}")
cols[i].write(f"**Seed:** {params['seed']}")
else:
st.error(f"Failed to generate holiday card {i + 1}. Please try again.")
# Create the ZIP file with all images and metadata
if image_files:
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
for file_name, img_data in image_files:
zf.writestr(file_name, img_data.getvalue())
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}" for m in metadata])
zf.writestr("metadata.txt", metadata_str)
zip_buffer.seek(0)
# Single download button for all images and metadata
st.download_button(
label="Download all images and metadata as ZIP",
data=zip_buffer,
file_name="holiday_cards.zip",
mime="application/zip"
)
# Footer Section
st.divider()
st.markdown(
"""
Thank you for using the Holiday Card Generator powered by **Fireworks**! πŸŽ‰
Share your creations with the world and spread the holiday cheer!
Happy Holidays from the **Fireworks Team**. πŸ’₯
"""
)