1inkusFace commited on
Commit
a4a2927
·
verified ·
1 Parent(s): abdf73f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -50
app.py CHANGED
@@ -4,27 +4,30 @@ import sys
4
  import time
5
  import os
6
  import random
7
- from PIL import Image
8
- # os.environ["CUDA_VISIBLE_DEVICES"] = ""
9
- os.environ["SAFETENSORS_FAST_GPU"] = "1"
10
- os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
11
  import torch
 
 
 
 
 
 
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
 
14
- # Create the gr.State component *outside* the gr.Blocks context
 
15
 
16
- global predictor
17
 
18
  def init_predictor(task_type: str):
19
  from skyreelsinfer import TaskType
20
  from skyreelsinfer.offload import OffloadConfig
21
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
22
  from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
23
- global predictor
24
  try:
25
  predictor = SkyReelsVideoInfer(
26
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
27
- model_id="Skywork/skyreels-v1-Hunyuan-i2v",
28
  quant_model=True,
29
  is_offload=True,
30
  offload_config=OffloadConfig(
@@ -35,23 +38,28 @@ def init_predictor(task_type: str):
35
  )
36
  return predictor
37
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
38
- return f"Error: Model not found. Details: {e}", None
 
39
  except Exception as e:
40
- return f"Error loading model: {e}", None
41
-
42
- predictor = init_predictor('i2v')
43
 
44
- @spaces.GPU(duration=80)
45
- def generate_video(prompt, image, predictor):
46
  from diffusers.utils import export_to_video
47
  from diffusers.utils import load_image
48
- if image == None:
49
- return "Error: For i2v, provide image path.", "{}"
50
- if not isinstance(prompt, str):
51
- return "Error: No prompt.", "{}"
52
- #if seed == -1:
 
 
 
53
  random.seed(time.time())
54
  seed = int(random.randrange(4294967294))
 
55
  kwargs = {
56
  "prompt": prompt,
57
  "height": 256,
@@ -65,41 +73,62 @@ def generate_video(prompt, image, predictor):
65
  "cfg_for": False,
66
  }
67
 
68
- kwargs["image"] = load_image(image=image)
69
- output = predictor.inference(kwargs)
70
- frames = output
71
- save_dir = f"./result/{task_type}"
 
 
 
 
 
 
 
 
72
  os.makedirs(save_dir, exist_ok=True)
73
- video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
74
  print(f"Generating video: {video_out_file}")
75
- export_to_video(frames, video_out_file, fps=24)
76
- return video_out_file
77
-
 
 
 
 
 
 
78
  def display_image(file):
79
  if file is not None:
80
  return Image.open(file.name)
81
  else:
82
  return None
83
-
84
- with gr.Blocks() as demo:
85
- #predictor = gr.State({}) # Initialize as an empty dictionary
86
-
87
- image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
88
- image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
89
- prompt_textbox = gr.Text(label="Prompt")
90
- generate_button = gr.Button("Generate")
91
- output_video = gr.Video(label="Output Video")
92
-
93
- image_file.change(
94
- display_image,
95
- inputs=[image_file],
96
- outputs=[image_file_preview]
97
- )
98
-
99
- generate_button.click(
100
- fn=generate_video,
101
- inputs=[prompt_textbox, image_file, predictor],
102
- outputs=[output_video],
103
- )
104
-
105
- demo.launch()
 
 
 
 
 
 
 
 
4
  import time
5
  import os
6
  import random
7
+ from PIL import Image
 
 
 
8
  import torch
9
+ import asyncio # Import asyncio
10
+
11
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
12
+ os.environ["SAFETENSORS_FAST_GPU"] = "1"
13
+ os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
14
+
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Use gr.State to hold the predictor. Initialize it to None.
18
+ predictor_state = gr.State(None)
19
 
 
20
 
21
  def init_predictor(task_type: str):
22
  from skyreelsinfer import TaskType
23
  from skyreelsinfer.offload import OffloadConfig
24
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
25
  from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
26
+
27
  try:
28
  predictor = SkyReelsVideoInfer(
29
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
30
+ model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Adjust model ID as needed
31
  quant_model=True,
32
  is_offload=True,
33
  offload_config=OffloadConfig(
 
38
  )
39
  return predictor
40
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
41
+ print(f"Error: Model not found. Details: {e}")
42
+ return None
43
  except Exception as e:
44
+ print(f"Error loading model: {e}")
45
+ return None
46
+
47
 
48
+ # Make generate_video async
49
+ async def generate_video(prompt, image_file, predictor):
50
  from diffusers.utils import export_to_video
51
  from diffusers.utils import load_image
52
+
53
+ if image_file is None:
54
+ return gr.Error("Error: For i2v, provide an image.")
55
+ if not isinstance(prompt, str) or not prompt.strip():
56
+ return gr.Error("Error: Please provide a prompt.")
57
+ if predictor is None:
58
+ return gr.Error("Error: Model not loaded.")
59
+
60
  random.seed(time.time())
61
  seed = int(random.randrange(4294967294))
62
+
63
  kwargs = {
64
  "prompt": prompt,
65
  "height": 256,
 
73
  "cfg_for": False,
74
  }
75
 
76
+ try:
77
+ kwargs["image"] = load_image(image=image_file.name)
78
+ except Exception as e:
79
+ return gr.Error(f"image loading error: {e}")
80
+
81
+ try:
82
+ output = predictor.inference(kwargs)
83
+ frames = output
84
+ except Exception as e:
85
+ return gr.Error(f"Inference error: {e}")
86
+
87
+ save_dir = "./result/i2v" # Consistent directory
88
  os.makedirs(save_dir, exist_ok=True)
89
+ video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
90
  print(f"Generating video: {video_out_file}")
91
+
92
+ try:
93
+ export_to_video(frames, video_out_file, fps=24)
94
+ except Exception as e:
95
+ return gr.Error(f"Video export error: {e}")
96
+
97
+ return video_out_file, predictor # Return updated predictor
98
+
99
+
100
  def display_image(file):
101
  if file is not None:
102
  return Image.open(file.name)
103
  else:
104
  return None
105
+
106
+ async def load_model():
107
+ predictor = init_predictor('i2v')
108
+ return predictor
109
+
110
+ async def main():
111
+ with gr.Blocks() as demo:
112
+ image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
113
+ image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
114
+ prompt_textbox = gr.Text(label="Prompt")
115
+ generate_button = gr.Button("Generate")
116
+ output_video = gr.Video(label="Output Video")
117
+
118
+ image_file.change(
119
+ display_image,
120
+ inputs=[image_file],
121
+ outputs=[image_file_preview]
122
+ )
123
+
124
+ generate_button.click(
125
+ fn=generate_video,
126
+ inputs=[prompt_textbox, image_file, predictor_state],
127
+ outputs=[output_video, predictor_state], # Output predictor_state
128
+ )
129
+ predictor_state.value = await load_model() # load and set predictor
130
+
131
+ await demo.launch()
132
+
133
+ if __name__ == "__main__":
134
+ asyncio.run(main())