QHL067 commited on
Commit
0f478f2
·
1 Parent(s): fced74d
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -121,13 +121,13 @@ state_dict = torch.load(checkpoint_path, map_location=device)
121
  nnet_1.load_state_dict(state_dict)
122
  nnet_1.eval()
123
 
124
- # filename = "pretrained_models/t2i_512px_clip_dimr.pth"
125
- # checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
126
- # nnet_2 = utils.get_nnet(**config_2.nnet)
127
- # nnet_2 = nnet_2.to(device)
128
- # state_dict = torch.load(checkpoint_path, map_location=device)
129
- # nnet_2.load_state_dict(state_dict)
130
- # nnet_2.eval()
131
 
132
  # Initialize text model.
133
  llm = "clip"
@@ -181,11 +181,14 @@ def infer(
181
  else:
182
  assert num_of_interpolation == 3, "For arithmetic, please sample three images."
183
 
184
- # if num_of_interpolation == 3:
185
- # nnet = nnet_2
186
- # else:
187
- # nnet = nnet_1
188
- nnet = nnet_1
 
 
 
189
 
190
  # Get text embeddings and tokens.
191
  _context, _token_mask, _token, _caption = get_caption(
 
121
  nnet_1.load_state_dict(state_dict)
122
  nnet_1.eval()
123
 
124
+ filename = "pretrained_models/t2i_512px_clip_dimr.pth"
125
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
126
+ nnet_2 = utils.get_nnet(**config_2.nnet)
127
+ nnet_2 = nnet_2.to(device)
128
+ state_dict = torch.load(checkpoint_path, map_location=device)
129
+ nnet_2.load_state_dict(state_dict)
130
+ nnet_2.eval()
131
 
132
  # Initialize text model.
133
  llm = "clip"
 
181
  else:
182
  assert num_of_interpolation == 3, "For arithmetic, please sample three images."
183
 
184
+ if num_of_interpolation == 3:
185
+ nnet = nnet_2
186
+ config = config_2
187
+ else:
188
+ nnet = nnet_1
189
+ config = config_1
190
+ # nnet = nnet_1
191
+ # config = config_1
192
 
193
  # Get text embeddings and tokens.
194
  _context, _token_mask, _token, _caption = get_caption(