Spaces:
Running
Running
bug fix
Browse files
app.py
CHANGED
@@ -483,18 +483,29 @@ class ExplainerCheckbox(Component):
|
|
483 |
|
484 |
data_id = self.gallery.selected_index
|
485 |
|
486 |
-
|
487 |
-
|
488 |
explainer_id=self.default_exp_id,
|
489 |
metric_id=self.obj_metric,
|
490 |
direction='maximize',
|
491 |
sampler=SAMPLE_METHOD,
|
492 |
n_trials=OPT_N_TRIALS,
|
493 |
)
|
|
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
499 |
self.optimal_exp_id = opt_explainer_id
|
500 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
@@ -628,6 +639,8 @@ from torch.utils.data import DataLoader
|
|
628 |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
|
629 |
|
630 |
os.environ['GRADIO_TEMP_DIR'] = '.tmp'
|
|
|
|
|
631 |
|
632 |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
|
633 |
|
@@ -637,7 +650,7 @@ model, transform = get_torchvision_model('resnet18')
|
|
637 |
dataset = get_imagenet_dataset(transform)
|
638 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
639 |
experiment1 = AutoExplanationForImageClassification(
|
640 |
-
model=model,
|
641 |
data=loader,
|
642 |
input_extractor=lambda batch: batch[0],
|
643 |
label_extractor=lambda batch: batch[-1],
|
@@ -657,7 +670,7 @@ model, transform = get_torchvision_model('vit_b_16')
|
|
657 |
dataset = get_imagenet_dataset(transform)
|
658 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
659 |
experiment2 = AutoExplanationForImageClassification(
|
660 |
-
model=model,
|
661 |
data=loader,
|
662 |
input_extractor=lambda batch: batch[0],
|
663 |
label_extractor=lambda batch: batch[-1],
|
|
|
483 |
|
484 |
data_id = self.gallery.selected_index
|
485 |
|
486 |
+
opt_output = self.experiment.optimize(
|
487 |
+
data_ids=data_id.value,
|
488 |
explainer_id=self.default_exp_id,
|
489 |
metric_id=self.obj_metric,
|
490 |
direction='maximize',
|
491 |
sampler=SAMPLE_METHOD,
|
492 |
n_trials=OPT_N_TRIALS,
|
493 |
)
|
494 |
+
|
495 |
|
496 |
+
def get_str_ppid(pp_obj):
|
497 |
+
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
498 |
+
|
499 |
+
str_id = get_str_ppid(opt_output.postprocessor)
|
500 |
+
for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
|
501 |
+
if get_str_ppid(pp_obj) == str_id:
|
502 |
+
opt_postprocessor_id = pp_id
|
503 |
+
break
|
504 |
+
|
505 |
+
opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1
|
506 |
+
opt_output.explainer.model = self.experiment.model
|
507 |
+
self.experiment.manager._explainers.append(opt_output.explainer)
|
508 |
+
self.experiment.manager._explainer_ids.append(opt_explainer_id)
|
509 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
510 |
self.optimal_exp_id = opt_explainer_id
|
511 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
|
|
639 |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
|
640 |
|
641 |
os.environ['GRADIO_TEMP_DIR'] = '.tmp'
|
642 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
643 |
+
device = torch.device("cpu")
|
644 |
|
645 |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
|
646 |
|
|
|
650 |
dataset = get_imagenet_dataset(transform)
|
651 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
652 |
experiment1 = AutoExplanationForImageClassification(
|
653 |
+
model=model.to(device),
|
654 |
data=loader,
|
655 |
input_extractor=lambda batch: batch[0],
|
656 |
label_extractor=lambda batch: batch[-1],
|
|
|
670 |
dataset = get_imagenet_dataset(transform)
|
671 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
672 |
experiment2 = AutoExplanationForImageClassification(
|
673 |
+
model=model.to(device),
|
674 |
data=loader,
|
675 |
input_extractor=lambda batch: batch[0],
|
676 |
label_extractor=lambda batch: batch[-1],
|