WeichenFan commited on
Commit
9587e73
·
1 Parent(s): adf43eb

Add application file

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -48,7 +48,7 @@ def load_model(model_name):
48
  gc.collect() # Force garbage collection
49
 
50
  if "wan-t2v" in model_name:
51
- vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.float32)
52
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
53
  current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.float16).to("cuda")
54
  current_model.scheduler = scheduler
@@ -146,7 +146,7 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
146
  demo = gr.Interface(
147
  fn=generate_content,
148
  inputs=[
149
- gr.Textbox(value="A capybara holding a sign that reads Hello World", label="Enter your prompt"),
150
  gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"),
151
  gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"),
152
  gr.Slider(10, 100, value=28, step=5, label="Inference Steps"),
 
48
  gc.collect() # Force garbage collection
49
 
50
  if "wan-t2v" in model_name:
51
+ vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.bfloat16)
52
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
53
  current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.float16).to("cuda")
54
  current_model.scheduler = scheduler
 
146
  demo = gr.Interface(
147
  fn=generate_content,
148
  inputs=[
149
+ gr.Textbox(value="A cosmic whale swimming throught a glaxy with stars and swirling cosmic dusts.", label="Enter your prompt"),
150
  gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"),
151
  gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"),
152
  gr.Slider(10, 100, value=28, step=5, label="Inference Steps"),