jbilcke-hf HF Staff commited on
Commit
aaa160e
·
1 Parent(s): 075fea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -51
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces
2
  import os
3
  import datetime
4
  import einops
@@ -22,6 +22,28 @@ from myutils.misc import load_dreambooth_lora, rand_name
22
  from myutils.wavelet_color_fix import wavelet_color_fix
23
  from annotator.retinaface import RetinaFaceDetection
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  use_pasd_light = False
26
  face_detector = RetinaFaceDetection()
27
 
@@ -84,7 +106,7 @@ def resize_image(image_path, target_height):
84
  #resized_img.save(output_path)
85
  return resized_img
86
 
87
- @spaces.GPU(enable_queue=True)
88
  def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
89
  input_image = resize_image(input_image, 512)
90
  process_size = 768
@@ -138,55 +160,18 @@ def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, a
138
  print(e)
139
  image = Image.new(mode="RGB", size=(512, 512))
140
 
141
- # Convert and save the image as JPEG
142
- image.save(f'result_{timestamp}.jpg', 'JPEG')
143
-
144
- # Convert and save the image as JPEG
145
- input_image.save(f'input_{timestamp}.jpg', 'JPEG')
146
-
147
- return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg"
148
 
149
  title = "Pixel-Aware Stable Diffusion for Real-ISR"
150
  description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
151
  article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
152
- #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
153
-
154
- css = """
155
- #col-container{
156
- margin: 0 auto;
157
- max-width: 720px;
158
- }
159
- #project-links{
160
- margin: 0 0 12px !important;
161
- column-gap: 8px;
162
- display: flex;
163
- justify-content: center;
164
- flex-wrap: nowrap;
165
- flex-direction: row;
166
- align-items: center;
167
- }
168
- """
169
-
170
- with gr.Blocks(css=css) as demo:
171
- with gr.Column(elem_id="col-container"):
172
- gr.HTML(f"""
173
- <h2 style="text-align: center;">
174
- PASD Magnify
175
- </h2>
176
- <p style="text-align: center;">
177
- Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
178
- </p>
179
- <p id="project-links" align="center">
180
- <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
181
- </p>
182
- <p style="margin:12px auto;display: flex;justify-content: center;">
183
- <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a>
184
- </p>
185
-
186
- """)
187
  with gr.Row():
188
  with gr.Column():
189
- input_image = gr.Image(type="filepath", sources=["upload"], value="samples/frog.png")
190
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
191
  with gr.Accordion(label="Advanced settings", open=False):
192
  added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
@@ -198,8 +183,7 @@ with gr.Blocks(css=css) as demo:
198
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
199
  submit_btn = gr.Button("Submit")
200
  with gr.Column():
201
- b_a_slider = ImageSlider(label="B/A result", position=0.5)
202
- file_output = gr.File(label="Downloadable image result")
203
 
204
  submit_btn.click(
205
  fn = inference,
@@ -210,9 +194,6 @@ with gr.Blocks(css=css) as demo:
210
  upsample_scale, condition_scale,
211
  classifier_free_guidance, seed
212
  ],
213
- outputs = [
214
- b_a_slider,
215
- file_output
216
- ]
217
  )
218
  demo.queue().launch()
 
1
+ # import spaces
2
  import os
3
  import datetime
4
  import einops
 
22
  from myutils.wavelet_color_fix import wavelet_color_fix
23
  from annotator.retinaface import RetinaFaceDetection
24
 
25
+ from io import BytesIO
26
+ import base64
27
+ import re
28
+
29
+ # Regex pattern to match data URI scheme
30
+ data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
31
+
32
+ def readb64(b64):
33
+ # Remove any data URI scheme prefix with regex
34
+ b64 = data_uri_pattern.sub("", b64)
35
+ # Decode and open the image with PIL
36
+ img = Image.open(BytesIO(base64.b64decode(b64)))
37
+ return img
38
+
39
+ # convert from PIL to base64
40
+ def writeb64(image):
41
+ buffered = BytesIO()
42
+ image.save(buffered, format="PNG")
43
+ b64image = base64.b64encode(buffered.getvalue())
44
+ b64image_str = b64image.decode("utf-8")
45
+ return b64image_str
46
+
47
  use_pasd_light = False
48
  face_detector = RetinaFaceDetection()
49
 
 
106
  #resized_img.save(output_path)
107
  return resized_img
108
 
109
+ # @spaces.GPU(enable_queue=True)
110
  def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
111
  input_image = resize_image(input_image, 512)
112
  process_size = 768
 
160
  print(e)
161
  image = Image.new(mode="RGB", size=(512, 512))
162
 
163
+ return writeb64(image)
 
 
 
 
 
 
164
 
165
  title = "Pixel-Aware Stable Diffusion for Real-ISR"
166
  description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
167
  article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
168
+
169
+
170
+ with gr.Blocks() as demo:
171
+ with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Row():
173
  with gr.Column():
174
+ input_image = gr.Textbox()
175
  prompt_in = gr.Textbox(label="Prompt", value="Frog")
176
  with gr.Accordion(label="Advanced settings", open=False):
177
  added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
 
183
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
184
  submit_btn = gr.Button("Submit")
185
  with gr.Column():
186
+ output_image = gr.Textbox()
 
187
 
188
  submit_btn.click(
189
  fn = inference,
 
194
  upsample_scale, condition_scale,
195
  classifier_free_guidance, seed
196
  ],
197
+ outputs = output_image
 
 
 
198
  )
199
  demo.queue().launch()