ahsanMah commited on
Commit
d904b3a
·
1 Parent(s): 133982f

fixed outdir bug in gmm train

Browse files
Files changed (1) hide show
  1. msma.py +5 -1
msma.py CHANGED
@@ -231,8 +231,10 @@ def common_args(func):
231
  )
232
  @common_args
233
  def train_gmm(preset, outdir, gridsearch=False, **kwargs):
234
- score_path = f"{outdir}/{preset}/imagenette_score_norms.pt"
 
235
  X = torch.load(score_path).numpy()
 
236
 
237
  gm = GaussianMixture(
238
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
@@ -267,12 +269,14 @@ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
267
  clf.fit(X)
268
  inlier_nll = -clf.score_samples(X)
269
 
 
270
  os.makedirs(outdir, exist_ok=True)
271
  with open(f"{outdir}/refscores.npz", "wb") as f:
272
  np.savez_compressed(f, inlier_nll)
273
 
274
  with open(f"{outdir}/gmm.pkl", "wb") as f:
275
  dump(clf, f, protocol=5)
 
276
 
277
 
278
  @cmdline.command(name="cache-scores")
 
231
  )
232
  @common_args
233
  def train_gmm(preset, outdir, gridsearch=False, **kwargs):
234
+ outdir = f"{outdir}/{preset}"
235
+ score_path = f"{outdir}/imagenette_score_norms.pt"
236
  X = torch.load(score_path).numpy()
237
+ print(f"Loaded score norms from: {score_path} - # Samples: {X.shape[0]}")
238
 
239
  gm = GaussianMixture(
240
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
 
269
  clf.fit(X)
270
  inlier_nll = -clf.score_samples(X)
271
 
272
+ print("Saving reference inlier scores ... ")
273
  os.makedirs(outdir, exist_ok=True)
274
  with open(f"{outdir}/refscores.npz", "wb") as f:
275
  np.savez_compressed(f, inlier_nll)
276
 
277
  with open(f"{outdir}/gmm.pkl", "wb") as f:
278
  dump(clf, f, protocol=5)
279
+ print("Saved GMM pickle.")
280
 
281
 
282
  @cmdline.command(name="cache-scores")