Spaces:
Running
Running
fixed outdir bug in gmm train
Browse files
msma.py
CHANGED
@@ -231,8 +231,10 @@ def common_args(func):
|
|
231 |
)
|
232 |
@common_args
|
233 |
def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
234 |
-
|
|
|
235 |
X = torch.load(score_path).numpy()
|
|
|
236 |
|
237 |
gm = GaussianMixture(
|
238 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
@@ -267,12 +269,14 @@ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
|
267 |
clf.fit(X)
|
268 |
inlier_nll = -clf.score_samples(X)
|
269 |
|
|
|
270 |
os.makedirs(outdir, exist_ok=True)
|
271 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
272 |
np.savez_compressed(f, inlier_nll)
|
273 |
|
274 |
with open(f"{outdir}/gmm.pkl", "wb") as f:
|
275 |
dump(clf, f, protocol=5)
|
|
|
276 |
|
277 |
|
278 |
@cmdline.command(name="cache-scores")
|
|
|
231 |
)
|
232 |
@common_args
|
233 |
def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
234 |
+
outdir = f"{outdir}/{preset}"
|
235 |
+
score_path = f"{outdir}/imagenette_score_norms.pt"
|
236 |
X = torch.load(score_path).numpy()
|
237 |
+
print(f"Loaded score norms from: {score_path} - # Samples: {X.shape[0]}")
|
238 |
|
239 |
gm = GaussianMixture(
|
240 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
|
|
269 |
clf.fit(X)
|
270 |
inlier_nll = -clf.score_samples(X)
|
271 |
|
272 |
+
print("Saving reference inlier scores ... ")
|
273 |
os.makedirs(outdir, exist_ok=True)
|
274 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
275 |
np.savez_compressed(f, inlier_nll)
|
276 |
|
277 |
with open(f"{outdir}/gmm.pkl", "wb") as f:
|
278 |
dump(clf, f, protocol=5)
|
279 |
+
print("Saved GMM pickle.")
|
280 |
|
281 |
|
282 |
@cmdline.command(name="cache-scores")
|