# Generate_holiday_borders.py import streamlit as st import os from PIL import Image from io import BytesIO import zipfile from dotenv import load_dotenv # Import from helper_utilities.py from utils.helper_utilities import ( get_closest_aspect_ratio, process_image, generate_flux_image, # Replace ControlNet API call with Flux API call draw_crop_preview, combine_images, get_next_largest_aspect_ratio ) # Import from configuration.py from utils.configuration import ( default_guidance_scale, default_num_inference_steps, default_seed, holiday_border_prompts ) # Initialize session state if 'uploaded_file' not in st.session_state: st.session_state.uploaded_file = None if 'card_params' not in st.session_state: st.session_state.card_params = [{} for _ in range(4)] if 'generated_cards' not in st.session_state: st.session_state.generated_cards = [None for _ in range(4)] # Streamlit app starts here st.image("img/fireworksai_logo.png") st.title("🎨 Holiday Multi-Card Generator🎨") st.markdown( """Welcome to the first part of your holiday card creation journey! 🌟 Here, you'll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in the section 'Customize Holiday Borders'. Let your creativity flow! πŸŽ‰ ### How it works: 1. **πŸ–ΌοΈ Upload Your Image:** Choose the image that will be the center of your card. 2. **βœ‚οΈ Crop It:** Adjust the crop to highlight the most important part of your image. 3. **πŸ’‘ Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique. 4. **βš™οΈ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic. 5. **πŸ‘€ Preview & Download:** See your generated holiday cards, tweak them until they're just right, and download the final designs and metadata in a neat ZIP file! Once you've got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! πŸ’Œ """ ) # Load API Key st.divider() st.subheader("Load Fireworks API Key") # Define and ensure the .env directory and file exist dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env') os.makedirs(os.path.dirname(dotenv_path), exist_ok=True) # Create the .env file if it doesn't exist if not os.path.exists(dotenv_path): with open(dotenv_path, "w") as f: st.success(f"Created {dotenv_path}") # Load environment variables from the .env file load_dotenv(dotenv_path, override=True) # Check if the Fireworks API key is set or blank fireworks_api_key = os.getenv("FIREWORKS_API_KEY") # Show the entire app but disable running parts if no API key if not fireworks_api_key or fireworks_api_key.strip() == "": fireworks_api_key = st.text_input("Enter Fireworks API Key", type="password") # Optionally, allow the user to save the API key to the .env file if fireworks_api_key and st.checkbox("Save API key for future use"): with open(dotenv_path, "a") as f: f.write(f"FIREWORKS_API_KEY={fireworks_api_key}\n") st.success("API key saved to .env file.") else: st.success(f"API key loaded successfully: partial preview {fireworks_api_key[:5]}") # Step 1: Upload Image st.divider() st.subheader("πŸ–ΌοΈ Step 1: Upload Your Picture!") uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: st.session_state.uploaded_file = uploaded_file original_image = Image.open(uploaded_file) img_width, img_height = original_image.size # Calculate the next largest valid aspect ratio aspect_ratio = get_next_largest_aspect_ratio(img_width, img_height) # Ensure the aspect ratio is valid st.image(original_image, caption="Uploaded Image", use_column_width=True) # Step 2: Crop Image st.divider() st.subheader("βœ‚οΈ Step 2: Crop It Like It's Hot!") img_width, img_height = original_image.size col1, col2 = st.columns(2) with col1: x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4) crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos)) with col2: y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4) crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos)) preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height) st.image(preview_image, caption="Crop Preview", use_column_width=True) # Step 3: Set Card Parameters st.divider() st.subheader("βš™οΈ Step 3: Set Your Festive Border Design with Flux + Fireworks!") for i in range(4): with st.expander(f"Holiday Card {i + 1} Parameters"): card_params = st.session_state.card_params[i] # Set default values for card parameters if not already set card_params.setdefault("prompt", holiday_border_prompts[i % len(holiday_border_prompts)]) # Set default from holiday prompts card_params.setdefault("guidance_scale", default_guidance_scale) card_params.setdefault("num_inference_steps", default_num_inference_steps) card_params.setdefault("seed", i * 100) selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_border_prompts) custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt # Allow the user to tweak other parameters guidance_scale = st.slider( f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1 ) num_inference_steps = st.slider( f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1 ) seed = st.slider( f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=card_params["seed"] ) st.session_state.card_params[i] = { "prompt": custom_prompt, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "seed": seed } # Generate Holiday Cards st.divider() st.subheader("Preview and Share the Holiday Cheer! πŸŽ…πŸ“¬") st.markdown(""" Click "Generate Image" and watch the magic happen! Your holiday card is just moments away from spreading joy to everyone on your list. πŸŽ„πŸŽβœ¨ """) # Disable the generate button if the API key is missing if not fireworks_api_key or fireworks_api_key.strip() == "": st.warning("Enter a valid Fireworks API key to enable card generation.") generate_button = st.button("Generate Holiday Cards", disabled=True) else: generate_button = st.button("Generate Holiday Cards") if generate_button: with st.spinner("Processing..."): cols = st.columns(4) image_files = [] metadata = [] for i, params in enumerate(st.session_state.card_params): # Generate image using Flux API with the next largest valid aspect ratio generated_image = generate_flux_image( model_path="flux-1-schnell-fp8", prompt=params['prompt'], steps=params['num_inference_steps'], guidance_scale=params['guidance_scale'], seed=params['seed'], api_key=fireworks_api_key, aspect_ratio=f"{aspect_ratio[0]}:{aspect_ratio[1]}" # Ensure aspect ratio is passed as a string in "width:height" format ) if generated_image: generated_image = generated_image.resize(original_image.size) # Center the cropped original image onto the generated image cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height)) flux_width, flux_height = generated_image.size cropped_width, cropped_height = cropped_original.size center_x = (flux_width - cropped_width) // 2 center_y = (flux_height - cropped_height) // 2 final_image = generated_image.copy() final_image.paste(cropped_original, (center_x, center_y)) # Save final image and metadata img_byte_arr = BytesIO() final_image.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) metadata.append({ "Card": f"Holiday Card {i + 1}", "Prompt": params['prompt'], "Guidance Scale": params['guidance_scale'], "Inference Steps": params['num_inference_steps'], "Seed": params['seed'] }) st.session_state.generated_cards[i] = { "image": final_image, "metadata": metadata[-1] } # Display the final holiday card cols[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True) cols[i].write(f"**Prompt:** {params['prompt']}") cols[i].write(f"**Guidance Scale:** {params['guidance_scale']}") cols[i].write(f"**Inference Steps:** {params['num_inference_steps']}") cols[i].write(f"**Seed:** {params['seed']}") else: st.error(f"Failed to generate holiday card {i + 1}. Please try again.") # Create the ZIP file with all images and metadata if image_files: zip_buffer = BytesIO() with zipfile.ZipFile(zip_buffer, "w") as zf: for file_name, img_data in image_files: zf.writestr(file_name, img_data.getvalue()) metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}" for m in metadata]) zf.writestr("metadata.txt", metadata_str) zip_buffer.seek(0) # Single download button for all images and metadata st.download_button( label="Download all images and metadata as ZIP", data=zip_buffer, file_name="holiday_cards.zip", mime="application/zip" ) # Footer Section st.divider() st.markdown( """ Thank you for using the Holiday Card Generator powered by **Fireworks**! πŸŽ‰ Share your creations with the world and spread the holiday cheer! Happy Holidays from the **Fireworks Team**. πŸ’₯ """ )