NikhilJoson commited on
Commit
418b14d
·
verified ·
1 Parent(s): 8e9eece

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -29
app.py CHANGED
@@ -4,6 +4,9 @@ from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
 
 
 
7
 
8
  import numpy as np
9
  import os
@@ -18,9 +21,7 @@ model_path = "deepseek-ai/Janus-Pro-7B"
18
  config = AutoConfig.from_pretrained(model_path)
19
  language_config = config.language_config
20
  language_config._attn_implementation = 'eager'
21
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
22
- language_config=language_config,
23
- trust_remote_code=True)
24
  if torch.cuda.is_available():
25
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
26
  else:
@@ -234,34 +235,39 @@ def unpack(dec, width, height, parallel_size=5):
234
  @torch.inference_mode()
235
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
236
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
237
- # Clear CUDA cache and avoid tracking gradients
238
- torch.cuda.empty_cache()
239
- # Set the seed for reproducible results
240
- if seed is not None:
241
- torch.manual_seed(seed)
242
- torch.cuda.manual_seed(seed)
243
- np.random.seed(seed)
244
- width = 384
245
- height = 384
246
- parallel_size = 1
247
 
248
- with torch.no_grad():
249
- messages = [{'role': '<|User|>', 'content': prompt},
250
- {'role': '<|Assistant|>', 'content': ''}]
251
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
252
- sft_format=vl_chat_processor.sft_format,
253
- system_prompt='')
254
- text = text + vl_chat_processor.image_start_tag
255
 
256
- input_ids = torch.LongTensor(tokenizer.encode(text))
257
- output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
258
- parallel_size=parallel_size, temperature=t2i_temperature)
259
- images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
260
-
261
- stime = time.time()
262
- ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
263
- print(f'upsample time: {time.time() - stime}')
264
- return ret_images
 
 
 
 
 
265
 
266
 
267
  @spaces.GPU(duration=60)
 
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
7
+ from diffusers import FluxPipeline
8
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
9
+ pipe.to("cuda")
10
 
11
  import numpy as np
12
  import os
 
21
  config = AutoConfig.from_pretrained(model_path)
22
  language_config = config.language_config
23
  language_config._attn_implementation = 'eager'
24
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
 
 
25
  if torch.cuda.is_available():
26
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
27
  else:
 
235
  @torch.inference_mode()
236
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
237
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress(track_tqdm=True)):
238
+ # # Clear CUDA cache and avoid tracking gradients
239
+ # torch.cuda.empty_cache()
240
+ # # Set the seed for reproducible results
241
+ # if seed is not None:
242
+ # torch.manual_seed(seed)
243
+ # torch.cuda.manual_seed(seed)
244
+ # np.random.seed(seed)
245
+ # width = 384
246
+ # height = 384
247
+ # parallel_size = 1
248
 
249
+ # with torch.no_grad():
250
+ # messages = [{'role': '<|User|>', 'content': prompt},
251
+ # {'role': '<|Assistant|>', 'content': ''}]
252
+ # text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
253
+ # sft_format=vl_chat_processor.sft_format,
254
+ # system_prompt='')
255
+ # text = text + vl_chat_processor.image_start_tag
256
 
257
+ # input_ids = torch.LongTensor(tokenizer.encode(text))
258
+ # output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance,
259
+ # parallel_size=parallel_size, temperature=t2i_temperature)
260
+ # images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)
261
+
262
+ # stime = time.time()
263
+ # ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
264
+ # print(f'upsample time: {time.time() - stime}')
265
+ # return ret_images
266
+
267
+ # Depending on the variant being used, the pipeline call will slightly vary.
268
+ # Refer to the pipeline documentation for more details.
269
+ image = pipe(prompt=prompt, guidance_scale=guidance, height=768, width=768, num_inference_steps=16,).images[0]
270
+ return image
271
 
272
 
273
  @spaces.GPU(duration=60)