|
|
|
import os |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import json |
|
from PIL import Image, ImageDraw, ImageFont |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
class SimpleTextDiffuser: |
|
""" |
|
Simple implementation of TextDiffuser-2 concept for Hugging Face Spaces |
|
""" |
|
def __init__(self): |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") |
|
self.language_model = AutoModelForCausalLM.from_pretrained("distilgpt2") |
|
self.language_model.to(device) |
|
|
|
|
|
self.diffusion_model = None |
|
if torch.cuda.is_available(): |
|
self.diffusion_model = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16 |
|
) |
|
self.diffusion_model.to(device) |
|
|
|
print("Models initialized") |
|
|
|
def generate_layout(self, prompt, image_size=(512, 512), num_text_elements=3): |
|
"""Generate text layout based on prompt""" |
|
width, height = image_size |
|
|
|
|
|
layout_prompt = f""" |
|
Create a layout for an image with: |
|
- Description: {prompt} |
|
- Image size: {width}x{height} |
|
- Number of text elements: {num_text_elements} |
|
|
|
Generate text content and positions: |
|
""" |
|
|
|
|
|
input_ids = self.tokenizer.encode(layout_prompt, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
output = self.language_model.generate( |
|
input_ids, |
|
max_length=input_ids.shape[1] + 150, |
|
temperature=0.7, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
layout_text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
text_elements = [] |
|
|
|
|
|
import random |
|
|
|
|
|
title = prompt.split()[:5] |
|
title = " ".join(title) + "..." |
|
title_x = width // 4 |
|
title_y = height // 4 |
|
text_elements.append({ |
|
"text": title, |
|
"position": (title_x, title_y), |
|
"size": 24, |
|
"color": (0, 0, 0), |
|
"type": "title" |
|
}) |
|
|
|
|
|
sample_texts = [ |
|
"Premium Quality", |
|
"Best Value", |
|
"Limited Edition", |
|
"New Collection", |
|
"Special Offer", |
|
"Coming Soon", |
|
"Best Seller", |
|
"Top Choice", |
|
"Featured Product", |
|
"Exclusive Deal" |
|
] |
|
|
|
for i in range(1, num_text_elements): |
|
x = random.randint(width // 8, width * 3 // 4) |
|
y = random.randint(height // 3, height * 3 // 4) |
|
text = sample_texts[i % len(sample_texts)] |
|
color = ( |
|
random.randint(0, 200), |
|
random.randint(0, 200), |
|
random.randint(0, 200) |
|
) |
|
|
|
text_elements.append({ |
|
"text": text, |
|
"position": (x, y), |
|
"size": 18, |
|
"color": color, |
|
"type": f"element_{i}" |
|
}) |
|
|
|
return text_elements, layout_text |
|
|
|
def generate_image(self, prompt, image_size=(512, 512)): |
|
"""Generate base image using diffusion model or placeholder""" |
|
width, height = image_size |
|
|
|
if self.diffusion_model and torch.cuda.is_available(): |
|
|
|
image = self.diffusion_model( |
|
prompt=prompt, |
|
height=height, |
|
width=width, |
|
num_inference_steps=30 |
|
).images[0] |
|
else: |
|
|
|
image = Image.new("RGB", image_size, (240, 240, 240)) |
|
|
|
|
|
for y in range(height): |
|
for x in range(width): |
|
r = int(240 - 100 * (y / height)) |
|
g = int(240 - 50 * (x / width)) |
|
b = int(240 - 75 * ((x + y) / (width + height))) |
|
image.putpixel((x, y), (r, g, b)) |
|
|
|
return image |
|
|
|
def render_text(self, image, text_elements): |
|
"""Render text elements onto the image""" |
|
image_with_text = image.copy() |
|
draw = ImageDraw.Draw(image_with_text) |
|
|
|
for element in text_elements: |
|
try: |
|
font_size = element["size"] |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("DejaVuSans.ttf", font_size) |
|
except IOError: |
|
try: |
|
font = ImageFont.truetype("Arial.ttf", font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
|
|
text = element["text"] |
|
position = element["position"] |
|
color = element["color"] |
|
|
|
|
|
bbox = draw.textbbox(position, text, font=font) |
|
text_width = bbox[2] - bbox[0] |
|
text_height = bbox[3] - bbox[1] |
|
|
|
|
|
padding = 5 |
|
background_box = [ |
|
position[0] - padding, |
|
position[1] - padding, |
|
position[0] + text_width + padding, |
|
position[1] + text_height + padding |
|
] |
|
draw.rectangle(background_box, fill=(255, 255, 255, 200)) |
|
|
|
|
|
draw.text(position, text, fill=color, font=font) |
|
|
|
except Exception as e: |
|
print(f"Error rendering text: {e}") |
|
continue |
|
|
|
return image_with_text |
|
|
|
def visualize_layout(self, text_elements, image_size=(512, 512)): |
|
"""Create a visualization of the text layout""" |
|
width, height = image_size |
|
image = Image.new("RGB", image_size, (255, 255, 255)) |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
for x in range(0, width, 50): |
|
draw.line([(x, 0), (x, height)], fill=(230, 230, 230)) |
|
for y in range(0, height, 50): |
|
draw.line([(0, y), (width, y)], fill=(230, 230, 230)) |
|
|
|
|
|
for element in text_elements: |
|
position = element["position"] |
|
text = element["text"] |
|
element_type = element.get("type", "unknown") |
|
|
|
|
|
circle_radius = 5 |
|
circle_bbox = [ |
|
position[0] - circle_radius, |
|
position[1] - circle_radius, |
|
position[0] + circle_radius, |
|
position[1] + circle_radius |
|
] |
|
draw.ellipse(circle_bbox, fill=(255, 0, 0)) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("DejaVuSans.ttf", 12) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
|
|
info_text = f"{text} ({element_type})" |
|
pos_text = f"Position: ({position[0]}, {position[1]})" |
|
draw.text((position[0] + 10, position[1]), info_text, fill=(0, 0, 0), font=font) |
|
draw.text((position[0] + 10, position[1] + 15), pos_text, fill=(0, 0, 255), font=font) |
|
|
|
return image |
|
|
|
def generate_text_image(self, prompt, width=512, height=512, num_text_elements=3): |
|
"""Generate an image with rendered text based on prompt""" |
|
|
|
width = max(256, min(1024, width)) |
|
height = max(256, min(1024, height)) |
|
num_text_elements = max(1, min(5, num_text_elements)) |
|
|
|
image_size = (width, height) |
|
|
|
|
|
text_elements, layout_text = self.generate_layout(prompt, image_size, num_text_elements) |
|
|
|
|
|
base_image = self.generate_image(prompt, image_size) |
|
|
|
|
|
image_with_text = self.render_text(base_image, text_elements) |
|
|
|
|
|
layout_visualization = self.visualize_layout(text_elements, image_size) |
|
|
|
|
|
layout_info = { |
|
"prompt": prompt, |
|
"image_size": image_size, |
|
"num_text_elements": num_text_elements, |
|
"text_elements": text_elements, |
|
"layout_generation_prompt": layout_text |
|
} |
|
|
|
formatted_layout = json.dumps(layout_info, indent=2) |
|
|
|
return image_with_text, layout_visualization, formatted_layout |
|
|
|
|
|
model = SimpleTextDiffuser() |
|
|
|
|
|
def process_request(prompt, width, height, num_text_elements): |
|
try: |
|
width = int(width) |
|
height = int(height) |
|
num_text_elements = int(num_text_elements) |
|
|
|
image, layout, layout_info = model.generate_text_image( |
|
prompt, |
|
width=width, |
|
height=height, |
|
num_text_elements=num_text_elements |
|
) |
|
|
|
return image, layout, layout_info |
|
except Exception as e: |
|
error_message = f"Error: {str(e)}" |
|
print(error_message) |
|
return None, None, error_message |
|
|
|
|
|
with gr.Blocks(title="TextDiffuser-2 Demo") as demo: |
|
gr.Markdown(""" |
|
# TextDiffuser-2 Demo |
|
|
|
This demo implements the concepts from the paper "[TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering](https://arxiv.org/abs/2311.16465)" by Jingye Chen et al. |
|
|
|
Generate images with text by providing a descriptive prompt below. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
value="A modern business poster with company name and tagline", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
width_input = gr.Number(label="Width", value=512, minimum=256, maximum=1024, step=64) |
|
height_input = gr.Number(label="Height", value=512, minimum=256, maximum=1024, step=64) |
|
|
|
num_elements_input = gr.Slider( |
|
label="Number of Text Elements", |
|
minimum=1, |
|
maximum=5, |
|
value=3, |
|
step=1 |
|
) |
|
|
|
submit_button = gr.Button("Generate Image", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Tabs(): |
|
with gr.TabItem("Generated Image"): |
|
image_output = gr.Image(label="Image with Text") |
|
|
|
with gr.TabItem("Layout Visualization"): |
|
layout_output = gr.Image(label="Text Layout") |
|
|
|
with gr.TabItem("Layout Information"): |
|
layout_info_output = gr.Code(language="json", label="Layout Data") |
|
|
|
gr.Markdown(""" |
|
## Example Prompts |
|
|
|
Try these prompts or create your own: |
|
""") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
["A movie poster for a sci-fi thriller", 512, 768, 3], |
|
["A motivational quote on a sunset background", 768, 512, 2], |
|
["A coffee shop menu with prices", 512, 512, 4], |
|
["A modern business card design", 512, 384, 3], |
|
], |
|
inputs=[prompt_input, width_input, height_input, num_elements_input] |
|
) |
|
|
|
submit_button.click( |
|
fn=process_request, |
|
inputs=[prompt_input, width_input, height_input, num_elements_input], |
|
outputs=[image_output, layout_output, layout_info_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
## About |
|
|
|
This is a simplified implementation for demonstration purposes. The full approach described in the paper involves deeper integration of language models with the diffusion process. |
|
|
|
Running on: """ + str(device)) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |