File size: 2,782 Bytes
15c0354
2bca1d4
 
bb13f04
 
 
15c0354
2bca1d4
 
78cf820
 
2bca1d4
78cf820
2bca1d4
bb13f04
2eba1ff
2bca1d4
bb13f04
2eba1ff
7a7170b
2bca1d4
 
bb13f04
2bca1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
7a7170b
2bca1d4
 
 
7a7170b
 
2bca1d4
78cf820
bb13f04
78cf820
 
808ed79
bb13f04
2bca1d4
 
 
bb13f04
2bca1d4
5f715e8
bb13f04
2bca1d4
 
bb13f04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import BartForSequenceClassification, BartTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")

def predict(context, intent, multi_class):
    input_text = "In one word, what is the opposite of: " + intent + "?"
    input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
    input_text = "In one word, what is the following describing: " + context
    input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
    batch = ['I think the ' + object_output + ' is ' + intent, 'I think the ' + object_output + ' is ' + opposite_output, 'I think the ' + object_output  + ' are neither ' + intent + ' nor ' + opposite_output]
    outputs = []
    for i, hypothesis in enumerate(batch):
        input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
        # -> [contradiction, neutral, entailment]
        logits = te_model(input_ids)[0][0]

        if (i == 2):
            # -> [contradiction, entailment]
            probs = logits[[0,2]].softmax(dim=0)
        else:
            probs = logits.softmax(dim=0)
        outputs.append(probs)
        
    # -> [entailment, contradiction]
    outputs[2] = outputs[2].flip(dims=[0])
    # -> [entailment, neutral, contradiction]
    outputs[0] = outputs[0].flip(dims=[0])
    pn_tensor = (outputs[0] + outputs[1])/2
    pn_tensor[1] = pn_tensor[1] * outputs[2][0]
    pn_tensor[2] = pn_tensor[2] * outputs[2][1]
    pn_tensor[0] = pn_tensor[0] * outputs[2][1]
    
    pn_tensor = pn_tensor.exp() - 1

    if (multi_class):
        pn_tensor = torch.sigmoid(pn_tensor)
    else:
        pn_tensor = pn_tensor.softmax(dim=0)
    pn_tensor = pn_tensor.tolist()
    return {"agree": pn_tensor[0], "neutral": pn_tensor[1], "disagree": pn_tensor[2]}

gradio_app = gr.Interface(
    predict,
    inputs=[gr.Text(label="Sentence"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
    outputs=[gr.Label(num_top_classes=3)],
    title="Intent Analysis",
    description="This model predicts whether or not the **class** describes the **object described in the sentence.**"
)

gradio_app.launch()