WeichenFan commited on
Commit
1298c80
·
1 Parent(s): 174e58d

update demo

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -53,23 +53,6 @@ from datetime import datetime, timedelta
53
  import spaces
54
  import moviepy.editor as mp
55
 
56
-
57
- import os
58
- from huggingface_hub import login
59
- login(token=os.getenv('HF_TOKEN'))
60
-
61
- dtype = torch.float16
62
- device = "cuda" if torch.cuda.is_available() else "cpu"
63
- pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device)
64
-
65
- # pipe.acc_call = acc_call.__get__(pipe)
66
- import types
67
- # pipe.__call__ = types.MethodType(acc_call, pipe)
68
- pipe.__class__.__call__ = acc_call
69
-
70
- os.makedirs("./output", exist_ok=True)
71
- os.makedirs("./gradio_tmp", exist_ok=True)
72
-
73
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
74
  def retrieve_timesteps(
75
  scheduler,
@@ -357,6 +340,23 @@ def acc_call(
357
  videos.append(image[0])
358
 
359
  return videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  @spaces.GPU(duration=120)
362
  def infer(prompt: str, progress=gr.Progress(track_tqdm=True)):
 
53
  import spaces
54
  import moviepy.editor as mp
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
57
  def retrieve_timesteps(
58
  scheduler,
 
340
  videos.append(image[0])
341
 
342
  return videos
343
+
344
+ import os
345
+ from huggingface_hub import login
346
+ login(token=os.getenv('HF_TOKEN'))
347
+
348
+ dtype = torch.float16
349
+ device = "cuda" if torch.cuda.is_available() else "cpu"
350
+ pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device)
351
+
352
+ # pipe.acc_call = acc_call.__get__(pipe)
353
+ import types
354
+ # pipe.__call__ = types.MethodType(acc_call, pipe)
355
+ pipe.__class__.__call__ = acc_call
356
+
357
+ os.makedirs("./output", exist_ok=True)
358
+ os.makedirs("./gradio_tmp", exist_ok=True)
359
+
360
 
361
  @spaces.GPU(duration=120)
362
  def infer(prompt: str, progress=gr.Progress(track_tqdm=True)):