mostlycached commited on
Commit
a38e6ab
·
verified ·
1 Parent(s): fd31821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +443 -265
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - TextDiffuser-2 implementation for Hugging Face Spaces
2
  import os
3
  import torch
4
  import gradio as gr
@@ -7,357 +7,535 @@ import json
7
  from PIL import Image, ImageDraw, ImageFont
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Check for GPU
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Using device: {device}")
14
 
15
- class SimpleTextDiffuser:
 
 
 
16
  """
17
- Simple implementation of TextDiffuser-2 concept for Hugging Face Spaces
18
  """
19
  def __init__(self):
20
- # Load language model for layout generation
21
- # Using a small model for efficiency
22
- self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
23
- self.language_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
24
- self.language_model.to(device)
25
 
26
- # Only load the diffusion model if we have a GPU
27
- self.diffusion_model = None
28
- if torch.cuda.is_available():
29
- self.diffusion_model = StableDiffusionPipeline.from_pretrained(
30
- "runwayml/stable-diffusion-v1-5",
31
- torch_dtype=torch.float16
 
32
  )
33
- self.diffusion_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- print("Models initialized")
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def generate_layout(self, prompt, image_size=(512, 512), num_text_elements=3):
38
- """Generate text layout based on prompt"""
39
- width, height = image_size
40
-
41
- # Format the prompt for layout generation
42
- layout_prompt = f"""
43
- Create a layout for an image with:
44
- - Description: {prompt}
45
- - Image size: {width}x{height}
46
- - Number of text elements: {num_text_elements}
47
-
48
- Generate text content and positions:
49
  """
 
50
 
51
- # Generate layout using LM
52
- input_ids = self.tokenizer.encode(layout_prompt, return_tensors="pt").to(device)
53
- with torch.no_grad():
54
- output = self.language_model.generate(
55
- input_ids,
56
- max_length=input_ids.shape[1] + 150,
57
- temperature=0.7,
58
- num_return_sequences=1,
59
- pad_token_id=self.tokenizer.eos_token_id
60
- )
61
-
62
- layout_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
63
-
64
- # Parse the generated layout (simplified)
65
- # In a real implementation, this would be more sophisticated
66
- text_elements = []
67
-
68
- # Simple fallback: generate random layout
69
- import random
70
-
71
- # Create a title element
72
- title = prompt.split()[:5]
73
- title = " ".join(title) + "..."
74
- title_x = width // 4
75
- title_y = height // 4
76
- text_elements.append({
77
- "text": title,
78
- "position": (title_x, title_y),
79
- "size": 24,
80
- "color": (0, 0, 0),
81
- "type": "title"
82
- })
83
-
84
- # Create additional text elements
85
- sample_texts = [
86
- "Premium Quality",
87
- "Best Value",
88
- "Limited Edition",
89
- "New Collection",
90
- "Special Offer",
91
- "Coming Soon",
92
- "Best Seller",
93
- "Top Choice",
94
- "Featured Product",
95
- "Exclusive Deal"
96
- ]
97
-
98
- for i in range(1, num_text_elements):
99
- x = random.randint(width // 8, width * 3 // 4)
100
- y = random.randint(height // 3, height * 3 // 4)
101
- text = sample_texts[i % len(sample_texts)]
102
- color = (
103
- random.randint(0, 200),
104
- random.randint(0, 200),
105
- random.randint(0, 200)
106
- )
107
 
108
- text_elements.append({
109
- "text": text,
110
- "position": (x, y),
111
- "size": 18,
112
- "color": color,
113
- "type": f"element_{i}"
114
- })
115
-
116
- return text_elements, layout_text
117
-
118
- def generate_image(self, prompt, image_size=(512, 512)):
119
- """Generate base image using diffusion model or placeholder"""
120
  width, height = image_size
121
 
122
- if self.diffusion_model and torch.cuda.is_available():
123
- # Generate image using diffusion model
124
- image = self.diffusion_model(
125
- prompt=prompt,
126
- height=height,
127
- width=width,
128
- num_inference_steps=30
129
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
- # Create a placeholder gradient image
132
- image = Image.new("RGB", image_size, (240, 240, 240))
133
-
134
- # Add a colored gradient background
135
- for y in range(height):
136
- for x in range(width):
137
- r = int(240 - 100 * (y / height))
138
- g = int(240 - 50 * (x / width))
139
- b = int(240 - 75 * ((x + y) / (width + height)))
140
- image.putpixel((x, y), (r, g, b))
141
 
142
- return image
143
 
144
- def render_text(self, image, text_elements):
145
- """Render text elements onto the image"""
146
- image_with_text = image.copy()
147
- draw = ImageDraw.Draw(image_with_text)
 
 
 
 
 
 
 
 
 
148
 
149
- for element in text_elements:
 
 
 
 
150
  try:
151
- font_size = element["size"]
 
 
 
152
 
153
- # Try to load a font, fall back to default if not available
154
- try:
155
- font = ImageFont.truetype("DejaVuSans.ttf", font_size)
156
- except IOError:
157
- try:
158
- font = ImageFont.truetype("Arial.ttf", font_size)
159
- except IOError:
160
- font = ImageFont.load_default()
161
 
162
- # Draw text with background for better visibility
163
- text = element["text"]
164
- position = element["position"]
165
- color = element["color"]
166
-
167
- # Get text size to create background
168
- bbox = draw.textbbox(position, text, font=font)
169
- text_width = bbox[2] - bbox[0]
170
- text_height = bbox[3] - bbox[1]
171
-
172
- # Draw semi-transparent background
173
- padding = 5
174
- background_box = [
175
- position[0] - padding,
176
- position[1] - padding,
177
- position[0] + text_width + padding,
178
- position[1] + text_height + padding
179
- ]
180
- draw.rectangle(background_box, fill=(255, 255, 255, 200))
181
-
182
- # Draw text
183
- draw.text(position, text, fill=color, font=font)
 
184
 
185
  except Exception as e:
186
- print(f"Error rendering text: {e}")
187
  continue
188
 
189
- return image_with_text
190
 
191
- def visualize_layout(self, text_elements, image_size=(512, 512)):
192
- """Create a visualization of the text layout"""
 
 
 
 
 
 
 
 
 
 
193
  width, height = image_size
194
- image = Image.new("RGB", image_size, (255, 255, 255))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  draw = ImageDraw.Draw(image)
196
 
197
- # Draw grid
198
- for x in range(0, width, 50):
199
- draw.line([(x, 0), (x, height)], fill=(230, 230, 230))
200
- for y in range(0, height, 50):
201
- draw.line([(0, y), (width, y)], fill=(230, 230, 230))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  # Draw text elements
204
- for element in text_elements:
205
- position = element["position"]
206
  text = element["text"]
207
- element_type = element.get("type", "unknown")
208
-
209
- # Draw position marker
210
- circle_radius = 5
211
- circle_bbox = [
212
- position[0] - circle_radius,
213
- position[1] - circle_radius,
214
- position[0] + circle_radius,
215
- position[1] + circle_radius
216
- ]
217
- draw.ellipse(circle_bbox, fill=(255, 0, 0))
218
 
219
  # Draw text label
220
- try:
221
- font = ImageFont.truetype("DejaVuSans.ttf", 12)
222
- except IOError:
223
- font = ImageFont.load_default()
 
 
224
 
225
- # Draw text preview and position info
226
- info_text = f"{text} ({element_type})"
227
- pos_text = f"Position: ({position[0]}, {position[1]})"
228
- draw.text((position[0] + 10, position[1]), info_text, fill=(0, 0, 0), font=font)
229
- draw.text((position[0] + 10, position[1] + 15), pos_text, fill=(0, 0, 255), font=font)
 
 
 
230
 
231
  return image
232
 
233
- def generate_text_image(self, prompt, width=512, height=512, num_text_elements=3):
234
- """Generate an image with rendered text based on prompt"""
235
- # Validate inputs
236
- width = max(256, min(1024, width))
237
- height = max(256, min(1024, height))
238
- num_text_elements = max(1, min(5, num_text_elements))
239
 
240
- image_size = (width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- # Step 1: Generate text layout
243
- text_elements, layout_text = self.generate_layout(prompt, image_size, num_text_elements)
 
244
 
245
- # Step 2: Generate base image
246
- base_image = self.generate_image(prompt, image_size)
 
 
 
 
 
247
 
248
- # Step 3: Render text onto the image
249
- image_with_text = self.render_text(base_image, text_elements)
 
 
 
250
 
251
- # Step 4: Create layout visualization
252
- layout_visualization = self.visualize_layout(text_elements, image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- # Step 5: Format layout information for display
255
- layout_info = {
 
 
 
 
 
 
 
 
 
 
256
  "prompt": prompt,
 
257
  "image_size": image_size,
258
- "num_text_elements": num_text_elements,
259
- "text_elements": text_elements,
260
- "layout_generation_prompt": layout_text
261
  }
262
 
263
- formatted_layout = json.dumps(layout_info, indent=2)
264
-
265
- return image_with_text, layout_visualization, formatted_layout
266
 
267
  # Initialize the model
268
- model = SimpleTextDiffuser()
269
-
270
- # Define the Gradio interface
271
- def process_request(prompt, width, height, num_text_elements):
272
- try:
273
- width = int(width)
274
- height = int(height)
275
- num_text_elements = int(num_text_elements)
276
-
277
- image, layout, layout_info = model.generate_text_image(
278
- prompt,
279
- width=width,
280
- height=height,
281
- num_text_elements=num_text_elements
282
- )
283
-
284
- return image, layout, layout_info
285
- except Exception as e:
286
- error_message = f"Error: {str(e)}"
287
- print(error_message)
288
- return None, None, error_message
289
 
290
- # Create the Gradio app
291
- with gr.Blocks(title="TextDiffuser-2 Demo") as demo:
292
  gr.Markdown("""
293
- # TextDiffuser-2 Demo
294
 
295
- This demo implements the concepts from the paper "[TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering](https://arxiv.org/abs/2311.16465)" by Jingye Chen et al.
296
 
297
- Generate images with text by providing a descriptive prompt below.
 
 
 
 
298
  """)
299
 
300
  with gr.Row():
301
  with gr.Column(scale=1):
302
  prompt_input = gr.Textbox(
303
  label="Prompt",
304
- value="A modern business poster with company name and tagline",
305
- lines=3
 
 
 
 
 
 
 
306
  )
307
 
308
  with gr.Row():
309
  width_input = gr.Number(label="Width", value=512, minimum=256, maximum=1024, step=64)
310
  height_input = gr.Number(label="Height", value=512, minimum=256, maximum=1024, step=64)
311
 
312
- num_elements_input = gr.Slider(
313
- label="Number of Text Elements",
314
- minimum=1,
315
- maximum=5,
316
- value=3,
317
- step=1
 
318
  )
319
 
320
- submit_button = gr.Button("Generate Image", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  with gr.Column(scale=2):
323
  with gr.Tabs():
324
- with gr.TabItem("Generated Image"):
325
- image_output = gr.Image(label="Image with Text")
326
-
327
  with gr.TabItem("Layout Visualization"):
328
- layout_output = gr.Image(label="Text Layout")
 
 
 
329
 
330
  with gr.TabItem("Layout Information"):
331
- layout_info_output = gr.Code(language="json", label="Layout Data")
332
-
333
- gr.Markdown("""
334
- ## Example Prompts
335
-
336
- Try these prompts or create your own:
337
- """)
338
 
339
- examples = gr.Examples(
 
340
  examples=[
341
- ["A movie poster for a sci-fi thriller", 512, 768, 3],
342
- ["A motivational quote on a sunset background", 768, 512, 2],
343
- ["A coffee shop menu with prices", 512, 512, 4],
344
- ["A modern business card design", 512, 384, 3],
 
 
 
345
  ],
346
- inputs=[prompt_input, width_input, height_input, num_elements_input]
347
  )
348
 
349
- submit_button.click(
350
- fn=process_request,
351
- inputs=[prompt_input, width_input, height_input, num_elements_input],
352
- outputs=[image_output, layout_output, layout_info_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  )
354
 
355
  gr.Markdown("""
356
- ## About
 
 
357
 
358
- This is a simplified implementation for demonstration purposes. The full approach described in the paper involves deeper integration of language models with the diffusion process.
359
 
360
- Running on: """ + str(device))
 
361
 
362
  # Launch the app
363
  if __name__ == "__main__":
 
1
+ # app.py - TextDiffuser-2 implementation with focus on layout planning
2
  import os
3
  import torch
4
  import gradio as gr
 
7
  from PIL import Image, ImageDraw, ImageFont
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from diffusers import StableDiffusionPipeline
10
+ import time
11
+ import random
12
+
13
+ # Try to import fastchat - may need to install with pip if not available
14
+ try:
15
+ from fastchat.model import get_conversation_template
16
+ except ImportError:
17
+ # Fallback implementation if fastchat is not available
18
+ print("FastChat not found. Installing...")
19
+ os.system("pip install fschat")
20
+ from fastchat.model import get_conversation_template
21
 
22
  # Check for GPU
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  print(f"Using device: {device}")
25
 
26
+ # Define global storage for user interactions
27
+ global_dict = {}
28
+
29
+ class TextDiffuserLayoutPlanner:
30
  """
31
+ Implementation focused on the layout planning aspect of TextDiffuser-2
32
  """
33
  def __init__(self):
34
+ # Load the layout planner model
35
+ self.layout_model_path = "JingyeChen22/textdiffuser2_layout_planner"
 
 
 
36
 
37
+ print(f"Loading layout planner model from {self.layout_model_path}...")
38
+
39
+ try:
40
+ # Initialize the tokenizer and model
41
+ self.layout_tokenizer = AutoTokenizer.from_pretrained(
42
+ self.layout_model_path,
43
+ use_fast=False
44
  )
45
+
46
+ # Load the model with half precision if GPU is available
47
+ model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
48
+ self.layout_model = AutoModelForCausalLM.from_pretrained(
49
+ self.layout_model_path,
50
+ torch_dtype=model_dtype,
51
+ low_cpu_mem_usage=True
52
+ ).to(device)
53
+
54
+ print("Layout planner model loaded successfully")
55
+ except Exception as e:
56
+ print(f"Error loading layout planner: {e}")
57
+ print("Falling back to simpler implementation...")
58
+ # Set models to None to indicate fallback mode
59
+ self.layout_model = None
60
+ self.layout_tokenizer = None
61
 
62
+ # Initialize a simple diffusion model for context visualization
63
+ # This is optional and could be removed if you only need layout
64
+ self.diffusion_model = None
65
+ if torch.cuda.is_available():
66
+ try:
67
+ self.diffusion_model = StableDiffusionPipeline.from_pretrained(
68
+ "runwayml/stable-diffusion-v1-5",
69
+ torch_dtype=torch.float16
70
+ ).to(device)
71
+ print("Diffusion model loaded for context visualization")
72
+ except Exception as e:
73
+ print(f"Could not load diffusion model: {e}")
74
+ print("Will use placeholder images instead")
75
 
76
+ def generate_layout(self, prompt, keywords="", image_size=(512, 512), temperature=0.7):
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
+ Generate a text layout based on the prompt using the layout planner model
79
 
80
+ Args:
81
+ prompt: Description of the image to generate
82
+ keywords: Optional keywords to include in the layout (format: "word1/word2/...")
83
+ image_size: Size of the target image (width, height)
84
+ temperature: Temperature for layout generation (higher = more diverse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ Returns:
87
+ layout_elements: List of text elements with positions
88
+ layout_text: Raw output from the layout planner
89
+ layout_image: Visualization of the layout
90
+ """
 
 
 
 
 
 
 
91
  width, height = image_size
92
 
93
+ # Only proceed with the layout planner if available
94
+ if self.layout_model is not None and self.layout_tokenizer is not None:
95
+ # Format the prompt for layout generation
96
+ if len(keywords.strip()) == 0:
97
+ template = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is {width//4}x{height//4}. Therefore, all properties of the positions should not exceed {width//4}, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}'
98
+ else:
99
+ keywords_list = keywords.split('/')
100
+ keywords_list = [k.strip() for k in keywords_list]
101
+ template = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is {width//4}x{height//4}. Therefore, all properties of the positions should not exceed {width//4}, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords_list)}'
102
+
103
+ # Use FastChat's conversation template
104
+ conv = get_conversation_template(self.layout_model_path)
105
+ conv.append_message(conv.roles[0], template)
106
+ conv.append_message(conv.roles[1], None)
107
+ prompt_text = conv.get_prompt()
108
+
109
+ # Generate the layout
110
+ time_start = time.time()
111
+ print(f"Generating layout for prompt: {prompt}")
112
+
113
+ # Tokenize and prepare inputs
114
+ inputs = self.layout_tokenizer([prompt_text], return_token_type_ids=False)
115
+ inputs = {k: torch.tensor(v).to(device) for k, v in inputs.items()}
116
+
117
+ # Generate layout with the model
118
+ with torch.no_grad():
119
+ output_ids = self.layout_model.generate(
120
+ **inputs,
121
+ do_sample=True,
122
+ temperature=temperature,
123
+ repetition_penalty=1.0,
124
+ max_new_tokens=512,
125
+ )
126
+
127
+ # Process the output
128
+ if self.layout_model.config.is_encoder_decoder:
129
+ output_ids = output_ids[0]
130
+ else:
131
+ output_ids = output_ids[0][len(inputs["input_ids"][0]):]
132
+
133
+ layout_text = self.layout_tokenizer.decode(
134
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
135
+ )
136
+
137
+ time_end = time.time()
138
+ print(f"Layout generation took {time_end - time_start:.2f} seconds")
139
+ print(f"Layout output: {layout_text}")
140
+
141
+ # Parse the layout text to extract text elements
142
+ layout_elements = self.parse_layout_text(layout_text, image_size)
143
+
144
+ # Create a visualization of the layout
145
+ layout_image = self.visualize_layout(layout_elements, image_size)
146
+
147
  else:
148
+ # Fallback: Generate a simple layout
149
+ print("Using fallback layout generation")
150
+ layout_elements = self.generate_fallback_layout(prompt, keywords, image_size)
151
+ layout_text = "Fallback layout generation - Layout planner model not available"
152
+ layout_image = self.visualize_layout(layout_elements, image_size)
 
 
 
 
 
153
 
154
+ return layout_elements, layout_text, layout_image
155
 
156
+ def parse_layout_text(self, layout_text, image_size=(512, 512)):
157
+ """
158
+ Parse the layout text from the layout planner to extract text elements
159
+
160
+ Args:
161
+ layout_text: Output text from the layout planner
162
+ image_size: Size of the target image
163
+
164
+ Returns:
165
+ layout_elements: List of text elements with positions
166
+ """
167
+ layout_elements = []
168
+ lines = layout_text.strip().split('\n')
169
 
170
+ for line in lines:
171
+ line = line.strip()
172
+ if not line or '###' in line or '.com' in line:
173
+ continue
174
+
175
  try:
176
+ # Parse the line to extract text and position
177
+ parts = line.split()
178
+ if len(parts) < 5: # Need at least text and 4 coordinates
179
+ continue
180
 
181
+ # Last 4 parts should be coordinates, everything else is text
182
+ coords = parts[-1]
183
+ text = ' '.join(parts[:-1])
 
 
 
 
 
184
 
185
+ # Parse coordinates (left, top, right, bottom)
186
+ try:
187
+ l, t, r, b = map(int, coords.split(','))
188
+
189
+ # Scale coordinates to image size (they are given in 1/4 scale)
190
+ l, t, r, b = l*4, t*4, r*4, b*4
191
+
192
+ # Create text element
193
+ element = {
194
+ "text": text,
195
+ "position": (l, t),
196
+ "size": (r-l, b-t),
197
+ "box": (l, t, r, b),
198
+ "style": {
199
+ "font": "Arial",
200
+ "size": 24,
201
+ "color": (0, 0, 0)
202
+ }
203
+ }
204
+ layout_elements.append(element)
205
+ except ValueError:
206
+ print(f"Could not parse coordinates in line: {line}")
207
+ continue
208
 
209
  except Exception as e:
210
+ print(f"Error parsing layout line: {e}")
211
  continue
212
 
213
+ return layout_elements
214
 
215
+ def generate_fallback_layout(self, prompt, keywords="", image_size=(512, 512)):
216
+ """
217
+ Generate a fallback layout when the layout planner is not available
218
+
219
+ Args:
220
+ prompt: Description of the image
221
+ keywords: Optional keywords to include
222
+ image_size: Size of the target image
223
+
224
+ Returns:
225
+ layout_elements: List of text elements with positions
226
+ """
227
  width, height = image_size
228
+ layout_elements = []
229
+
230
+ # Extract keywords from the prompt or use provided keywords
231
+ if keywords:
232
+ keywords_list = keywords.split('/')
233
+ keywords_list = [k.strip() for k in keywords_list]
234
+ else:
235
+ # Extract potential keywords from the prompt
236
+ words = prompt.split()
237
+ keywords_list = [word for word in words if len(word) > 3 and word.isalpha()]
238
+ keywords_list = keywords_list[:3] # Limit to 3 keywords
239
+
240
+ # Generate positions for the keywords
241
+ for i, keyword in enumerate(keywords_list):
242
+ # Calculate a position based on the index
243
+ row = i // 2
244
+ col = i % 2
245
+
246
+ l = 50 + col * (width // 2)
247
+ t = 50 + row * (height // 3)
248
+ r = l + 200
249
+ b = t + 50
250
+
251
+ element = {
252
+ "text": keyword,
253
+ "position": (l, t),
254
+ "size": (r-l, b-t),
255
+ "box": (l, t, r, b),
256
+ "style": {
257
+ "font": "Arial",
258
+ "size": 24,
259
+ "color": (0, 0, 0)
260
+ }
261
+ }
262
+ layout_elements.append(element)
263
+
264
+ return layout_elements
265
+
266
+ def visualize_layout(self, layout_elements, image_size=(512, 512)):
267
+ """
268
+ Create a visualization of the text layout
269
+
270
+ Args:
271
+ layout_elements: List of text elements with positions
272
+ image_size: Size of the target image
273
+
274
+ Returns:
275
+ layout_image: Visualization of the layout
276
+ """
277
+ width, height = image_size
278
+ image = Image.new("RGB", image_size, (240, 240, 240))
279
  draw = ImageDraw.Draw(image)
280
 
281
+ # Draw grid lines
282
+ for x in range(0, width, 32):
283
+ alpha = 255 if x % 128 == 0 else 100
284
+ draw.line([(x, 0), (x, height)], fill=(200, 200, 200, alpha), width=1)
285
+
286
+ for y in range(0, height, 32):
287
+ alpha = 255 if y % 128 == 0 else 100
288
+ draw.line([(0, y), (width, y)], fill=(200, 200, 200, alpha), width=1)
289
+
290
+ # Try to load a font
291
+ try:
292
+ font_large = ImageFont.truetype("Arial.ttf", 20)
293
+ font_small = ImageFont.truetype("Arial.ttf", 12)
294
+ except IOError:
295
+ try:
296
+ font_large = ImageFont.truetype("DejaVuSans.ttf", 20)
297
+ font_small = ImageFont.truetype("DejaVuSans.ttf", 12)
298
+ except IOError:
299
+ font_large = ImageFont.load_default()
300
+ font_small = ImageFont.load_default()
301
 
302
  # Draw text elements
303
+ for i, element in enumerate(layout_elements):
304
+ box = element.get("box", (0, 0, 0, 0))
305
  text = element["text"]
306
+
307
+ # Draw bounding box
308
+ draw.rectangle(box, outline=(255, 0, 0), width=2)
 
 
 
 
 
 
 
 
309
 
310
  # Draw text label
311
+ draw.text(
312
+ (box[0] + 5, box[1] - 20),
313
+ f"{i+1}: {text}",
314
+ font=font_small,
315
+ fill=(0, 0, 0)
316
+ )
317
 
318
+ # Draw coordinates
319
+ coord_text = f"({box[0]},{box[1]}) to ({box[2]},{box[3]})"
320
+ draw.text(
321
+ (box[0] + 5, box[3] + 5),
322
+ coord_text,
323
+ font=font_small,
324
+ fill=(0, 0, 255)
325
+ )
326
 
327
  return image
328
 
329
+ def generate_context_image(self, prompt, image_size=(512, 512)):
330
+ """
331
+ Generate a context image based on the prompt
 
 
 
332
 
333
+ Args:
334
+ prompt: Description of the image
335
+ image_size: Size of the target image
336
+
337
+ Returns:
338
+ image: Generated or placeholder image
339
+ """
340
+ if self.diffusion_model is not None:
341
+ # Generate image using the diffusion model
342
+ try:
343
+ images = self.diffusion_model(
344
+ prompt=prompt,
345
+ height=image_size[1],
346
+ width=image_size[0],
347
+ num_inference_steps=20
348
+ ).images
349
+ return images[0]
350
+ except Exception as e:
351
+ print(f"Error generating image: {e}")
352
+ print("Using placeholder image instead")
353
 
354
+ # Create a placeholder gradient image
355
+ width, height = image_size
356
+ image = Image.new("RGB", image_size, (240, 240, 240))
357
 
358
+ # Add a subtle gradient background
359
+ for y in range(height):
360
+ for x in range(width):
361
+ r = int(240 - 30 * (y / height))
362
+ g = int(240 - 20 * (x / width))
363
+ b = int(240 - 40 * ((x + y) / (width + height)))
364
+ image.putpixel((x, y), (r, g, b))
365
 
366
+ return image
367
+
368
+ def process_request(self, prompt, keywords="", width=512, height=512, temperature=0.7, generate_image=False):
369
+ """
370
+ Process a user request to generate a layout
371
 
372
+ Args:
373
+ prompt: Description of the image
374
+ keywords: Optional keywords to include
375
+ width: Width of the target image
376
+ height: Height of the target image
377
+ temperature: Temperature for layout generation
378
+ generate_image: Whether to generate a context image
379
+
380
+ Returns:
381
+ layout_elements: List of text elements with positions
382
+ layout_text: Raw output from the layout planner
383
+ layout_image: Visualization of the layout
384
+ context_image: Generated or placeholder image (if requested)
385
+ """
386
+ image_size = (width, height)
387
 
388
+ # Generate layout
389
+ layout_elements, layout_text, layout_image = self.generate_layout(
390
+ prompt, keywords, image_size, temperature
391
+ )
392
+
393
+ # Generate context image if requested
394
+ context_image = None
395
+ if generate_image:
396
+ context_image = self.generate_context_image(prompt, image_size)
397
+
398
+ # Format the layout data for display
399
+ layout_data = {
400
  "prompt": prompt,
401
+ "keywords": keywords,
402
  "image_size": image_size,
403
+ "text_elements": layout_elements,
 
 
404
  }
405
 
406
+ return layout_elements, layout_text, layout_image, context_image, layout_data
 
 
407
 
408
  # Initialize the model
409
+ model = TextDiffuserLayoutPlanner()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
+ # Create the Gradio interface
412
+ with gr.Blocks(title="TextDiffuser-2 Layout Planner") as demo:
413
  gr.Markdown("""
414
+ # TextDiffuser-2 Layout Planner
415
 
416
+ This application focuses on the layout planning aspect of TextDiffuser-2. It allows you to:
417
 
418
+ 1. Generate text layouts for images based on prompts
419
+ 2. Visualize the layout with text positions and bounding boxes
420
+ 3. Export the layout information for use in your own HTML5 Canvas UI editor
421
+
422
+ Based on the paper "[TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering](https://arxiv.org/abs/2311.16465)" by Jingye Chen et al.
423
  """)
424
 
425
  with gr.Row():
426
  with gr.Column(scale=1):
427
  prompt_input = gr.Textbox(
428
  label="Prompt",
429
+ value="A beautiful city skyline stamp of Shanghai",
430
+ lines=3,
431
+ placeholder="Describe the image you want to generate with text elements"
432
+ )
433
+
434
+ keywords_input = gr.Textbox(
435
+ label="Optional Keywords (separated by /)",
436
+ placeholder="keyword1/keyword2/keyword3",
437
+ info="If provided, the layout planner will try to use these keywords"
438
  )
439
 
440
  with gr.Row():
441
  width_input = gr.Number(label="Width", value=512, minimum=256, maximum=1024, step=64)
442
  height_input = gr.Number(label="Height", value=512, minimum=256, maximum=1024, step=64)
443
 
444
+ temperature_input = gr.Slider(
445
+ label="Temperature",
446
+ minimum=0.1,
447
+ maximum=2.0,
448
+ value=0.7,
449
+ step=0.1,
450
+ info="Controls randomness in layout generation. Higher values produce more diverse layouts."
451
  )
452
 
453
+ show_image_input = gr.Checkbox(
454
+ label="Generate Context Image",
455
+ value=False,
456
+ info="Generate a simple image to provide context (this is just for visualization)"
457
+ )
458
+
459
+ generate_button = gr.Button("Generate Layout", variant="primary")
460
+
461
+ gr.Markdown("""
462
+ ## Tips for using this demo
463
+
464
+ 1. The layout planner works best with descriptive prompts
465
+ 2. You can specify keywords to ensure they appear in the layout
466
+ 3. Increase temperature for more diverse layouts
467
+ 4. The context image is optional and just for visualization
468
+ """)
469
 
470
  with gr.Column(scale=2):
471
  with gr.Tabs():
 
 
 
472
  with gr.TabItem("Layout Visualization"):
473
+ layout_output = gr.Image(label="Text Layout Visualization", type="pil")
474
+
475
+ with gr.TabItem("Context Image"):
476
+ context_image_output = gr.Image(label="Context Image (Optional)", type="pil")
477
 
478
  with gr.TabItem("Layout Information"):
479
+ layout_elements_output = gr.JSON(label="Layout Elements")
480
+
481
+ with gr.TabItem("Raw Layout Output"):
482
+ layout_text_output = gr.Textbox(label="Raw Layout Planner Output", lines=10)
 
 
 
483
 
484
+ # Examples
485
+ gr.Examples(
486
  examples=[
487
+ ["A new year greeting card of happy 2024, surrounded by balloons", "", 512, 512, 0.7, True],
488
+ ["A beautiful city skyline stamp of Shanghai", "", 512, 512, 0.7, True],
489
+ ["The words 'KFC VIVO50' are inscribed upon the wall in a neon light effect", "KFC/VIVO50", 512, 512, 0.7, True],
490
+ ["A logo of superman", "", 512, 512, 0.7, True],
491
+ ["A pencil sketch of a tree with the title nothing to tree here", "nothing/tree/here", 512, 512, 0.7, True],
492
+ ["Delicate greeting card of happy birthday to xyz", "happy/birthday/xyz", 768, 512, 1.0, True],
493
+ ["Book cover of good morning baby", "good/morning/baby", 512, 768, 0.7, True],
494
  ],
495
+ inputs=[prompt_input, keywords_input, width_input, height_input, temperature_input, show_image_input]
496
  )
497
 
498
+ # Function to process the request
499
+ def process_ui_request(prompt, keywords, width, height, temperature, show_image):
500
+ try:
501
+ width = int(width)
502
+ height = int(height)
503
+
504
+ layout_elements, layout_text, layout_image, context_image, layout_data = model.process_request(
505
+ prompt,
506
+ keywords,
507
+ width,
508
+ height,
509
+ temperature,
510
+ show_image
511
+ )
512
+
513
+ if show_image and context_image is not None:
514
+ return layout_image, context_image, layout_data, layout_text
515
+ else:
516
+ return layout_image, None, layout_data, layout_text
517
+
518
+ except Exception as e:
519
+ error_message = f"Error: {str(e)}"
520
+ print(error_message)
521
+ return None, None, {"error": error_message}, error_message
522
+
523
+ # Connect the button to the processing function
524
+ generate_button.click(
525
+ fn=process_ui_request,
526
+ inputs=[prompt_input, keywords_input, width_input, height_input, temperature_input, show_image_input],
527
+ outputs=[layout_output, context_image_output, layout_elements_output, layout_text_output]
528
  )
529
 
530
  gr.Markdown("""
531
+ ## About TextDiffuser-2
532
+
533
+ TextDiffuser-2 is a system that uses language models for text rendering in images. The layout planner component is responsible for determining where text should be positioned in the generated image.
534
 
535
+ This demo focuses only on the layout planning aspect, allowing you to generate and export layout information that can be used in your own HTML5 Canvas UI editor.
536
 
537
+ For the full TextDiffuser-2 implementation, please visit the [official repository](https://github.com/microsoft/unilm/tree/master/textdiffuser-2).
538
+ """)
539
 
540
  # Launch the app
541
  if __name__ == "__main__":