philipp-zettl commited on
Commit
f33e554
·
1 Parent(s): ac1234c
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -2,7 +2,63 @@ from aura_sr import AuraSR
2
  import gradio as gr
3
  import spaces
4
 
5
- aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @spaces.GPU()
@@ -10,7 +66,7 @@ def predict(img):
10
  return aura_sr.upscale_4x(img)
11
 
12
 
13
- demo = gr.Interface(predict, inputs=gr.Image(type="filepath"),
14
  outputs=gr.Image())
15
 
16
 
 
2
  import gradio as gr
3
  import spaces
4
 
5
+
6
+ class ZeroGPUAuraSR(AuraSR):
7
+ @classmethod
8
+ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True):
9
+ import json
10
+ import torch
11
+ from pathlib import Path
12
+ from huggingface_hub import snapshot_download
13
+
14
+ # Check if model_id is a local file
15
+ if Path(model_id).is_file():
16
+ local_file = Path(model_id)
17
+ if local_file.suffix == '.safetensors':
18
+ use_safetensors = True
19
+ elif local_file.suffix == '.ckpt':
20
+ use_safetensors = False
21
+ else:
22
+ raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
23
+
24
+ # For local files, we need to provide the config separately
25
+ config_path = local_file.with_name('config.json')
26
+ if not config_path.exists():
27
+ raise FileNotFoundError(
28
+ f"Config file not found: {config_path}. "
29
+ f"When loading from a local file, ensure that 'config.json' "
30
+ f"is present in the same directory as '{local_file.name}'. "
31
+ f"If you're trying to load a model from Hugging Face, "
32
+ f"please provide the model ID instead of a file path."
33
+ )
34
+
35
+ config = json.loads(config_path.read_text())
36
+ hf_model_path = local_file.parent
37
+ else:
38
+ hf_model_path = Path(snapshot_download(model_id))
39
+ config = json.loads((hf_model_path / "config.json").read_text())
40
+
41
+ model = cls(config)
42
+
43
+ if use_safetensors:
44
+ try:
45
+ from safetensors.torch import load_file
46
+ checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
47
+ except ImportError:
48
+ raise ImportError(
49
+ "The safetensors library is not installed. "
50
+ "Please install it with `pip install safetensors` "
51
+ "or use `use_safetensors=False` to load the model with PyTorch."
52
+ )
53
+ else:
54
+ checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)
55
+
56
+ model.upsampler.load_state_dict(checkpoint, strict=True)
57
+ return model
58
+
59
+
60
+
61
+ aura_sr = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR")
62
 
63
 
64
  @spaces.GPU()
 
66
  return aura_sr.upscale_4x(img)
67
 
68
 
69
+ demo = gr.Interface(predict, inputs=gr.Image(),
70
  outputs=gr.Image())
71
 
72