Spaces:
Runtime error
Runtime error
import gradio as gr | |
import fal_client | |
from fal_client.client import FalClientError | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import traceback | |
import os | |
def generate_image(api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance): | |
try: | |
os.environ['FAL_KEY'] = api_key | |
arguments = { | |
"prompt": prompt, | |
"image_size": image_size, | |
"num_images": num_images, | |
"enable_safety_checker": enable_safety_checker, | |
} | |
arguments["safety_tolerance"] = safety_tolerance | |
if seed is not None and seed != "": | |
arguments["seed"] = int(seed) | |
if sync_mode is not None: | |
arguments["sync_mode"] = sync_mode | |
# Log the actual request body | |
print(f"Request Body: {arguments}") | |
handler = fal_client.submit( | |
"fal-ai/flux-pro/v1.1", | |
arguments=arguments, | |
) | |
result = handler.get() | |
# Display and log the response | |
print(f"Response: {result}") | |
images = [] | |
for img_info in result['images']: | |
img_url = img_info['url'] | |
response = requests.get(img_url) | |
img = Image.open(BytesIO(response.content)) | |
images.append(img) | |
return [gr.update(value=images, visible=True), gr.update(value=str(result), visible=True)] | |
except FalClientError as e: | |
error_messages = [] | |
for error_obj in e.args[0]: | |
error_messages.append(error_obj['msg']) | |
error_msg = "Errors:\n" + "\n".join(error_messages) | |
print(error_msg) | |
return [gr.update(value=[]), gr.update(value=error_msg)] | |
except Exception as e: | |
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
print(error_msg) | |
return [gr.update(value=[]), gr.update(value=error_msg)] | |
def update_safety_tolerance_visibility(enable_safety): | |
return gr.update(visible=enable_safety, value="6") | |
with gr.Blocks() as demo: | |
gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator") | |
gr.Markdown("Get your API key at https://fal.ai/dashboard/keys") | |
with gr.Row(): | |
api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") | |
with gr.Row(): | |
image_size = gr.Dropdown( | |
label="Image Size", | |
choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], | |
value="landscape_4_3" | |
) | |
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) | |
with gr.Row(): | |
seed = gr.Textbox(label="Seed (optional)", placeholder="Enter a number for reproducible results") | |
sync_mode = gr.Checkbox(label="Sync Mode", value=False) | |
with gr.Row(): | |
enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) | |
safety_tolerance = gr.Dropdown( | |
label="Safety Tolerance", | |
choices=["1", "2", "3", "4", "5"], | |
value="2", | |
visible=True | |
) | |
gr.Markdown("**Note:** Safety Tolerance: 1 is the most strict, 6 is the most permissive. Default is 2.") | |
generate_btn = gr.Button("Generate Image") | |
output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2) | |
response_output = gr.Textbox(label="Response", visible=True) | |
enable_safety_checker.change( | |
fn=update_safety_tolerance_visibility, | |
inputs=[enable_safety_checker], | |
outputs=[safety_tolerance] | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance], | |
outputs=[output_gallery, response_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |