rahul7star commited on
Commit
b85cf9b
Β·
verified Β·
1 Parent(s): 566b570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -2,20 +2,35 @@ import torch
2
  import gradio as gr
3
  import imageio
4
  import os
 
5
  from safetensors.torch import load_file
6
  from torchvision import transforms
7
  from PIL import Image
8
  import numpy as np
9
 
10
- # Define model path (assuming it's in the HF Space)
11
- MODEL_PATH = "sarthak247/Wan2.1-T2V-1.3B-nf4"
12
- MODEL_FILE = f"{MODEL_PATH}/blob/main/diffusion_pytorch_model.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load model weights manually
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Loading model on {device}...")
17
 
18
  try:
 
19
  model_weights = load_file(MODEL_FILE, device=device)
20
  print("Model loaded successfully!")
21
  except Exception as e:
 
2
  import gradio as gr
3
  import imageio
4
  import os
5
+ import requests
6
  from safetensors.torch import load_file
7
  from torchvision import transforms
8
  from PIL import Image
9
  import numpy as np
10
 
11
+ # Define model URL and local path
12
+ MODEL_URL = "https://huggingface.co/sarthak247/Wan2.1-T2V-1.3B-nf4/resolve/main/diffusion_pytorch_model.safetensors"
13
+ MODEL_FILE = "diffusion_pytorch_model.safetensors"
14
+
15
+ # Function to download model if not present
16
+ def download_model():
17
+ if not os.path.exists(MODEL_FILE):
18
+ print("Downloading model...")
19
+ response = requests.get(MODEL_URL, stream=True)
20
+ if response.status_code == 200:
21
+ with open(MODEL_FILE, "wb") as f:
22
+ for chunk in response.iter_content(chunk_size=8192):
23
+ f.write(chunk)
24
+ print("Download complete!")
25
+ else:
26
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
27
 
28
  # Load model weights manually
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print(f"Loading model on {device}...")
31
 
32
  try:
33
+ download_model()
34
  model_weights = load_file(MODEL_FILE, device=device)
35
  print("Model loaded successfully!")
36
  except Exception as e: