iris / app.py
Reality123b's picture
Update app.py
27f5740 verified
raw
history blame
5.25 kB
import gradio as gr
from huggingface_hub import InferenceClient
from PIL import Image
import time
import os
import base64
from io import BytesIO
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
else:
HF_TOKEN_ERROR = None
client = InferenceClient(token=HF_TOKEN)
PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
def improve_prompt(original_prompt):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
try:
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
prompt_for_llm = f"""<|system|>
{system_prompt}</s>
<|user|>
Improve this prompt: {original_prompt}
</s>
<|assistant|>
"""
improved_prompt = client.text_generation(
prompt=prompt_for_llm,
model=PROMPT_IMPROVER_MODEL,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
stop_sequences=["</s>"],
)
return improved_prompt.strip()
except Exception as e:
print(f"Error improving prompt: {e}")
return original_prompt
def generate_image(prompt, progress=gr.Progress()):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
progress(0, desc="Improving prompt...")
improved_prompt = improve_prompt(prompt)
progress(0.2, desc="Sending request to Hugging Face...")
try:
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
if not isinstance(image, Image.Image):
raise Exception(f"Expected a PIL Image, but got: {type(image)}")
progress(0.8, desc="Processing image...")
time.sleep(0.5)
progress(1.0, desc="Done!")
return image
except Exception as e:
if "rate limit" in str(e).lower():
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
else:
error_message = f"An error occurred: {e}"
raise gr.Error(error_message)
def pil_to_base64(img):
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
css = """
body {
background-color: #f4f4f4;
font-family: 'Arial', sans-serif;
}
.container {
max-width: 900px;
margin: auto;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
background-color: white;
}
.title {
text-align: center;
font-size: 3em;
margin-bottom: 0.5em;
color: #3a3a3a;
}
.input-section {
background-color: #e3f7fc;
border-radius: 8px;
padding: 15px;
}
.output-section {
background-color: #f0f0f0;
border-radius: 8px;
padding: 15px;
}
.output-section img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
.submit-button {
background-color: #007BFF;
border: none;
border-radius: 5px;
color: white;
padding: 12px 20px;
cursor: pointer;
transition: background-color 0.3s ease, transform 0.2s ease;
}
.submit-button:hover {
background-color: #0056b3;
transform: scale(1.05);
}
.error-message {
color: red;
text-align: center;
font-weight: bold;
}
.label {
font-weight: bold;
}
.download-link {
color: #007BFF;
font-weight: bold;
text-decoration: none;
}
.download-link:hover {
text-decoration: underline;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Xylaria Iris Image Generator
""",
elem_classes="title"
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="input-section"):
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3)
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
with gr.Column():
with gr.Group(elem_classes="output-section") as output_group:
image_output = gr.Image(label="Generated Image", interactive=False)
def on_generate_click(prompt):
output_group.elem_classes = ["output-section", "animate"]
image = generate_image(prompt) # Ignore the improved prompt
output_group.elem_classes = ["output-section"]
return image # Return only the generated image
generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output)
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output)
gr.Examples(
[["A dog"],
["A house on a hill"],
["A spaceship"]],
inputs=prompt_input
)
if __name__ == "__main__":
demo.queue().launch()