Mikiko Bazeley commited on
Commit
be47fca
·
1 Parent(s): f8c7b2e

Added multiple variant generator

Browse files
pages/1_Autocenter_HolidayCard.py CHANGED
@@ -89,6 +89,9 @@ if uploaded_file is not None:
89
  preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
90
  st.image(preview_image, caption="Crop Preview", use_column_width=True)
91
 
 
 
 
92
  holiday_prompts = [
93
  "A border of Festive snowflakes and winter patterns for a holiday card border",
94
  "A border of Joyful Christmas ornaments and lights decorating the edges",
@@ -97,75 +100,79 @@ if uploaded_file is not None:
97
  "A border of New Year's Eve fireworks with stars and confetti framing the image"
98
  ]
99
 
100
- selected_prompt = st.selectbox("Choose a holiday-themed prompt or enter your own", options=["Custom"] + holiday_prompts)
101
- custom_prompt = st.text_input("Enter your custom prompt") if selected_prompt == "Custom" else ""
102
- prompt = custom_prompt if selected_prompt == "Custom" else selected_prompt
103
-
104
- with st.expander("Advanced Parameters"):
105
- control_mode = st.slider("Control Mode", min_value=0, max_value=2, value=0)
106
- controlnet_conditioning_scale = st.slider("ControlNet Conditioning Scale", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
107
- guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=20.0, value=3.5, step=0.1)
108
- num_inference_steps = st.slider("Number of Inference Steps", min_value=1, max_value=100, value=30, step=1)
109
- seed = st.slider("Random Seed", min_value=0, max_value=1000, value=0)
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if st.button("Generate Holiday Cards"):
112
- if not prompt.strip():
113
- st.error("Please enter a prompt.")
114
- else:
115
- with st.spinner("Processing..."):
116
- # Define multiple sets of parameters for different holiday cards
117
- holiday_card_params = [
118
- {"prompt": "A border of Festive snowflakes and winter patterns for a holiday card border", "guidance_scale": 3.5, "num_inference_steps": 30, "seed": 0},
119
- {"prompt": "A border of Joyful Christmas ornaments and lights decorating the edges", "guidance_scale": 5.0, "num_inference_steps": 40, "seed": 123},
120
- {"prompt": "A border of Warm and cozy fireplace scene with stockings and garlands", "guidance_scale": 4.0, "num_inference_steps": 35, "seed": 456},
121
- {"prompt": "A border of New Year's Eve fireworks with stars and confetti framing the image", "guidance_scale": 6.0, "num_inference_steps": 50, "seed": 789}
122
- ]
 
 
 
123
 
124
- # Create a column layout for displaying cards side by side
125
- col1, col2 = st.columns(2)
126
- col3, col4 = st.columns(2)
127
-
128
- columns = [col1, col2, col3, col4] # To display images in a 2x2 grid
129
-
130
- # Loop through each parameter set and generate a holiday card
131
- for i, params in enumerate(holiday_card_params):
132
- prompt = params['prompt']
133
- guidance_scale = params['guidance_scale']
134
- num_inference_steps = params['num_inference_steps']
135
- seed = params['seed']
136
-
137
- # Generate the holiday card using the current parameters
138
- generated_image, processed_image, _ = call_control_net_api(
139
- uploaded_file, prompt, control_mode=control_mode,
140
- guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
141
- seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale
142
- )
143
-
144
- if generated_image is not None:
145
- # Resize generated_image to match original_image size
146
- generated_image = generated_image.resize(original_image.size)
147
-
148
- # Create a copy of the generated image
149
- final_image = generated_image.copy()
150
 
151
- # Crop the selected portion of the original image
152
- cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
 
153
 
154
- # Get the size of the cropped image
155
- cropped_width, cropped_height = cropped_original.size
156
 
157
- # Calculate the center of the generated image
158
- center_x = (final_image.width - cropped_width) // 2
159
- center_y = (final_image.height - cropped_height) // 2
160
 
161
- # Paste the cropped portion of the original image onto the generated image at the calculated center
162
- final_image.paste(cropped_original, (center_x, center_y))
163
 
164
- # Display the final holiday card in one of the columns
165
- columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
166
- else:
167
- st.error(f"Failed to generate holiday card {i + 1}. Please try again.")
168
 
 
 
169
 
 
 
 
 
170
  else:
171
- st.warning("Please upload an image to get started.")
 
89
  preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
90
  st.image(preview_image, caption="Crop Preview", use_column_width=True)
91
 
92
+ st.subheader("Set Parameters for Each Holiday Card")
93
+
94
+ # Define the list of suggested holiday prompts
95
  holiday_prompts = [
96
  "A border of Festive snowflakes and winter patterns for a holiday card border",
97
  "A border of Joyful Christmas ornaments and lights decorating the edges",
 
100
  "A border of New Year's Eve fireworks with stars and confetti framing the image"
101
  ]
102
 
103
+ # Define input fields for each holiday card's parameters
104
+ card_params = []
105
+ for i in range(4):
106
+ st.write(f"### Holiday Card {i + 1}")
107
+
108
+ # Dropdown to choose a suggested holiday prompt or enter custom prompt
109
+ selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts)
110
+ custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=f"Custom Prompt {i + 1}") if selected_prompt == "Custom" else selected_prompt
111
+
112
+ # Parameter sliders for each holiday card
113
+ guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=3.5, step=0.1)
114
+ num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=30, step=1)
115
+ seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=i * 100)
116
+ controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
117
+ control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=0, help="0: None, 1: Partial, 2: Full")
118
+
119
+ # Save the parameters for each holiday card
120
+ card_params.append({
121
+ "prompt": custom_prompt,
122
+ "guidance_scale": guidance_scale,
123
+ "num_inference_steps": num_inference_steps,
124
+ "seed": seed,
125
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
126
+ "control_mode": control_mode
127
+ })
128
+
129
+ # Generate the holiday cards
130
  if st.button("Generate Holiday Cards"):
131
+ with st.spinner("Processing..."):
132
+ # Create a column layout for displaying cards side by side
133
+ col1, col2 = st.columns(2)
134
+ col3, col4 = st.columns(2)
135
+ columns = [col1, col2, col3, col4] # To display images in a 2x2 grid
136
+
137
+ # Loop through each card's parameters and generate the holiday card
138
+ for i, params in enumerate(card_params):
139
+ prompt = params['prompt']
140
+ guidance_scale = params['guidance_scale']
141
+ num_inference_steps = params['num_inference_steps']
142
+ seed = params['seed']
143
+ controlnet_conditioning_scale = params['controlnet_conditioning_scale']
144
+ control_mode = params['control_mode']
145
 
146
+ # Generate the holiday card using the current parameters
147
+ generated_image, processed_image, _ = call_control_net_api(
148
+ uploaded_file, prompt, control_mode=control_mode,
149
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
150
+ seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale
151
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ if generated_image is not None:
154
+ # Resize generated_image to match original_image size
155
+ generated_image = generated_image.resize(original_image.size)
156
 
157
+ # Create a copy of the generated image
158
+ final_image = generated_image.copy()
159
 
160
+ # Crop the selected portion of the original image
161
+ cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
 
162
 
163
+ # Get the size of the cropped image
164
+ cropped_width, cropped_height = cropped_original.size
165
 
166
+ # Calculate the center of the generated image
167
+ center_x = (final_image.width - cropped_width) // 2
168
+ center_y = (final_image.height - cropped_height) // 2
 
169
 
170
+ # Paste the cropped portion of the original image onto the generated image at the calculated center
171
+ final_image.paste(cropped_original, (center_x, center_y))
172
 
173
+ # Display the final holiday card in one of the columns
174
+ columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
175
+ else:
176
+ st.error(f"Failed to generate holiday card {i + 1}. Please try again.")
177
  else:
178
+ st.warning("Please upload an image to get started.")
pages/2_Multiple_HolidayCard.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image, ImageDraw
6
+ import numpy as np
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
12
+ load_dotenv(dotenv_path, override=True)
13
+ api_key = os.getenv("FIREWORKS_API_KEY")
14
+
15
+ if not api_key:
16
+ st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
17
+ st.stop()
18
+
19
+ VALID_ASPECT_RATIOS = {
20
+ (1, 1): "1:1", (21, 9): "21:9", (16, 9): "16:9", (3, 2): "3:2", (5, 4): "5:4",
21
+ (4, 5): "4:5", (2, 3): "2:3", (9, 16): "9:16", (9, 21): "9:21",
22
+ }
23
+
24
+ def get_closest_aspect_ratio(width, height):
25
+ aspect_ratio = width / height
26
+ closest_ratio = min(VALID_ASPECT_RATIOS.keys(), key=lambda x: abs((x[0] / x[1]) - aspect_ratio))
27
+ return VALID_ASPECT_RATIOS[closest_ratio]
28
+
29
+ def process_image(uploaded_image):
30
+ image = np.array(Image.open(uploaded_image).convert('L'))
31
+ edges = cv2.Canny(image, 100, 200)
32
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
33
+ pil_image = Image.fromarray(edges_rgb)
34
+ byte_arr = BytesIO()
35
+ pil_image.save(byte_arr, format='JPEG')
36
+ byte_arr.seek(0)
37
+ return byte_arr, pil_image
38
+
39
+ def call_control_net_api(uploaded_image, prompt, control_mode=0, guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0):
40
+ control_image, processed_image = process_image(uploaded_image)
41
+ files = {'control_image': ('control_image.jpg', control_image, 'image/jpeg')}
42
+ original_image = Image.open(uploaded_image)
43
+ width, height = original_image.size
44
+ aspect_ratio = get_closest_aspect_ratio(width, height)
45
+ data = {
46
+ 'prompt': prompt,
47
+ 'control_mode': control_mode,
48
+ 'aspect_ratio': aspect_ratio,
49
+ 'guidance_scale': guidance_scale,
50
+ 'num_inference_steps': num_inference_steps,
51
+ 'seed': seed,
52
+ 'controlnet_conditioning_scale': controlnet_conditioning_scale
53
+ }
54
+ headers = {
55
+ 'accept': 'image/jpeg',
56
+ 'authorization': f'Bearer {api_key}',
57
+ }
58
+ response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net',
59
+ files=files, data=data, headers=headers)
60
+ if response.status_code == 200:
61
+ return Image.open(BytesIO(response.content)), processed_image, original_image
62
+ else:
63
+ st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}")
64
+ return None, None, None
65
+
66
+ def draw_crop_preview(image, x, y, width, height):
67
+ draw = ImageDraw.Draw(image)
68
+ draw.rectangle([x, y, x + width, y + height], outline="red", width=2)
69
+ return image
70
+
71
+ st.title("Holiday Card Generator with ControlNet")
72
+
73
+ uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
74
+
75
+ if uploaded_file is not None:
76
+ original_image = Image.open(uploaded_file)
77
+ st.image(original_image, caption="Uploaded Image", use_column_width=True)
78
+
79
+ img_width, img_height = original_image.size
80
+
81
+ col1, col2 = st.columns(2)
82
+ with col1:
83
+ x_pos = st.slider("X position", 0, img_width, img_width // 4)
84
+ crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos))
85
+ with col2:
86
+ y_pos = st.slider("Y position", 0, img_height, img_height // 4)
87
+ crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos))
88
+
89
+ preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
90
+ st.image(preview_image, caption="Crop Preview", use_column_width=True)
91
+
92
+ st.subheader("Set Parameters for Each Holiday Card")
93
+
94
+ # Define the list of suggested holiday prompts
95
+ holiday_prompts = [
96
+ "A border of Festive snowflakes and winter patterns for a holiday card border",
97
+ "A border of Joyful Christmas ornaments and lights decorating the edges",
98
+ "A border of Warm and cozy fireplace scene with stockings and garlands",
99
+ "A border of Colorful Hanukkah menorahs and dreidels along the border",
100
+ "A border of New Year's Eve fireworks with stars and confetti framing the image"
101
+ ]
102
+
103
+ # Define input fields for each holiday card's parameters in expanders
104
+ card_params = []
105
+ for i in range(4):
106
+ with st.expander(f"Holiday Card {i + 1} Parameters"):
107
+ st.write(f"### Holiday Card {i + 1}")
108
+
109
+ # Dropdown to choose a suggested holiday prompt or enter custom prompt
110
+ selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_prompts)
111
+ custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=f"Custom Prompt {i + 1}") if selected_prompt == "Custom" else selected_prompt
112
+
113
+ # Parameter sliders for each holiday card
114
+ guidance_scale = st.slider(f"Guidance Scale for Holiday Card {i + 1}", min_value=0.0, max_value=20.0, value=3.5, step=0.1)
115
+ num_inference_steps = st.slider(f"Number of Inference Steps for Holiday Card {i + 1}", min_value=1, max_value=100, value=30, step=1)
116
+ seed = st.slider(f"Random Seed for Holiday Card {i + 1}", min_value=0, max_value=1000, value=i * 100)
117
+ controlnet_conditioning_scale = st.slider(f"ControlNet Conditioning Scale for Holiday Card {i + 1}", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
118
+ control_mode = st.slider(f"Control Mode for Holiday Card {i + 1}", min_value=0, max_value=2, value=0, help="0: None, 1: Partial, 2: Full")
119
+
120
+ # Save the parameters for each holiday card
121
+ card_params.append({
122
+ "prompt": custom_prompt,
123
+ "guidance_scale": guidance_scale,
124
+ "num_inference_steps": num_inference_steps,
125
+ "seed": seed,
126
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
127
+ "control_mode": control_mode
128
+ })
129
+
130
+ # Generate the holiday cards
131
+ if st.button("Generate Holiday Cards"):
132
+ with st.spinner("Processing..."):
133
+ # Create a column layout for displaying cards side by side
134
+ col1, col2 = st.columns(2)
135
+ col3, col4 = st.columns(2)
136
+ columns = [col1, col2, col3, col4] # To display images in a 2x2 grid
137
+
138
+ # Loop through each card's parameters and generate the holiday card
139
+ for i, params in enumerate(card_params):
140
+ prompt = params['prompt']
141
+ guidance_scale = params['guidance_scale']
142
+ num_inference_steps = params['num_inference_steps']
143
+ seed = params['seed']
144
+ controlnet_conditioning_scale = params['controlnet_conditioning_scale']
145
+ control_mode = params['control_mode']
146
+
147
+ # Generate the holiday card using the current parameters
148
+ generated_image, processed_image, _ = call_control_net_api(
149
+ uploaded_file, prompt, control_mode=control_mode,
150
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
151
+ seed=seed, controlnet_conditioning_scale=controlnet_conditioning_scale
152
+ )
153
+
154
+ if generated_image is not None:
155
+ # Resize generated_image to match original_image size
156
+ generated_image = generated_image.resize(original_image.size)
157
+
158
+ # Create a copy of the generated image
159
+ final_image = generated_image.copy()
160
+
161
+ # Crop the selected portion of the original image
162
+ cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
163
+
164
+ # Get the size of the cropped image
165
+ cropped_width, cropped_height = cropped_original.size
166
+
167
+ # Calculate the center of the generated image
168
+ center_x = (final_image.width - cropped_width) // 2
169
+ center_y = (final_image.height - cropped_height) // 2
170
+
171
+ # Paste the cropped portion of the original image onto the generated image at the calculated center
172
+ final_image.paste(cropped_original, (center_x, center_y))
173
+
174
+ # Display the final holiday card in one of the columns
175
+ columns[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
176
+
177
+ # Display the parameters used for this card
178
+ columns[i].write(f"**Prompt:** {prompt}")
179
+ columns[i].write(f"**Guidance Scale:** {guidance_scale}")
180
+ columns[i].write(f"**Inference Steps:** {num_inference_steps}")
181
+ columns[i].write(f"**Seed:** {seed}")
182
+ columns[i].write(f"**ControlNet Conditioning Scale:** {controlnet_conditioning_scale}")
183
+ columns[i].write(f"**Control Mode:** {control_mode}")
184
+ else:
185
+ st.error(f"Failed to generate holiday card {i + 1}. Please try again.")
186
+ else:
187
+ st.warning("Please upload an image to get started.")
188
+