Spaces:
Runtime error
Runtime error
File size: 2,560 Bytes
4f12085 4f58d6c 4f12085 764a07d ccd28c4 764a07d 7d696ff 4f12085 e6c4b07 c7e078f 9213df3 c3b6f23 e3b861a 764a07d 4f12085 260e0f1 4f12085 cd19d7a 58f811a cd19d7a 55b4896 4f12085 30aaa45 684dd3e 9213df3 b5870b3 4f12085 3c93ca3 69143ff 4f12085 665ec8e 62a6cd4 8f4e395 4f12085 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
image_embeddings = torch.load('embeddings.pt')
links = np.load('data.npy', allow_pickle=True)
@st.experimental_memo
def image_search(query, top_k=10):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def get_html(url_list):
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: 180px; margin: 2px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
description = '''
# Persian (fa) image search
- Enter your query and hit enter
- Note: We used a small set of images to keep this app almost real-time, but it's obvious that the quality of image search depends heavily on the size of the image database.
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('Search Box (type in fa)', value='قطره های باران روی شیشه')
c.text("It'll take about 30s to load all new images")
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main() |