mostlycached commited on
Commit
0773ab4
·
verified ·
1 Parent(s): f837703

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -13
  2. app.py +364 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,13 +1 @@
1
- ---
2
- title: Textdiffuser 2 Demo
3
- emoji: ⚡
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # textdiffuser-2-demo
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - TextDiffuser-2 implementation for Hugging Face Spaces
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ import json
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from diffusers import StableDiffusionPipeline
10
+
11
+ # Check for GPU
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ class SimpleTextDiffuser:
16
+ """
17
+ Simple implementation of TextDiffuser-2 concept for Hugging Face Spaces
18
+ """
19
+ def __init__(self):
20
+ # Load language model for layout generation
21
+ # Using a small model for efficiency
22
+ self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
23
+ self.language_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
24
+ self.language_model.to(device)
25
+
26
+ # Only load the diffusion model if we have a GPU
27
+ self.diffusion_model = None
28
+ if torch.cuda.is_available():
29
+ self.diffusion_model = StableDiffusionPipeline.from_pretrained(
30
+ "runwayml/stable-diffusion-v1-5",
31
+ torch_dtype=torch.float16
32
+ )
33
+ self.diffusion_model.to(device)
34
+
35
+ print("Models initialized")
36
+
37
+ def generate_layout(self, prompt, image_size=(512, 512), num_text_elements=3):
38
+ """Generate text layout based on prompt"""
39
+ width, height = image_size
40
+
41
+ # Format the prompt for layout generation
42
+ layout_prompt = f"""
43
+ Create a layout for an image with:
44
+ - Description: {prompt}
45
+ - Image size: {width}x{height}
46
+ - Number of text elements: {num_text_elements}
47
+
48
+ Generate text content and positions:
49
+ """
50
+
51
+ # Generate layout using LM
52
+ input_ids = self.tokenizer.encode(layout_prompt, return_tensors="pt").to(device)
53
+ with torch.no_grad():
54
+ output = self.language_model.generate(
55
+ input_ids,
56
+ max_length=input_ids.shape[1] + 150,
57
+ temperature=0.7,
58
+ num_return_sequences=1,
59
+ pad_token_id=self.tokenizer.eos_token_id
60
+ )
61
+
62
+ layout_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
63
+
64
+ # Parse the generated layout (simplified)
65
+ # In a real implementation, this would be more sophisticated
66
+ text_elements = []
67
+
68
+ # Simple fallback: generate random layout
69
+ import random
70
+
71
+ # Create a title element
72
+ title = prompt.split()[:5]
73
+ title = " ".join(title) + "..."
74
+ title_x = width // 4
75
+ title_y = height // 4
76
+ text_elements.append({
77
+ "text": title,
78
+ "position": (title_x, title_y),
79
+ "size": 24,
80
+ "color": (0, 0, 0),
81
+ "type": "title"
82
+ })
83
+
84
+ # Create additional text elements
85
+ sample_texts = [
86
+ "Premium Quality",
87
+ "Best Value",
88
+ "Limited Edition",
89
+ "New Collection",
90
+ "Special Offer",
91
+ "Coming Soon",
92
+ "Best Seller",
93
+ "Top Choice",
94
+ "Featured Product",
95
+ "Exclusive Deal"
96
+ ]
97
+
98
+ for i in range(1, num_text_elements):
99
+ x = random.randint(width // 8, width * 3 // 4)
100
+ y = random.randint(height // 3, height * 3 // 4)
101
+ text = sample_texts[i % len(sample_texts)]
102
+ color = (
103
+ random.randint(0, 200),
104
+ random.randint(0, 200),
105
+ random.randint(0, 200)
106
+ )
107
+
108
+ text_elements.append({
109
+ "text": text,
110
+ "position": (x, y),
111
+ "size": 18,
112
+ "color": color,
113
+ "type": f"element_{i}"
114
+ })
115
+
116
+ return text_elements, layout_text
117
+
118
+ def generate_image(self, prompt, image_size=(512, 512)):
119
+ """Generate base image using diffusion model or placeholder"""
120
+ width, height = image_size
121
+
122
+ if self.diffusion_model and torch.cuda.is_available():
123
+ # Generate image using diffusion model
124
+ image = self.diffusion_model(
125
+ prompt=prompt,
126
+ height=height,
127
+ width=width,
128
+ num_inference_steps=30
129
+ ).images[0]
130
+ else:
131
+ # Create a placeholder gradient image
132
+ image = Image.new("RGB", image_size, (240, 240, 240))
133
+
134
+ # Add a colored gradient background
135
+ for y in range(height):
136
+ for x in range(width):
137
+ r = int(240 - 100 * (y / height))
138
+ g = int(240 - 50 * (x / width))
139
+ b = int(240 - 75 * ((x + y) / (width + height)))
140
+ image.putpixel((x, y), (r, g, b))
141
+
142
+ return image
143
+
144
+ def render_text(self, image, text_elements):
145
+ """Render text elements onto the image"""
146
+ image_with_text = image.copy()
147
+ draw = ImageDraw.Draw(image_with_text)
148
+
149
+ for element in text_elements:
150
+ try:
151
+ font_size = element["size"]
152
+
153
+ # Try to load a font, fall back to default if not available
154
+ try:
155
+ font = ImageFont.truetype("DejaVuSans.ttf", font_size)
156
+ except IOError:
157
+ try:
158
+ font = ImageFont.truetype("Arial.ttf", font_size)
159
+ except IOError:
160
+ font = ImageFont.load_default()
161
+
162
+ # Draw text with background for better visibility
163
+ text = element["text"]
164
+ position = element["position"]
165
+ color = element["color"]
166
+
167
+ # Get text size to create background
168
+ bbox = draw.textbbox(position, text, font=font)
169
+ text_width = bbox[2] - bbox[0]
170
+ text_height = bbox[3] - bbox[1]
171
+
172
+ # Draw semi-transparent background
173
+ padding = 5
174
+ background_box = [
175
+ position[0] - padding,
176
+ position[1] - padding,
177
+ position[0] + text_width + padding,
178
+ position[1] + text_height + padding
179
+ ]
180
+ draw.rectangle(background_box, fill=(255, 255, 255, 200))
181
+
182
+ # Draw text
183
+ draw.text(position, text, fill=color, font=font)
184
+
185
+ except Exception as e:
186
+ print(f"Error rendering text: {e}")
187
+ continue
188
+
189
+ return image_with_text
190
+
191
+ def visualize_layout(self, text_elements, image_size=(512, 512)):
192
+ """Create a visualization of the text layout"""
193
+ width, height = image_size
194
+ image = Image.new("RGB", image_size, (255, 255, 255))
195
+ draw = ImageDraw.Draw(image)
196
+
197
+ # Draw grid
198
+ for x in range(0, width, 50):
199
+ draw.line([(x, 0), (x, height)], fill=(230, 230, 230))
200
+ for y in range(0, height, 50):
201
+ draw.line([(0, y), (width, y)], fill=(230, 230, 230))
202
+
203
+ # Draw text elements
204
+ for element in text_elements:
205
+ position = element["position"]
206
+ text = element["text"]
207
+ element_type = element.get("type", "unknown")
208
+
209
+ # Draw position marker
210
+ circle_radius = 5
211
+ circle_bbox = [
212
+ position[0] - circle_radius,
213
+ position[1] - circle_radius,
214
+ position[0] + circle_radius,
215
+ position[1] + circle_radius
216
+ ]
217
+ draw.ellipse(circle_bbox, fill=(255, 0, 0))
218
+
219
+ # Draw text label
220
+ try:
221
+ font = ImageFont.truetype("DejaVuSans.ttf", 12)
222
+ except IOError:
223
+ font = ImageFont.load_default()
224
+
225
+ # Draw text preview and position info
226
+ info_text = f"{text} ({element_type})"
227
+ pos_text = f"Position: ({position[0]}, {position[1]})"
228
+ draw.text((position[0] + 10, position[1]), info_text, fill=(0, 0, 0), font=font)
229
+ draw.text((position[0] + 10, position[1] + 15), pos_text, fill=(0, 0, 255), font=font)
230
+
231
+ return image
232
+
233
+ def generate_text_image(self, prompt, width=512, height=512, num_text_elements=3):
234
+ """Generate an image with rendered text based on prompt"""
235
+ # Validate inputs
236
+ width = max(256, min(1024, width))
237
+ height = max(256, min(1024, height))
238
+ num_text_elements = max(1, min(5, num_text_elements))
239
+
240
+ image_size = (width, height)
241
+
242
+ # Step 1: Generate text layout
243
+ text_elements, layout_text = self.generate_layout(prompt, image_size, num_text_elements)
244
+
245
+ # Step 2: Generate base image
246
+ base_image = self.generate_image(prompt, image_size)
247
+
248
+ # Step 3: Render text onto the image
249
+ image_with_text = self.render_text(base_image, text_elements)
250
+
251
+ # Step 4: Create layout visualization
252
+ layout_visualization = self.visualize_layout(text_elements, image_size)
253
+
254
+ # Step 5: Format layout information for display
255
+ layout_info = {
256
+ "prompt": prompt,
257
+ "image_size": image_size,
258
+ "num_text_elements": num_text_elements,
259
+ "text_elements": text_elements,
260
+ "layout_generation_prompt": layout_text
261
+ }
262
+
263
+ formatted_layout = json.dumps(layout_info, indent=2)
264
+
265
+ return image_with_text, layout_visualization, formatted_layout
266
+
267
+ # Initialize the model
268
+ model = SimpleTextDiffuser()
269
+
270
+ # Define the Gradio interface
271
+ def process_request(prompt, width, height, num_text_elements):
272
+ try:
273
+ width = int(width)
274
+ height = int(height)
275
+ num_text_elements = int(num_text_elements)
276
+
277
+ image, layout, layout_info = model.generate_text_image(
278
+ prompt,
279
+ width=width,
280
+ height=height,
281
+ num_text_elements=num_text_elements
282
+ )
283
+
284
+ return image, layout, layout_info
285
+ except Exception as e:
286
+ error_message = f"Error: {str(e)}"
287
+ print(error_message)
288
+ return None, None, error_message
289
+
290
+ # Create the Gradio app
291
+ with gr.Blocks(title="TextDiffuser-2 Demo") as demo:
292
+ gr.Markdown("""
293
+ # TextDiffuser-2 Demo
294
+
295
+ This demo implements the concepts from the paper "[TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering](https://arxiv.org/abs/2311.16465)" by Jingye Chen et al.
296
+
297
+ Generate images with text by providing a descriptive prompt below.
298
+ """)
299
+
300
+ with gr.Row():
301
+ with gr.Column(scale=1):
302
+ prompt_input = gr.Textbox(
303
+ label="Prompt",
304
+ value="A modern business poster with company name and tagline",
305
+ lines=3
306
+ )
307
+
308
+ with gr.Row():
309
+ width_input = gr.Number(label="Width", value=512, minimum=256, maximum=1024, step=64)
310
+ height_input = gr.Number(label="Height", value=512, minimum=256, maximum=1024, step=64)
311
+
312
+ num_elements_input = gr.Slider(
313
+ label="Number of Text Elements",
314
+ minimum=1,
315
+ maximum=5,
316
+ value=3,
317
+ step=1
318
+ )
319
+
320
+ submit_button = gr.Button("Generate Image", variant="primary")
321
+
322
+ with gr.Column(scale=2):
323
+ with gr.Tabs():
324
+ with gr.TabItem("Generated Image"):
325
+ image_output = gr.Image(label="Image with Text")
326
+
327
+ with gr.TabItem("Layout Visualization"):
328
+ layout_output = gr.Image(label="Text Layout")
329
+
330
+ with gr.TabItem("Layout Information"):
331
+ layout_info_output = gr.Code(language="json", label="Layout Data")
332
+
333
+ gr.Markdown("""
334
+ ## Example Prompts
335
+
336
+ Try these prompts or create your own:
337
+ """)
338
+
339
+ examples = gr.Examples(
340
+ examples=[
341
+ ["A movie poster for a sci-fi thriller", 512, 768, 3],
342
+ ["A motivational quote on a sunset background", 768, 512, 2],
343
+ ["A coffee shop menu with prices", 512, 512, 4],
344
+ ["A modern business card design", 512, 384, 3],
345
+ ],
346
+ inputs=[prompt_input, width_input, height_input, num_elements_input]
347
+ )
348
+
349
+ submit_button.click(
350
+ fn=process_request,
351
+ inputs=[prompt_input, width_input, height_input, num_elements_input],
352
+ outputs=[image_output, layout_output, layout_info_output]
353
+ )
354
+
355
+ gr.Markdown("""
356
+ ## About
357
+
358
+ This is a simplified implementation for demonstration purposes. The full approach described in the paper involves deeper integration of language models with the diffusion process.
359
+
360
+ Running on: """ + str(device))
361
+
362
+ # Launch the app
363
+ if __name__ == "__main__":
364
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=1.12.0
2
+ transformers>=4.26.0
3
+ diffusers>=0.14.0
4
+ accelerate>=0.16.0
5
+ numpy>=1.22.0
6
+ Pillow>=9.0.0
7
+ gradio>=3.20.0