File size: 3,510 Bytes
3b2519a
 
 
 
 
 
 
 
 
 
 
 
436442c
 
 
3b2519a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
import pandas as pd
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer, util
from langdetect import detect
import plotly.express as px
from collections import Counter

# sidebar
with st.sidebar:
    st.header("Examples:")
    st.markdown("This search finds content in Medium .")


# main content
st.header("Semantic Search Engine on [Medium](https://medium.com/) articles")
st.markdown("This is a small demo project of a semantic search engine over a dataset of ~190k Medium articles.")

st_placeholder_loading = st.empty()
st_placeholder_loading.text('Loading medium articles data...')

@st.cache(allow_output_mutation=True)
def load_data():
    df_articles = pd.read_csv(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_no_text.csv"))
    corpus_embeddings = pickle.load(open(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_embeddings.pickle"), "rb"))
    embedder = SentenceTransformer('all-MiniLM-L6-v2')
    return df_articles, corpus_embeddings, embedder

df_articles, corpus_embeddings, embedder = load_data()
st_placeholder_loading.empty()

n_top_tags = 20
@st.cache()
def load_chart_top_tags():
    # Occurrences of the top 50 tags
    print("we")
    all_tags = [tag for tags_list in df_articles["tags"] for tag in eval(tags_list)]
    d_tags_counter = Counter(all_tags)
    tags, frequencies = list(zip(*d_tags_counter.most_common(n=n_top_tags)))
    fig = px.bar(x=tags, y=frequencies)
    fig.update_xaxes(title="tags")
    fig.update_yaxes(title="frequencies")
    return fig

fig_top_tags = load_chart_top_tags()

st_query = st.text_input("Write your query here", max_chars=100)

def on_click_search():
    if st_query != "":
        query_embedding = embedder.encode(st_query, convert_to_tensor=True)
        top_k = 10
        hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k*2)[0]
        article_dicts = []
        for hit in hits:
            score = hit['score']
            article_row = df_articles.iloc[hit['corpus_id']]
            try:
                detected_lang = detect(article_row["title"])
            except:
                detected_lang = ""
            if detected_lang == "en" and len(article_row["title"]) >= 10:
                article_dicts.append({
                    "title": article_row['title'],
                    "url": article_row['url'],
                    "score": score
                })
                if len(article_dicts) >= top_k:
                    break
        st.session_state.article_dicts = article_dicts
        st.session_state.empty_query = False
    else:
        st.session_state.article_dicts = []
        st.session_state.empty_query = True
st.button("Search", on_click=on_click_search)
if st_query != "":
    st.session_state.empty_query = False
    on_click_search()
else:
    st.session_state.empty_query = True

if not st.session_state.empty_query:
    st.markdown("### Results")
    st.markdown("*Scores between parentheses represent the similarity between the article and the query.*")
    for article_dict in st.session_state.article_dicts:
        st.markdown(f"""- [{article_dict['title'].capitalize()}]({article_dict['url']}) ({article_dict['score']:.2f})""")
elif st.session_state.empty_query and "article_dicts" in st.session_state:
    st.markdown("Please write a query and then press the search button.")