bhavyagiri commited on
Commit
9c7a90d
·
1 Parent(s): 7449338

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+ from huggingface_hub import hf_hub_download
3
+ import os
4
+ import pickle
5
+ import pandas as pd
6
+ import gradio as gr
7
+
8
+ pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
9
+
10
+ corpus_embeddings = pickle.load(open(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
11
+ df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))
12
+
13
+ model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
14
+
15
+ def generate_memes(prompt):
16
+ prompt_embedding = model.encode(prompt, convert_to_tensor=True)
17
+ hits = util.semantic_search(prompt_embedding, embeddings, top_k=5)
18
+ hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
19
+ desired_ids = hits["corpus_id"]
20
+ filtered_df = df.loc[df['id'].isin(desired_ids)]
21
+ filtered_list = list(filtered_df["url"])
22
+ images = [gr.Image.update(value=img, visible=True) for img in filtered_list]
23
+ return (
24
+ images
25
+ )
26
+ input_textbox = gr.inputs.Textbox(lines=2, label="Search something cool", max_length=256)
27
+ output_gallery = gr.output.Gallery(
28
+ label="Retrieved Memes", show_label=False, elem_id="gallery"
29
+ ).style(columns=[3], rows=[2], object_fit="contain", height="auto")
30
+ title = "Semantic Search for Memes"
31
+ description = "Search Memes from small dataset of 6k memes"
32
+ examples = ['Spiderman giving lecture', 'Angry Karen']
33
+ interpretation='default'
34
+ enable_queue=True
35
+
36
+ iface = gr.Interface(fn=classify_garbage, inputs=input_textbox, outputs=label,examples=examples,title=title,description=description,interpretation=interpretation,enable_queue=enable_queue)
37
+ iface.launch(inline=False)