|
|
|
|
|
import streamlit as st |
|
import os |
|
from PIL import Image |
|
from io import BytesIO |
|
import zipfile |
|
from dotenv import load_dotenv |
|
|
|
|
|
from utils.helper_utilities import ( |
|
get_closest_aspect_ratio, process_image, generate_flux_image, |
|
draw_crop_preview, combine_images, get_next_largest_aspect_ratio |
|
) |
|
|
|
|
|
from utils.configuration import ( |
|
default_guidance_scale, |
|
default_num_inference_steps, default_seed, |
|
holiday_border_prompts |
|
) |
|
|
|
|
|
|
|
|
|
if 'uploaded_file' not in st.session_state: |
|
st.session_state.uploaded_file = None |
|
if 'card_params' not in st.session_state: |
|
st.session_state.card_params = [{} for _ in range(4)] |
|
if 'generated_cards' not in st.session_state: |
|
st.session_state.generated_cards = [None for _ in range(4)] |
|
|
|
|
|
st.image("img/fireworksai_logo.png") |
|
st.title("π¨ Holiday Multi-Card Generatorπ¨") |
|
st.markdown( |
|
"""Welcome to the first part of your holiday card creation journey! π Here, you'll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in the section 'Customize Holiday Borders'. Let your creativity flow! π |
|
|
|
### How it works: |
|
|
|
1. **πΌοΈ Upload Your Image:** Choose the image that will be the center of your card. |
|
2. **βοΈ Crop It:** Adjust the crop to highlight the most important part of your image. |
|
3. **π‘ Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique. |
|
4. **βοΈ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic. |
|
5. **π Preview & Download:** See your generated holiday cards, tweak them until they're just right, and download the final designs and metadata in a neat ZIP file! |
|
|
|
Once you've got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! π |
|
""" |
|
) |
|
|
|
|
|
st.divider() |
|
st.subheader("Load Fireworks API Key") |
|
|
|
|
|
dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env') |
|
os.makedirs(os.path.dirname(dotenv_path), exist_ok=True) |
|
|
|
|
|
if not os.path.exists(dotenv_path): |
|
with open(dotenv_path, "w") as f: |
|
st.success(f"Created {dotenv_path}") |
|
|
|
|
|
load_dotenv(dotenv_path, override=True) |
|
|
|
|
|
fireworks_api_key = os.getenv("FIREWORKS_API_KEY") |
|
|
|
|
|
if not fireworks_api_key or fireworks_api_key.strip() == "": |
|
fireworks_api_key = st.text_input("Enter Fireworks API Key", type="password") |
|
|
|
|
|
if fireworks_api_key and st.checkbox("Save API key for future use"): |
|
with open(dotenv_path, "a") as f: |
|
f.write(f"FIREWORKS_API_KEY={fireworks_api_key}\n") |
|
st.success("API key saved to .env file.") |
|
else: |
|
st.success(f"API key loaded successfully: partial preview {fireworks_api_key[:5]}") |
|
|
|
|
|
st.divider() |
|
st.subheader("πΌοΈ Step 1: Upload Your Picture!") |
|
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
|
if uploaded_file is not None: |
|
st.session_state.uploaded_file = uploaded_file |
|
original_image = Image.open(uploaded_file) |
|
img_width, img_height = original_image.size |
|
|
|
|
|
aspect_ratio = get_next_largest_aspect_ratio(img_width, img_height) |
|
|
|
st.image(original_image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
st.divider() |
|
st.subheader("βοΈ Step 2: Crop It Like It's Hot!") |
|
img_width, img_height = original_image.size |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4) |
|
crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos)) |
|
with col2: |
|
y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4) |
|
crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos)) |
|
|
|
preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height) |
|
st.image(preview_image, caption="Crop Preview", use_column_width=True) |
|
|
|
|
|
st.divider() |
|
st.subheader("βοΈ Step 3: Set Your Festive Border Design with Flux + Fireworks!") |
|
for i in range(4): |
|
with st.expander(f"Holiday Card {i + 1} Parameters"): |
|
card_params = st.session_state.card_params[i] |
|
|
|
|
|
card_params.setdefault("prompt", holiday_border_prompts[i % len(holiday_border_prompts)]) |
|
card_params.setdefault("guidance_scale", default_guidance_scale) |
|
card_params.setdefault("num_inference_steps", default_num_inference_steps) |
|
card_params.setdefault("seed", i * 100) |
|
|
|
selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_border_prompts) |
|
custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt |
|
|
|
|
|
guidance_scale = st.slider( |
|
f"Guidance Scale for Holiday Card {i + 1}", |
|
min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1 |
|
) |
|
num_inference_steps = st.slider( |
|
f"Number of Inference Steps for Holiday Card {i + 1}", |
|
min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1 |
|
) |
|
seed = st.slider( |
|
f"Random Seed for Holiday Card {i + 1}", |
|
min_value=0, max_value=1000, value=card_params["seed"] |
|
) |
|
|
|
st.session_state.card_params[i] = { |
|
"prompt": custom_prompt, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"seed": seed |
|
} |
|
|
|
|
|
st.divider() |
|
st.subheader("Preview and Share the Holiday Cheer! π
π¬") |
|
st.markdown(""" |
|
Click "Generate Image" and watch the magic happen! Your holiday card is just moments away from spreading joy to everyone on your list. ππβ¨ |
|
""") |
|
|
|
|
|
if not fireworks_api_key or fireworks_api_key.strip() == "": |
|
st.warning("Enter a valid Fireworks API key to enable card generation.") |
|
generate_button = st.button("Generate Holiday Cards", disabled=True) |
|
else: |
|
generate_button = st.button("Generate Holiday Cards") |
|
|
|
if generate_button: |
|
with st.spinner("Processing..."): |
|
cols = st.columns(4) |
|
image_files = [] |
|
metadata = [] |
|
|
|
for i, params in enumerate(st.session_state.card_params): |
|
|
|
generated_image = generate_flux_image( |
|
model_path="flux-1-schnell-fp8", |
|
prompt=params['prompt'], |
|
steps=params['num_inference_steps'], |
|
guidance_scale=params['guidance_scale'], |
|
seed=params['seed'], |
|
api_key=fireworks_api_key, |
|
aspect_ratio=f"{aspect_ratio[0]}:{aspect_ratio[1]}" |
|
) |
|
|
|
if generated_image: |
|
generated_image = generated_image.resize(original_image.size) |
|
|
|
|
|
cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height)) |
|
flux_width, flux_height = generated_image.size |
|
cropped_width, cropped_height = cropped_original.size |
|
center_x = (flux_width - cropped_width) // 2 |
|
center_y = (flux_height - cropped_height) // 2 |
|
final_image = generated_image.copy() |
|
final_image.paste(cropped_original, (center_x, center_y)) |
|
|
|
|
|
img_byte_arr = BytesIO() |
|
final_image.save(img_byte_arr, format="PNG") |
|
img_byte_arr.seek(0) |
|
image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr)) |
|
|
|
metadata.append({ |
|
"Card": f"Holiday Card {i + 1}", |
|
"Prompt": params['prompt'], |
|
"Guidance Scale": params['guidance_scale'], |
|
"Inference Steps": params['num_inference_steps'], |
|
"Seed": params['seed'] |
|
}) |
|
|
|
st.session_state.generated_cards[i] = { |
|
"image": final_image, |
|
"metadata": metadata[-1] |
|
} |
|
|
|
|
|
cols[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True) |
|
cols[i].write(f"**Prompt:** {params['prompt']}") |
|
cols[i].write(f"**Guidance Scale:** {params['guidance_scale']}") |
|
cols[i].write(f"**Inference Steps:** {params['num_inference_steps']}") |
|
cols[i].write(f"**Seed:** {params['seed']}") |
|
else: |
|
st.error(f"Failed to generate holiday card {i + 1}. Please try again.") |
|
|
|
|
|
if image_files: |
|
zip_buffer = BytesIO() |
|
with zipfile.ZipFile(zip_buffer, "w") as zf: |
|
for file_name, img_data in image_files: |
|
zf.writestr(file_name, img_data.getvalue()) |
|
|
|
metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}" for m in metadata]) |
|
zf.writestr("metadata.txt", metadata_str) |
|
|
|
zip_buffer.seek(0) |
|
|
|
|
|
st.download_button( |
|
label="Download all images and metadata as ZIP", |
|
data=zip_buffer, |
|
file_name="holiday_cards.zip", |
|
mime="application/zip" |
|
) |
|
|
|
|
|
|
|
st.divider() |
|
st.markdown( |
|
""" |
|
Thank you for using the Holiday Card Generator powered by **Fireworks**! π |
|
Share your creations with the world and spread the holiday cheer! |
|
Happy Holidays from the **Fireworks Team**. π₯ |
|
""" |
|
) |