File size: 2,560 Bytes
4f12085
4f58d6c
4f12085
 
 
 
764a07d
 
ccd28c4
764a07d
7d696ff
4f12085
e6c4b07
c7e078f
 
 
 
 
 
 
9213df3
c3b6f23
e3b861a
764a07d
4f12085
 
260e0f1
4f12085
 
 
cd19d7a
 
58f811a
cd19d7a
55b4896
4f12085
 
 
 
30aaa45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684dd3e
9213df3
b5870b3
4f12085
3c93ca3
69143ff
4f12085
665ec8e
62a6cd4
8f4e395
 
4f12085
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
image_embeddings = torch.load('embeddings.pt')
links = np.load('data.npy', allow_pickle=True)


@st.experimental_memo
def image_search(query, top_k=10):
    with torch.no_grad():
        text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
    _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
    return [links[i] for i in indices[:top_k]]


def get_html(url_list):
    html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url in url_list:
        html2 = f"<img style='height: 180px; margin: 2px' src='{escape(url)}'>"  
        html = html + html2
    html += "</div>"
    return html


description = '''
# Persian (fa) image search 
- Enter your query and hit enter
- Note: We used a small set of images to keep this app almost real-time, but it's obvious that the quality of image search depends heavily on the size of the image database.

Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''


def main():
    st.markdown('''
                <style>
                .block-container{
                  max-width: 1200px;
                }
                section.main>div:first-child {
                  padding-top: 0px;
                }
                section:not(.main)>div:first-child {
                  padding-top: 30px;
                }
                div.reportview-container > section:first-child{
                  max-width: 320px;
                }
                #MainMenu {
                  visibility: hidden;
                }
                footer {
                  visibility: hidden;
                }
                </style>''',
            
                unsafe_allow_html=True)        
         
                
    st.sidebar.markdown(description)
    _, c, _ = st.columns((1, 3, 1))
    query = c.text_input('Search Box (type in fa)', value='قطره های باران روی شیشه')
    c.text("It'll take about 30s to load all new images")
    if len(query) > 0:
        results = image_search(query)
        st.markdown(get_html(results), unsafe_allow_html=True)


if __name__ == '__main__':
    main()