In [13]:
import pandas as pd
pd.set_option('display.max_colwidth', 0)
import numpy as np
import torch
import math
import openai
from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer
import nltk
from nltk.cluster import KMeansClusterer
import scipy.spatial.distance as sdist
from scipy.spatial import distance_matrix
import matplotlib.pyplot as plt
from datasets import load_dataset


In [9]:
import sys
sys.path.append('../../seal/')
from run_inference import run_inference

In [5]:
dataset = 'imdb'
model = 'lvwerra/distilbert-imdb'

In [14]:
tokenizer = AutoTokenizer.from_pretrained(model)

In [15]:
run_inference(dataset=dataset,model=model,split='test')



ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.

In [2]:
def kmeans(df, num_clusters=3):
    X = np.array(df['embedding'].tolist())
    kclusterer = KMeansClusterer(
        num_clusters, distance=nltk.cluster.util.cosine_distance,
        repeats=25,avoid_empty_clusters=True)
    assigned_clusters = kclusterer.cluster(X, assign_clusters=True)
    df['cluster'] = pd.Series(assigned_clusters, index=df.index).astype('int')
    df['centroid'] = df['cluster'].apply(lambda x: kclusterer.means()[x])
    return df

In [3]:
def cluster_errors(data_hl):
    merged = pd.DataFrame()
    ind=0
    num_clusters=0 #cluster count so far
    for df in data_hl:
        if 'cluster' in df.columns:
            df = df.drop(columns=['cluster','centroid'])
        kmeans_df = kmeans(df,num_clusters=int(math.sqrt(len(df)/2)))
        #print(kmeans_df.loc[kmeans_df['cluster'].idxmax()])
        df['cluster'] = kmeans_df['cluster'] + num_clusters
        num_clusters=num_clusters + int(math.sqrt(len(df)/2))
        ind = ind+1
        merged = pd.concat([merged, df], ignore_index=True)
    return merged

In [3]:
from sklearn.metrics import accuracy_score
def generate_groups(merged):
    clusters = merged.groupby('cluster')
    groups = {x: clusters.get_group(x) for x in clusters.groups}
    cluster_acc = {x: accuracy_score(clusters.get_group(x)['label'].values.tolist(), clusters.get_group(x)['pred'].values.tolist()) for x in clusters.groups}
    return groups, cluster_acc

def dict_zip(*dicts):
    all_keys = {k for d in dicts for k in d.keys()}
    return {k: [d[k] for d in dicts if k in d] for k in all_keys}

def semantic_labeling(groups):
    group_labels= {gidx: generate_group_label(cluster) for gidx, cluster in tqdm(groups.items())}
    error_groups= dict_zip(group_labels,{k: [len(v), v.iloc[0].label] for k, v in groups.items()})
    return error_groups

def generate_group_label(cluster):
    #instruction = "In this task, we'll assign a short and precise label to a cluster of documents based on the topics or concepts most relevant to these documents. The documents are all subsets of a sentiment classification dataset. Here are some examples of high-quality labels:\nDocuments:\n - Like Clay P who posted before me, I too love pancakes. Though I love chocolate chip pancakes. But like Clay I did not love the ones that I got a the Original Pancake House. Typically, restaurants just don't do it the way I like them and I have come to expect that. Grading OPH on those terms, they did a respectable job. It is definitely a worthwhile destination for a pancake lover.\n - Very small portions, but good food. Had a perfectly cooked fillet minion, but $60 and they could have served with a toothpick, it was so small.  \nSame for rest of our party. Excellent salmon, but maybe 5 bites again $60. \nEverything is a la carte so sides extra. Thai papaya salad good, and key lime pie great.\n - $65 per person to share ONE ribeye?  Yes you read that correctly.  We had our annual guys trip to vegas and we always try to visit a different steak house every time we visit.  This year was Carnevino.  We were surprised in order to order a ribeye or porterhouse, TWO people had to order it.  We found it a bit odd but we had no problems with it.   We found it even more odd the waiter suggested we have the steaks family style (they cut up the meat and put it on a plate for people to share).   Anyways, a bunch us order different cuts of meat.  NY Strip, Filet Mignon and the folks that wanted port.\n - This place was a flop. Was visiting a friend in Chandler and found this place on mobile Yelp. Had 68-something reviews and 4/5 stars which almost always means it will be a good place. Boy I was wrong this time.\n\nI hate those lengthy reviews talking about unimportant things so I'll stick to the main point. The restaurant itself was clean and you can sit down comfortably.\n\nThe waiter highly recommended the pork avodovo plate, which is pork that is served doused with either a green or red sauce with beans and rice on the side. I cannot take spicy food, and the waiter insisted the red version\n Label: overpriced and spicy food"
    instruction = "In this task, we'll assign a short and precise label to a cluster of documents based on the topics or concepts most relevant to these documents. The documents are all subsets of a sentiment classification dataset.\n"
    prompt = build_prompt(instruction, cluster)
    resp = openai.Completion.create(
        prompt=prompt,
        engine='text-davinci-002',
        #frequency_penalty = 1.2,
    )
    label = resp['choices'][0]['text']
    return label.strip()

#code to build prompt and query gpt3 for labeling
def build_prompt(instruction, cluster_df):
    if len(cluster_df)>10:
        content = cluster_df['content'].str[:600].tolist()
    else:
        content = cluster_df['content'].str[:1000].tolist()
    examples = '\n - '.join(content)
    text = instruction + '- ' + examples+ '\n Cluster label:'
    return text.strip()

In [6]:
data_df_imdb = pd.read_parquet('./imdb_test_distillbert.parquet')
data_df_imdb = data_df_imdb.drop_duplicates(subset=['content'])
data_df_imdb['loss'] = data_df_imdb['loss'].astype(float)
losses = data_df_imdb['loss']
high_loss = losses.quantile(0.99)
data_df_imdb['slice'] = 'high-loss'
data_df_imdb['slice'] = data_df_imdb['slice'].where(data_df_imdb['loss'] > high_loss, 'low-loss')
data_hl_imdb = data_df_imdb.drop(data_df_imdb[data_df_imdb['slice'] == 'low-loss'].index)
data_ll_imdb = data_df_imdb.drop(data_df_imdb[data_df_imdb['slice'] == 'high-loss'].index) 
df_list_imdb = [d for _, d in data_hl_imdb.groupby(['label'])]

In [7]:
tmp = data_df_imdb.drop_duplicates(subset=['content'])

In [8]:
data_df_imdb = tmp

In [9]:
accuracy_score(data_df_imdb['label'],data_df_imdb['pred'])

0.8622382730076287

In [22]:
merged_lvl0=cluster_errors(df_list_imdb)
groups_lvl0,acc_lvl0 = generate_groups(merged_lvl0)

In [23]:
merged_lvl1=cluster_errors(groups_lvl0.values())
groups_lvl1,acc_lvl1 = generate_groups(merged_lvl1)

In [24]:
print(merged_lvl1['cluster'].max())
print(len(groups_lvl1))

32
33


In [27]:
cluster_labels= semantic_labeling(groups_lvl1)

100%|██████████| 33/33 [00:41<00:00,  1.27s/it]


In [52]:
match = data_df_imdb[data_df_imdb['content'].str.contains('hamlet' , case=False)]
print(len(match), accuracy_score(match['label'], match['pred']))

43 0.6976744186046512


In [51]:
losses.loc[losses>high_loss]

6        5.065617
61       3.967897
143      4.062117
285      4.036336
313      4.622490
           ...   
24580    3.634398
24605    4.533525
24659    3.830003
24708    4.107304
24751    4.207838
Name: loss, Length: 493, dtype: float64

In [37]:
losses = data_df_imdb['loss']
high_loss = losses.quantile(0.98)
loss_weights = np.where(losses > high_loss,losses,0.0)
loss_weights = loss_weights / loss_weights.sum()

In [50]:
len(data_df_imdb)

24644

In [14]:
unique_tokens = []
tokens = []
for row in tqdm(data_df_imdb['content']):
    tokenized = tokenizer(row,padding=True, return_tensors='pt', truncation=True)
    tokens.append(tokenized['input_ids'].flatten())
    unique_tokens.append(torch.unique(tokenized['input_ids']))

100%|██████████| 24644/24644 [00:16<00:00, 1519.28it/s]


In [36]:
loss_weights

0        9.110757e-06
1        3.830380e-04
2        1.652545e-06
3        7.110300e-05
4        1.504268e-04
             ...     
24816    1.350308e-05
24828    2.056361e-06
24829    6.765310e-07
24830    6.537001e-07
24831    6.824863e-06
Name: loss, Length: 24644, dtype: float64

In [29]:
len(data_df_imdb)

24644

In [55]:
def frequent_tokens(data, tokenizer, loss_quantile=0.98, top_k=200, smoothing=0.005):
    unique_tokens = []
    tokens = []
    for row in tqdm(data['content']):
        tokenized = tokenizer(row, padding=True, truncation=True, return_tensors='pt')
        tokens.append(tokenized['input_ids'].flatten())
        unique_tokens.append(torch.unique(tokenized['input_ids']))
    losses = data['loss'].astype(float)
    high_loss = losses.quantile(loss_quantile)
    loss_weights = np.where(losses > high_loss,losses,0.0)
    loss_weights = loss_weights / loss_weights.sum()

    token_frequencies = defaultdict(float)
    token_frequencies_error = defaultdict(float)
    weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights))

    for i in tqdm(range(len(data))):
        for token in unique_tokens[i]:
            token_frequencies[token.item()] += weights_uniform[i]
            token_frequencies_error[token.item()] += loss_weights[i]

    token_lrs = {k: (smoothing+token_frequencies_error[k]) / (
        smoothing+token_frequencies[k]) for k in token_frequencies}
    tokens_sorted = list(map(lambda x: x[0], sorted(
        token_lrs.items(), key=lambda x: x[1])[::-1]))

    top_tokens = []
    for i, (token) in enumerate(tokens_sorted[:top_k]):
        top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % (
            token_frequencies_error[token]), '%4.2f' % (token_lrs[token])])
    return pd.DataFrame(top_tokens, columns=['token', 'freq', 'error-freq', 'ratio'])

In [56]:
commontokens = frequent_tokens(data_df_imdb,AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb'))

100%|██████████| 24644/24644 [00:15<00:00, 1551.93it/s]
100%|██████████| 24644/24644 [00:05<00:00, 4428.85it/s]


In [58]:
commontokens.to_parquet('./assets/data/imdb_test_distillbert_tokens.parquet')