File size: 3,992 Bytes
578bab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from PIL import Image

import datasets
from datasets import load_dataset
from huggingface_hub import delete_repo 

idx = 0
data_to_label = load_dataset("active-learning/to_label_samples")
imgs = data_to_label["train"]["image"]

def get_image():
    global idx
    new_img = imgs[idx]
    idx += 1
    return new_img

labeled_data = []

information = """# Active Learning Demo

This demo showcases Active Learning, which is great when labeling is expensive. In this demo, you will label images by choosing a digit (0-9).

How does this work?
* There is a large pool of unlabeled images
* A model is trained with the few labeled images
* We can then use the model to pick the images with the lowest confidence or with the lowest probability of corresponding to an image. These are the images for which the model is confused, so by improving them, the quality of the model can improve much more than queries for which the model was already doing well!
* In this UI, you will be provided a couple of images to label
* Once all the provided images are labeled, the model is retrained, and a new set of images is chosen!
"""

webhook_info = """## Model Retraining

There are new labeled images. The model is retraining. Follow progress in [here](https://huggingface.co/spaces/active-learning/webhook).
"""

with gr.Blocks() as demo:
    gr.Markdown(information)

    img_to_label = gr.Image(shape=[28,28], value=get_image())
    label_dropdown = gr.Dropdown(choices=[0,1,2,3,4,5,6,7,8,9], interactive=True, value=0)
    save_btn = gr.Button("Save label")
    output_box = gr.Markdown(value=webhook_info, visible=False)
    reload_btn = gr.Button("Reload", visible=False)

    def save_data(img, label):
        global labeled_data
        global idx

        labeled_data.append([img, label])

        if len(imgs) == idx :
            # Remove dataset of queries to label 
            # datasets library does not allow pushing an empty dataset, so as a 
            # workaround we just delete the repo
            delete_repo(repo_id="active-learning/to_label_samples", repo_type="dataset")
            
            # Save to dataset
            labeled_dataset = load_dataset("active-learning/labeled_samples")["train"]
            feature = datasets.Image(decode=False)
            for img, label in labeled_data:
              # Hack due to https://github.com/huggingface/datasets/issues/4796 
              labeled_dataset = labeled_dataset.add_item({
                  "image": feature.encode_example(Image.fromarray(img)), 
                  "label": label
              })
        labeled_dataset.push_to_hub("active-learning/labeled_samples")
        labeled_data = []
        idx = 0
        return {
            img_to_label: gr.update(visible=False),
            label_dropdown: gr.update(visible=False),
            save_btn: gr.update(visible=False),
            output_box: gr.update(visible=True),
            reload_btn: gr.update(visible=True)
        }
        else:
            return {
                img_to_label: gr.update(value=get_image())
            }
    
    def reload_data():
        global data_to_label 
        global imgs
        data_to_label = load_dataset("active-learning/to_label_samples")
        imgs = data_to_label["train"]["image"]
        if len(imgs) == 0:
            return 
        else:
            return {
                img_to_label: gr.update(visible=True),
                label_dropdown: gr.update(visible=True),
                save_btn: gr.update(visible=True),
                output_box: gr.update(visible=False),
                reload_btn: gr.update(visible=False)
            }

    save_btn.click(
        save_data,
        inputs=[img_to_label, label_dropdown],
        outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn]
    )

    reload_btn.click(
        reload_data,
        outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn]
    )
    
demo.launch(debug=True)