Spaces:
Sleeping
Sleeping
File size: 2,782 Bytes
15c0354 2bca1d4 bb13f04 15c0354 2bca1d4 78cf820 2bca1d4 78cf820 2bca1d4 bb13f04 2eba1ff 2bca1d4 bb13f04 2eba1ff 7a7170b 2bca1d4 bb13f04 2bca1d4 7a7170b 2bca1d4 7a7170b 2bca1d4 78cf820 bb13f04 78cf820 808ed79 bb13f04 2bca1d4 bb13f04 2bca1d4 5f715e8 bb13f04 2bca1d4 bb13f04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
from transformers import BartForSequenceClassification, BartTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
def predict(context, intent, multi_class):
input_text = "In one word, what is the opposite of: " + intent + "?"
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
input_text = "In one word, what is the following describing: " + context
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
batch = ['I think the ' + object_output + ' is ' + intent, 'I think the ' + object_output + ' is ' + opposite_output, 'I think the ' + object_output + ' are neither ' + intent + ' nor ' + opposite_output]
outputs = []
for i, hypothesis in enumerate(batch):
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
# -> [contradiction, neutral, entailment]
logits = te_model(input_ids)[0][0]
if (i == 2):
# -> [contradiction, entailment]
probs = logits[[0,2]].softmax(dim=0)
else:
probs = logits.softmax(dim=0)
outputs.append(probs)
# -> [entailment, contradiction]
outputs[2] = outputs[2].flip(dims=[0])
# -> [entailment, neutral, contradiction]
outputs[0] = outputs[0].flip(dims=[0])
pn_tensor = (outputs[0] + outputs[1])/2
pn_tensor[1] = pn_tensor[1] * outputs[2][0]
pn_tensor[2] = pn_tensor[2] * outputs[2][1]
pn_tensor[0] = pn_tensor[0] * outputs[2][1]
pn_tensor = pn_tensor.exp() - 1
if (multi_class):
pn_tensor = torch.sigmoid(pn_tensor)
else:
pn_tensor = pn_tensor.softmax(dim=0)
pn_tensor = pn_tensor.tolist()
return {"agree": pn_tensor[0], "neutral": pn_tensor[1], "disagree": pn_tensor[2]}
gradio_app = gr.Interface(
predict,
inputs=[gr.Text(label="Sentence"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
outputs=[gr.Label(num_top_classes=3)],
title="Intent Analysis",
description="This model predicts whether or not the **class** describes the **object described in the sentence.**"
)
gradio_app.launch() |