Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
import time | |
import os | |
import base64 | |
from io import BytesIO | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret." | |
else: | |
HF_TOKEN_ERROR = None | |
client = InferenceClient(token=HF_TOKEN) | |
PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta" | |
def improve_prompt(original_prompt): | |
if HF_TOKEN_ERROR: | |
raise gr.Error(HF_TOKEN_ERROR) | |
try: | |
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent." | |
prompt_for_llm = f"""<|system|> | |
{system_prompt}</s> | |
<|user|> | |
Improve this prompt: {original_prompt} | |
</s> | |
<|assistant|> | |
""" | |
improved_prompt = client.text_generation( | |
prompt=prompt_for_llm, | |
model=PROMPT_IMPROVER_MODEL, | |
max_new_tokens=128, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
stop_sequences=["</s>"], | |
) | |
return improved_prompt.strip() | |
except Exception as e: | |
print(f"Error improving prompt: {e}") | |
return original_prompt | |
def generate_image(prompt, progress=gr.Progress()): | |
if HF_TOKEN_ERROR: | |
raise gr.Error(HF_TOKEN_ERROR) | |
progress(0, desc="Improving prompt...") | |
improved_prompt = improve_prompt(prompt) | |
progress(0.2, desc="Sending request to Hugging Face...") | |
try: | |
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell") | |
if not isinstance(image, Image.Image): | |
raise Exception(f"Expected a PIL Image, but got: {type(image)}") | |
progress(0.8, desc="Processing image...") | |
time.sleep(0.5) | |
progress(1.0, desc="Done!") | |
return image, improved_prompt | |
except Exception as e: | |
if "rate limit" in str(e).lower(): | |
error_message = f"Rate limit exceeded. Please try again later. Error: {e}" | |
else: | |
error_message = f"An error occurred: {e}" | |
raise gr.Error(error_message) | |
def pil_to_base64(img): | |
buffered = BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return f"data:image/png;base64,{img_str}" | |
css = """ | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
border: 1px solid #ddd; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.title { | |
text-align: center; | |
font-size: 2.5em; | |
margin-bottom: 0.5em; | |
color: #333; | |
font-family: 'Arial', sans-serif; | |
} | |
.description { | |
text-align: center; | |
font-size: 1.1em; | |
margin-bottom: 1.5em; | |
color: #555; | |
} | |
.input-section, .output-section { | |
margin-bottom: 1.5em; | |
} | |
.output-section img { | |
display: block; | |
margin: auto; | |
max-width: 100%; | |
height: auto; | |
border-radius: 8px; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
} | |
@keyframes fadeIn { | |
from { opacity: 0; transform: translateY(20px); } | |
to { opacity: 1; transform: translateY(0); } | |
} | |
.output-section.animate img { | |
animation: fadeIn 0.8s ease-out; | |
} | |
.submit-button { | |
display: block; | |
margin: auto; | |
padding: 10px 20px; | |
font-size: 1.1em; | |
color: white; | |
background-color: #4CAF50; | |
border: none; | |
border-radius: 5px; | |
cursor: pointer; | |
transition: background-color 0.3s ease; | |
} | |
.submit-button:hover { | |
background-color: #367c39; | |
} | |
.error-message { | |
color: red; | |
text-align: center; | |
margin-top: 1em; | |
font-weight: bold; | |
} | |
label{ | |
font-weight: bold; | |
display: block; | |
margin-bottom: 0.5em; | |
} | |
.improved-prompt-display { | |
margin-top: 10px; | |
padding: 8px; | |
border: 1px solid #ccc; | |
border-radius: 4px; | |
background-color: #f9f9f9; | |
font-style: italic; | |
color: #444; | |
} | |
.download-link { | |
display: block; | |
text-align: center; | |
margin-top: 10px; | |
color: #4CAF50; | |
text-decoration: none; | |
font-weight: bold; | |
} | |
.download-link:hover{ | |
text-decoration: underline; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# Xylaria Iris Image Generator | |
""", | |
elem_classes="title" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(elem_classes="input-section"): | |
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3) | |
generate_button = gr.Button("Generate Image", elem_classes="submit-button") | |
with gr.Column(): | |
with gr.Group(elem_classes="output-section") as output_group: | |
image_output = gr.Image(label="Generated Image", interactive=False) | |
improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display") | |
download_link = gr.HTML(visible=False) | |
def on_generate_click(prompt): | |
output_group.elem_classes = ["output-section", "animate"] | |
image, improved_prompt = generate_image(prompt) | |
output_group.elem_classes = ["output-section"] | |
image_b64 = pil_to_base64(image) | |
download_html = f'<a class="download-link" href="{image_b64}" download="generated_image.png">Download Image</a>' | |
return image, improved_prompt, download_html | |
generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link]) | |
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link]) | |
gr.Examples( | |
[["A dog"], | |
["A house on a hill"], | |
["A spaceship"]], | |
inputs=prompt_input | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |