2model
Browse files
app.py
CHANGED
@@ -121,13 +121,13 @@ state_dict = torch.load(checkpoint_path, map_location=device)
|
|
121 |
nnet_1.load_state_dict(state_dict)
|
122 |
nnet_1.eval()
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
|
132 |
# Initialize text model.
|
133 |
llm = "clip"
|
@@ -181,11 +181,14 @@ def infer(
|
|
181 |
else:
|
182 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
189 |
|
190 |
# Get text embeddings and tokens.
|
191 |
_context, _token_mask, _token, _caption = get_caption(
|
|
|
121 |
nnet_1.load_state_dict(state_dict)
|
122 |
nnet_1.eval()
|
123 |
|
124 |
+
filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
125 |
+
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
126 |
+
nnet_2 = utils.get_nnet(**config_2.nnet)
|
127 |
+
nnet_2 = nnet_2.to(device)
|
128 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
129 |
+
nnet_2.load_state_dict(state_dict)
|
130 |
+
nnet_2.eval()
|
131 |
|
132 |
# Initialize text model.
|
133 |
llm = "clip"
|
|
|
181 |
else:
|
182 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
183 |
|
184 |
+
if num_of_interpolation == 3:
|
185 |
+
nnet = nnet_2
|
186 |
+
config = config_2
|
187 |
+
else:
|
188 |
+
nnet = nnet_1
|
189 |
+
config = config_1
|
190 |
+
# nnet = nnet_1
|
191 |
+
# config = config_1
|
192 |
|
193 |
# Get text embeddings and tokens.
|
194 |
_context, _token_mask, _token, _caption = get_caption(
|