NikhilJoson commited on
Commit
b3020f6
·
verified ·
1 Parent(s): 55736f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -46
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoConfig, AutoModelForCausalLM
5
  from janus.models import MultiModalityCausalLM, VLChatProcessor
6
  from janus.utils.io import load_pil_images
7
  from PIL import Image
8
- from diffusers import FluxPipeline
9
 
10
  import numpy as np
11
  import os
@@ -162,18 +161,10 @@ def unified_chat(image, message, chat_history, seed, top_p, temperature, cfg_wei
162
 
163
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
164
 
165
- outputs = vl_gpt.language_model.generate(
166
- inputs_embeds=inputs_embeds,
167
- attention_mask=prepare_inputs.attention_mask,
168
- pad_token_id=tokenizer.eos_token_id,
169
- bos_token_id=tokenizer.bos_token_id,
170
- eos_token_id=tokenizer.eos_token_id,
171
- max_new_tokens=512,
172
- do_sample=False if temperature == 0 else True,
173
- use_cache=True,
174
- temperature=temperature,
175
- top_p=top_p,
176
- )
177
 
178
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
179
 
@@ -234,42 +225,35 @@ def unpack(dec, width, height, parallel_size=5):
234
  @torch.inference_mode()
235
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
236
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
237
- # # Clear CUDA cache and avoid tracking gradients
238
- # torch.cuda.empty_cache()
239
- # # Set the seed for reproducible results
240
- # if seed is not None:
241
- # torch.manual_seed(seed)
242
- # torch.cuda.manual_seed(seed)
243
- # np.random.seed(seed)
244
- # width = 384
245
- # height = 384
246
- # parallel_size = 1
247
 
248
- # with torch.no_grad():
249
- # messages = [{'role': '<|User|>', 'content': prompt},
250
- # {'role': '<|Assistant|>', 'content': ''}]
251
- # text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
252
- # sft_format=vl_chat_processor.sft_format,
253
- # system_prompt='')
254
- # text = text + vl_chat_processor.image_start_tag
255
 
256
- # input_ids = torch.LongTensor(tokenizer.encode(text))
257
- # output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
258
- # parallel_size=parallel_size, temperature=t2i_temperature)
259
- # images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
260
-
261
- # stime = time.time()
262
- # ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
263
- # print(f'upsample time: {time.time() - stime}')
264
- # return ret_images
265
-
266
- # Depending on the variant being used, the pipeline call will slightly vary.
267
- # Refer to the pipeline documentation for more details.
268
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
269
- pipe.to("cuda")
270
 
271
- image = pipe(prompt=prompt, guidance_scale=guidance, height=768, width=768, num_inference_steps=16,).images[0]
272
- return image
273
 
274
 
275
  @spaces.GPU(duration=60)
 
5
  from janus.models import MultiModalityCausalLM, VLChatProcessor
6
  from janus.utils.io import load_pil_images
7
  from PIL import Image
 
8
 
9
  import numpy as np
10
  import os
 
161
 
162
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
163
 
164
+ outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
165
+ pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
166
+ eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
167
+ do_sample=False if temperature == 0 else True, use_cache=True,)
 
 
 
 
 
 
 
 
168
 
169
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
170
 
 
225
  @torch.inference_mode()
226
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
227
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
228
+ # Clear CUDA cache and avoid tracking gradients
229
+ torch.cuda.empty_cache()
230
+ # Set the seed for reproducible results
231
+ if seed is not None:
232
+ torch.manual_seed(seed)
233
+ torch.cuda.manual_seed(seed)
234
+ np.random.seed(seed)
235
+ width = 384
236
+ height = 384
237
+ parallel_size = 1
238
 
239
+ with torch.no_grad():
240
+ messages = [{'role': '<|User|>', 'content': prompt},
241
+ {'role': '<|Assistant|>', 'content': ''}]
242
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
243
+ sft_format=vl_chat_processor.sft_format,
244
+ system_prompt='')
245
+ text = text + vl_chat_processor.image_start_tag
246
 
247
+ input_ids = torch.LongTensor(tokenizer.encode(text))
248
+ output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
249
+ parallel_size=parallel_size, temperature=t2i_temperature)
250
+ images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
251
+
252
+ stime = time.time()
253
+ ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
254
+ print(f'upsample time: {time.time() - stime}')
255
+ return ret_images
 
 
 
 
 
256
 
 
 
257
 
258
 
259
  @spaces.GPU(duration=60)