iris / app.py
Reality123b's picture
Update app.py
bf20e5c verified
raw
history blame
5.92 kB
import gradio as gr
from huggingface_hub import InferenceClient
from PIL import Image
import time
import os
import base64
from io import BytesIO
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
else:
HF_TOKEN_ERROR = None
client = InferenceClient(token=HF_TOKEN)
PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
def improve_prompt(original_prompt):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
try:
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
prompt_for_llm = f"""<|system|>
{system_prompt}</s>
<|user|>
Improve this prompt: {original_prompt}
</s>
<|assistant|>
"""
improved_prompt = client.text_generation(
prompt=prompt_for_llm,
model=PROMPT_IMPROVER_MODEL,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
stop_sequences=["</s>"],
)
return improved_prompt.strip()
except Exception as e:
print(f"Error improving prompt: {e}")
return original_prompt
def generate_image(prompt, progress=gr.Progress()):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
progress(0, desc="Improving prompt...")
improved_prompt = improve_prompt(prompt)
progress(0.2, desc="Sending request to Hugging Face...")
try:
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
if not isinstance(image, Image.Image):
raise Exception(f"Expected a PIL Image, but got: {type(image)}")
progress(0.8, desc="Processing image...")
time.sleep(0.5)
progress(1.0, desc="Done!")
return image, improved_prompt
except Exception as e:
if "rate limit" in str(e).lower():
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
else:
error_message = f"An error occurred: {e}"
raise gr.Error(error_message)
def pil_to_base64(img):
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
css = """
.container {
max-width: 800px;
margin: auto;
padding: 20px;
border: 1px solid #ddd;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.title {
text-align: center;
font-size: 2.5em;
margin-bottom: 0.5em;
color: #333;
font-family: 'Arial', sans-serif;
}
.description {
text-align: center;
font-size: 1.1em;
margin-bottom: 1.5em;
color: #555;
}
.input-section, .output-section {
margin-bottom: 1.5em;
}
.output-section img {
display: block;
margin: auto;
max-width: 100%;
height: auto;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
.output-section.animate img {
animation: fadeIn 0.8s ease-out;
}
.submit-button {
display: block;
margin: auto;
padding: 10px 20px;
font-size: 1.1em;
color: white;
background-color: #4CAF50;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.submit-button:hover {
background-color: #367c39;
}
.error-message {
color: red;
text-align: center;
margin-top: 1em;
font-weight: bold;
}
label{
font-weight: bold;
display: block;
margin-bottom: 0.5em;
}
.improved-prompt-display {
margin-top: 10px;
padding: 8px;
border: 1px solid #ccc;
border-radius: 4px;
background-color: #f9f9f9;
font-style: italic;
color: #444;
}
.download-link {
display: block;
text-align: center;
margin-top: 10px;
color: #4CAF50;
text-decoration: none;
font-weight: bold;
}
.download-link:hover{
text-decoration: underline;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Xylaria Iris Image Generator
""",
elem_classes="title"
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="input-section"):
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3)
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
with gr.Column():
with gr.Group(elem_classes="output-section") as output_group:
image_output = gr.Image(label="Generated Image", interactive=False)
improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display")
download_link = gr.HTML(visible=False)
def on_generate_click(prompt):
output_group.elem_classes = ["output-section", "animate"]
image, improved_prompt = generate_image(prompt)
output_group.elem_classes = ["output-section"]
image_b64 = pil_to_base64(image)
download_html = f'<a class="download-link" href="{image_b64}" download="generated_image.png">Download Image</a>'
return image, improved_prompt, download_html
generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
gr.Examples(
[["A dog"],
["A house on a hill"],
["A spaceship"]],
inputs=prompt_input
)
if __name__ == "__main__":
demo.queue().launch()