File size: 3,893 Bytes
27f5740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db768d2
27f5740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db768d2
27f5740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4eaa1
27f5740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4eaa1
27f5740
71a1f99
b414b9f
 
 
db768d2
b414b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27f5740
b414b9f
 
27f5740
b414b9f
27f5740
 
b414b9f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr  
from huggingface_hub import InferenceClient  
from PIL import Image  
import time  
import os  
import base64  
from io import BytesIO  

HF_TOKEN = os.environ.get("HF_TOKEN")  

if not HF_TOKEN:  
    HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."  
else:  
    HF_TOKEN_ERROR = None  

client = InferenceClient(token=HF_TOKEN)  
PROMPT_IMPROVER_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct"  

def improve_prompt(original_prompt):  
    if HF_TOKEN_ERROR:  
        raise gr.Error(HF_TOKEN_ERROR)  

    try:  
        system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."  
        prompt_for_llm = f"""<|system|>  
{system_prompt}</s>  
<|user|>  
Improve this prompt: {original_prompt}  
</s>  
<|assistant|>  
"""  
        improved_prompt = client.text_generation(  
            prompt=prompt_for_llm,  
            model=PROMPT_IMPROVER_MODEL,  
            max_new_tokens=1280,  
            temperature=0.7,  
            top_p=0.9,  
            repetition_penalty=1.2,  
            stop_sequences=["</s>"],  
        )  

        return improved_prompt.strip()  

    except Exception as e:  
        print(f"Error improving prompt: {e}")  
        return original_prompt  


def generate_image(prompt, progress=gr.Progress()):  
    if HF_TOKEN_ERROR:  
        raise gr.Error(HF_TOKEN_ERROR)  

    progress(0, desc="Improving prompt...")  
    improved_prompt = improve_prompt(prompt)  

    progress(0.2, desc="Sending request ")  
    try:  
        image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")  

        if not isinstance(image, Image.Image):  
            raise Exception(f"Expected a PIL Image, but got: {type(image)}")  

        progress(0.8, desc="Processing image...")  
        time.sleep(0.5)  
        progress(1.0, desc="Done!")  
        return image  
    except Exception as e:  
        if "rate limit" in str(e).lower():  
            error_message = f"Rate limit exceeded. Please try again later. Error: {e}"  
        else:  
            error_message = f"An error occurred: {e}"  
        raise gr.Error(error_message)  


def pil_to_base64(img):  
    buffered = BytesIO()  
    img.save(buffered, format="PNG")  
    img_str = base64.b64encode(buffered.getvalue()).decode()  
    return f"data:image/png;base64,{img_str}"  


css = """  

"""  

with gr.Blocks(css=css) as demo:  
    gr.Markdown(  
        """  
        # Xylaria Iris v3  
        """,  
        elem_classes="title"  
    )  

    with gr.Row():  
        with gr.Column():  
            with gr.Group(elem_classes="input-section"):  
                prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3)  
                generate_button = gr.Button("Generate Image", elem_classes="submit-button")  
        with gr.Column():  
            with gr.Group(elem_classes="output-section") as output_group:  
                image_output = gr.Image(label="Generated Image", interactive=False)  

    def on_generate_click(prompt):  
        output_group.elem_classes = ["output-section", "animate"]  
        image = generate_image(prompt)  # Ignore the improved prompt  
        output_group.elem_classes = ["output-section"]  

        return image  # Return only the generated image  

    generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output)  
    prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output)  

    gr.Examples(  
        [["A dog"],  
         ["A house on a hill"],  
         ["A spaceship"]],  
        inputs=prompt_input  
    )  

if __name__ == "__main__":  
    demo.queue().launch()