ferret / single.py
g8a9's picture
add evaluation and prediction
acdedb4
raw
history blame
3.34 kB
from ctypes import DEFAULT_MODE
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from ferret import Benchmark
from torch.nn.functional import softmax
DEFAULT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
@st.cache()
def get_model(model_name):
return AutoModelForSequenceClassification.from_pretrained(model_name)
@st.cache()
def get_config(model_name):
return AutoConfig.from_pretrained(model_name)
def get_tokenizer(tokenizer_name):
return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
def body():
st.markdown(
"""
# Welcome to the *ferret* showcase
You are working now on the *single instance* mode -- i.e., you will work and
inspect one textual query at a time.
## Sentiment Analysis
Post-hoc explanation techniques discose the rationale behind a given prediction a model
makes while detecting a sentiment out of a text. In a sense the let you *poke* inside the model.
But **who watches the watchers**?
Let's find out!
Let's choose your favourite sentiment classification mode and let ferret do the rest.
We will:
1. download your model - if you're impatient, here it is a [cute video](https://www.youtube.com/watch?v=0Xks8t-SWHU) 🦜 for you;
2. explain using *ferret*'s built-in methods βš™οΈ
3. evaluate explanations with state-of-the-art **faithfulness metrics** πŸš€
"""
)
col1, col2 = st.columns([3, 1])
with col1:
model_name = st.text_input("HF Model", DEFAULT_MODEL)
with col2:
target = st.selectbox(
"Target",
options=range(5),
index=1,
help="Positional index of your target class.",
)
text = st.text_input("Text")
compute = st.button("Compute")
if compute and model_name:
with st.spinner("Preparing the magic. Hang in there..."):
model = get_model(model_name)
tokenizer = get_tokenizer(model_name)
config = get_config(model_name)
bench = Benchmark(model, tokenizer)
st.markdown("### Prediction")
scores = bench.score(text)
scores_str = ", ".join(
[f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)]
)
st.text(scores_str)
with st.spinner("Computing Explanations.."):
explanations = bench.explain(text, target=target)
st.markdown("### Explanations")
st.dataframe(bench.show_table(explanations))
with st.spinner("Evaluating Explanations..."):
evaluations = bench.evaluate_explanations(
explanations, target=target, apply_style=False
)
st.markdown("### Faithfulness Metrics")
st.dataframe(bench.show_evaluation_table(evaluations))
st.markdown(
"""
**Legend**
- **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the
explanation captures
- **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e.,
- **Leave-On-Out TAU Correlation** (taucorr_loo) measures
See the paper for details.
"""
)