Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|
5 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
6 |
from janus.utils.io import load_pil_images
|
7 |
from PIL import Image
|
8 |
-
from diffusers import FluxPipeline
|
9 |
|
10 |
import numpy as np
|
11 |
import os
|
@@ -162,18 +161,10 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
|
|
162 |
|
163 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
164 |
|
165 |
-
outputs = vl_gpt.language_model.generate(
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
bos_token_id=tokenizer.bos_token_id,
|
170 |
-
eos_token_id=tokenizer.eos_token_id,
|
171 |
-
max_new_tokens=512,
|
172 |
-
do_sample=False if temperature == 0 else True,
|
173 |
-
use_cache=True,
|
174 |
-
temperature=temperature,
|
175 |
-
top_p=top_p,
|
176 |
-
)
|
177 |
|
178 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
179 |
|
@@ -234,42 +225,35 @@ def unpack(dec, width, height, parallel_size=5):
|
|
234 |
@torch.inference_mode()
|
235 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
236 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
237 |
-
#
|
238 |
-
|
239 |
-
#
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
# Depending on the variant being used, the pipeline call will slightly vary.
|
267 |
-
# Refer to the pipeline documentation for more details.
|
268 |
-
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
269 |
-
pipe.to("cuda")
|
270 |
|
271 |
-
image = pipe(prompt=prompt, guidance_scale=guidance, height=768, width=768, num_inference_steps=16,).images[0]
|
272 |
-
return image
|
273 |
|
274 |
|
275 |
@spaces.GPU(duration=60)
|
|
|
5 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
6 |
from janus.utils.io import load_pil_images
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
import numpy as np
|
10 |
import os
|
|
|
161 |
|
162 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
163 |
|
164 |
+
outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
|
165 |
+
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
|
166 |
+
eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
|
167 |
+
do_sample=False if temperature == 0 else True, use_cache=True,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
170 |
|
|
|
225 |
@torch.inference_mode()
|
226 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
227 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
|
228 |
+
# Clear CUDA cache and avoid tracking gradients
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
# Set the seed for reproducible results
|
231 |
+
if seed is not None:
|
232 |
+
torch.manual_seed(seed)
|
233 |
+
torch.cuda.manual_seed(seed)
|
234 |
+
np.random.seed(seed)
|
235 |
+
width = 384
|
236 |
+
height = 384
|
237 |
+
parallel_size = 1
|
238 |
|
239 |
+
with torch.no_grad():
|
240 |
+
messages = [{'role': '<|User|>', 'content': prompt},
|
241 |
+
{'role': '<|Assistant|>', 'content': ''}]
|
242 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
243 |
+
sft_format=vl_chat_processor.sft_format,
|
244 |
+
system_prompt='')
|
245 |
+
text = text + vl_chat_processor.image_start_tag
|
246 |
|
247 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
248 |
+
output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
|
249 |
+
parallel_size=parallel_size, temperature=t2i_temperature)
|
250 |
+
images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
|
251 |
+
|
252 |
+
stime = time.time()
|
253 |
+
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
254 |
+
print(f'upsample time: {time.time() - stime}')
|
255 |
+
return ret_images
|
|
|
|
|
|
|
|
|
|
|
256 |
|
|
|
|
|
257 |
|
258 |
|
259 |
@spaces.GPU(duration=60)
|