OMilosh commited on
Commit
f8d575d
·
verified ·
1 Parent(s): 4ee1d89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -20,10 +20,16 @@ available_models = [
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
 
23
  def init_model(model_repo_id):
24
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
26
- return pipe.to(device)
 
 
 
 
 
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
29
  def infer(
@@ -38,7 +44,7 @@ def infer(
38
  num_inference_steps,
39
  progress=gr.Progress(track_tqdm=True),
40
  ):
41
- pipe = init_model(model_repo_id)
42
 
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+
24
  def init_model(model_repo_id):
25
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
27
+ return pipe
28
+
29
+ loaded_models = {}
30
+
31
+ for model in available_models:
32
+ loaded_models[model] = init_model(model)
33
 
34
  # @spaces.GPU #[uncomment to use ZeroGPU]
35
  def infer(
 
44
  num_inference_steps,
45
  progress=gr.Progress(track_tqdm=True),
46
  ):
47
+ pipe = loaded_models[model_repo_id].to(device)
48
 
49
  if randomize_seed:
50
  seed = random.randint(0, MAX_SEED)