Luis Oala commited on
Commit
e237845
·
2 Parent(s): 2e0468e 8c239b8
Files changed (4) hide show
  1. app.py +2 -3
  2. server.py +1 -1
  3. server.py~ +246 -0
  4. setup.py +0 -1
app.py CHANGED
@@ -192,6 +192,5 @@ iface = gr.Interface(fn=sample,
192
  inputs=gr.inputs.Textbox(label='enter text'),
193
  outputs=gr.outputs.Image(type="pil", label="..."),
194
  title=title,
195
- description=description,
196
- enable_queue=True)
197
- iface.launch(debug=True)
 
192
  inputs=gr.inputs.Textbox(label='enter text'),
193
  outputs=gr.outputs.Image(type="pil", label="..."),
194
  title=title,
195
+ description=description)
196
+ iface.launch(debug=True,enable_queue=True)
 
server.py CHANGED
@@ -172,4 +172,4 @@ def sample(prompt):
172
  def to_base64(pil_image):
173
  buffered = BytesIO()
174
  pil_image.save(buffered, format="JPEG")
175
- return base64.b64encode(buffered.getvalue())
 
172
  def to_base64(pil_image):
173
  buffered = BytesIO()
174
  pil_image.save(buffered, format="JPEG")
175
+ return base64.b64encode(buffered.getvalue())
server.py~ ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from fastapi import FastAPI
4
+ <<<<<<< HEAD
5
+
6
+ from PIL import Image
7
+ import torch as th
8
+
9
+ =======
10
+ from PIL import Image
11
+ import torch as th
12
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
13
+ from glide_text2im.download import load_checkpoint
14
+ from glide_text2im.model_creation import (
15
+ create_model_and_diffusion,
16
+ model_and_diffusion_defaults,
17
+ model_and_diffusion_defaults_upsampler
18
+ )
19
+ <<<<<<< HEAD
20
+
21
+ print("Loading models...")
22
+ app = FastAPI()
23
+
24
+ # This notebook supports both CPU and GPU.
25
+ # On CPU, generating one sample may take on the order of 20 minutes.
26
+ # On a GPU, it should be under a minute.
27
+
28
+ has_cuda = th.cuda.is_available()
29
+ device = th.device('cpu' if not has_cuda else 'cuda')
30
+
31
+ =======
32
+ print("Loading models...")
33
+ app = FastAPI()
34
+ # This notebook supports both CPU and GPU.
35
+ # On CPU, generating one sample may take on the order of 20 minutes.
36
+ # On a GPU, it should be under a minute.
37
+ has_cuda = th.cuda.is_available()
38
+ device = th.device('cpu' if not has_cuda else 'cuda')
39
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
40
+ # Create base model.
41
+ options = model_and_diffusion_defaults()
42
+ options['use_fp16'] = has_cuda
43
+ options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
44
+ model, diffusion = create_model_and_diffusion(**options)
45
+ model.eval()
46
+ if has_cuda:
47
+ model.convert_to_fp16()
48
+ model.to(device)
49
+ model.load_state_dict(load_checkpoint('base', device))
50
+ print('total base parameters', sum(x.numel() for x in model.parameters()))
51
+ <<<<<<< HEAD
52
+
53
+ =======
54
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
55
+ # Create upsampler model.
56
+ options_up = model_and_diffusion_defaults_upsampler()
57
+ options_up['use_fp16'] = has_cuda
58
+ options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
59
+ model_up, diffusion_up = create_model_and_diffusion(**options_up)
60
+ model_up.eval()
61
+ if has_cuda:
62
+ model_up.convert_to_fp16()
63
+ model_up.to(device)
64
+ model_up.load_state_dict(load_checkpoint('upsample', device))
65
+ print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
66
+ <<<<<<< HEAD
67
+
68
+
69
+ =======
70
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
71
+ def get_images(batch: th.Tensor):
72
+ """ Display a batch of images inline. """
73
+ scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
74
+ reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
75
+ Image.fromarray(reshaped.numpy())
76
+ <<<<<<< HEAD
77
+
78
+
79
+ # Create a classifier-free guidance sampling function
80
+ guidance_scale = 3.0
81
+
82
+ =======
83
+ # Create a classifier-free guidance sampling function
84
+ guidance_scale = 3.0
85
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
86
+ def model_fn(x_t, ts, **kwargs):
87
+ half = x_t[: len(x_t) // 2]
88
+ combined = th.cat([half, half], dim=0)
89
+ model_out = model(combined, ts, **kwargs)
90
+ eps, rest = model_out[:, :3], model_out[:, 3:]
91
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
92
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
93
+ eps = th.cat([half_eps, half_eps], dim=0)
94
+ return th.cat([eps, rest], dim=1)
95
+ <<<<<<< HEAD
96
+
97
+
98
+ @app.get("/")
99
+ def read_root():
100
+ return {"glide!"}
101
+
102
+ =======
103
+ @app.get("/")
104
+ def read_root():
105
+ return {"glide!"}
106
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
107
+ @app.get("/{generate}")
108
+ def sample(prompt):
109
+ # Sampling parameters
110
+ batch_size = 1
111
+ <<<<<<< HEAD
112
+
113
+ # Tune this parameter to control the sharpness of 256x256 images.
114
+ # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
115
+ upsample_temp = 0.997
116
+
117
+ ##############################
118
+ # Sample from the base model #
119
+ ##############################
120
+
121
+ =======
122
+ # Tune this parameter to control the sharpness of 256x256 images.
123
+ # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
124
+ upsample_temp = 0.997
125
+ ##############################
126
+ # Sample from the base model #
127
+ ##############################
128
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
129
+ # Create the text tokens to feed to the model.
130
+ tokens = model.tokenizer.encode(prompt)
131
+ tokens, mask = model.tokenizer.padded_tokens_and_mask(
132
+ tokens, options['text_ctx']
133
+ )
134
+ <<<<<<< HEAD
135
+
136
+ =======
137
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
138
+ # Create the classifier-free guidance tokens (empty)
139
+ full_batch_size = batch_size * 2
140
+ uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
141
+ [], options['text_ctx']
142
+ )
143
+ <<<<<<< HEAD
144
+
145
+ =======
146
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
147
+ # Pack the tokens together into model kwargs.
148
+ model_kwargs = dict(
149
+ tokens=th.tensor(
150
+ [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
151
+ ),
152
+ mask=th.tensor(
153
+ [mask] * batch_size + [uncond_mask] * batch_size,
154
+ dtype=th.bool,
155
+ device=device,
156
+ ),
157
+ )
158
+ <<<<<<< HEAD
159
+
160
+ =======
161
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
162
+ # Sample from the base model.
163
+ model.del_cache()
164
+ samples = diffusion.p_sample_loop(
165
+ model_fn,
166
+ (full_batch_size, 3, options["image_size"], options["image_size"]),
167
+ device=device,
168
+ clip_denoised=True,
169
+ progress=True,
170
+ model_kwargs=model_kwargs,
171
+ cond_fn=None,
172
+ )[:batch_size]
173
+ model.del_cache()
174
+ <<<<<<< HEAD
175
+
176
+
177
+ ##############################
178
+ # Upsample the 64x64 samples #
179
+ ##############################
180
+
181
+ =======
182
+ ##############################
183
+ # Upsample the 64x64 samples #
184
+ ##############################
185
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
186
+ tokens = model_up.tokenizer.encode(prompt)
187
+ tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
188
+ tokens, options_up['text_ctx']
189
+ )
190
+ <<<<<<< HEAD
191
+
192
+ =======
193
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
194
+ # Create the model conditioning dict.
195
+ model_kwargs = dict(
196
+ # Low-res image to upsample.
197
+ low_res=((samples+1)*127.5).round()/127.5 - 1,
198
+ <<<<<<< HEAD
199
+
200
+ =======
201
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
202
+ # Text tokens
203
+ tokens=th.tensor(
204
+ [tokens] * batch_size, device=device
205
+ ),
206
+ mask=th.tensor(
207
+ [mask] * batch_size,
208
+ dtype=th.bool,
209
+ device=device,
210
+ ),
211
+ )
212
+ <<<<<<< HEAD
213
+
214
+ =======
215
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
216
+ # Sample from the base model.
217
+ model_up.del_cache()
218
+ up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
219
+ up_samples = diffusion_up.ddim_sample_loop(
220
+ model_up,
221
+ up_shape,
222
+ noise=th.randn(up_shape, device=device) * upsample_temp,
223
+ device=device,
224
+ clip_denoised=True,
225
+ progress=True,
226
+ model_kwargs=model_kwargs,
227
+ cond_fn=None,
228
+ )[:batch_size]
229
+ model_up.del_cache()
230
+ <<<<<<< HEAD
231
+
232
+ =======
233
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
234
+ # Show the output
235
+ image = get_images(up_samples)
236
+ image = to_base64(image)
237
+ return {"image": image}
238
+ <<<<<<< HEAD
239
+
240
+
241
+ =======
242
+ >>>>>>> 8c239b8a9cdaf13e28c145e788b984c129547a37
243
+ def to_base64(pil_image):
244
+ buffered = BytesIO()
245
+ pil_image.save(buffered, format="JPEG")
246
+ return base64.b64encode(buffered.getvalue())
setup.py CHANGED
@@ -1,5 +1,4 @@
1
  from setuptools import setup
2
-
3
  setup(
4
  name="glide-text2im",
5
  packages=["glide_text2im"],
 
1
  from setuptools import setup
 
2
  setup(
3
  name="glide-text2im",
4
  packages=["glide_text2im"],