Spaces:
Runtime error
Runtime error
from ctypes import DEFAULT_MODE | |
import streamlit as st | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
from ferret import Benchmark | |
from torch.nn.functional import softmax | |
DEFAULT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" | |
def get_model(model_name): | |
return AutoModelForSequenceClassification.from_pretrained(model_name) | |
def get_config(model_name): | |
return AutoConfig.from_pretrained(model_name) | |
def get_tokenizer(tokenizer_name): | |
return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) | |
def body(): | |
st.markdown( | |
""" | |
# Welcome to the *ferret* showcase | |
You are working now on the *single instance* mode -- i.e., you will work and | |
inspect one textual query at a time. | |
## Sentiment Analysis | |
Post-hoc explanation techniques discose the rationale behind a given prediction a model | |
makes while detecting a sentiment out of a text. In a sense the let you *poke* inside the model. | |
But **who watches the watchers**? | |
Let's find out! | |
Let's choose your favourite sentiment classification mode and let ferret do the rest. | |
We will: | |
1. download your model - if you're impatient, here it is a [cute video](https://www.youtube.com/watch?v=0Xks8t-SWHU) π¦ for you; | |
2. explain using *ferret*'s built-in methods βοΈ | |
3. evaluate explanations with state-of-the-art **faithfulness metrics** π | |
""" | |
) | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
model_name = st.text_input("HF Model", DEFAULT_MODEL) | |
with col2: | |
target = st.selectbox( | |
"Target", | |
options=range(5), | |
index=1, | |
help="Positional index of your target class.", | |
) | |
text = st.text_input("Text") | |
compute = st.button("Compute") | |
if compute and model_name: | |
with st.spinner("Preparing the magic. Hang in there..."): | |
model = get_model(model_name) | |
tokenizer = get_tokenizer(model_name) | |
config = get_config(model_name) | |
bench = Benchmark(model, tokenizer) | |
st.markdown("### Prediction") | |
scores = bench.score(text) | |
scores_str = ", ".join( | |
[f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)] | |
) | |
st.text(scores_str) | |
with st.spinner("Computing Explanations.."): | |
explanations = bench.explain(text, target=target) | |
st.markdown("### Explanations") | |
st.dataframe(bench.show_table(explanations)) | |
with st.spinner("Evaluating Explanations..."): | |
evaluations = bench.evaluate_explanations( | |
explanations, target=target, apply_style=False | |
) | |
st.markdown("### Faithfulness Metrics") | |
st.dataframe(bench.show_evaluation_table(evaluations)) | |
st.markdown( | |
""" | |
**Legend** | |
- **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the | |
explanation captures | |
- **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., | |
- **Leave-On-Out TAU Correlation** (taucorr_loo) measures | |
See the paper for details. | |
""" | |
) | |