import gradio as gr from huggingface_hub import InferenceClient from PIL import Image import time import os import base64 from io import BytesIO HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret." else: HF_TOKEN_ERROR = None client = InferenceClient(token=HF_TOKEN) PROMPT_IMPROVER_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct" def improve_prompt(original_prompt): if HF_TOKEN_ERROR: raise gr.Error(HF_TOKEN_ERROR) try: system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent." prompt_for_llm = f"""<|system|> {system_prompt} <|user|> Improve this prompt: {original_prompt} <|assistant|> """ improved_prompt = client.text_generation( prompt=prompt_for_llm, model=PROMPT_IMPROVER_MODEL, max_new_tokens=1280, temperature=0.7, top_p=0.9, repetition_penalty=1.2, stop_sequences=[""], ) return improved_prompt.strip() except Exception as e: print(f"Error improving prompt: {e}") return original_prompt def generate_image(prompt, progress=gr.Progress()): if HF_TOKEN_ERROR: raise gr.Error(HF_TOKEN_ERROR) progress(0, desc="Improving prompt...") improved_prompt = improve_prompt(prompt) progress(0.2, desc="Sending request ") try: image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell") if not isinstance(image, Image.Image): raise Exception(f"Expected a PIL Image, but got: {type(image)}") progress(0.8, desc="Processing image...") time.sleep(0.5) progress(1.0, desc="Done!") return image except Exception as e: if "rate limit" in str(e).lower(): error_message = f"Rate limit exceeded. Please try again later. Error: {e}" else: error_message = f"An error occurred: {e}" raise gr.Error(error_message) def pil_to_base64(img): buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" css = """ """ with gr.Blocks(css=css) as demo: gr.Markdown( """ # Xylaria Iris v3 """, elem_classes="title" ) with gr.Row(): with gr.Column(): with gr.Group(elem_classes="input-section"): prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3) generate_button = gr.Button("Generate Image", elem_classes="submit-button") with gr.Column(): with gr.Group(elem_classes="output-section") as output_group: image_output = gr.Image(label="Generated Image", interactive=False) def on_generate_click(prompt): output_group.elem_classes = ["output-section", "animate"] image = generate_image(prompt) # Ignore the improved prompt output_group.elem_classes = ["output-section"] return image # Return only the generated image generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output) prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output) gr.Examples( [["A dog"], ["A house on a hill"], ["A spaceship"]], inputs=prompt_input ) if __name__ == "__main__": demo.queue().launch()