Spaces:
Sleeping
Sleeping
Dan Mo
commited on
Commit
·
cfb0d15
1
Parent(s):
cf957e4
Add script to generate and save embeddings for models
Browse files- Implemented `generate_embeddings.py` to load embedding models and generate embeddings for emotion and event dictionaries.
- Added functionality to save generated embeddings as pickle files in the 'embeddings' directory.
- Included error handling and logging for better debugging and tracking of the embedding generation process.
- .gitignore +58 -0
- app.py +136 -17
- config.py +19 -0
- embeddings/BAAI_bge-large-en-v1.5_emotion.pkl +3 -0
- embeddings/BAAI_bge-large-en-v1.5_event.pkl +3 -0
- embeddings/all-mpnet-base-v2_emotion.pkl +3 -0
- embeddings/all-mpnet-base-v2_event.pkl +3 -0
- embeddings/thenlper_gte-large_emotion.pkl +3 -0
- embeddings/thenlper_gte-large_event.pkl +3 -0
- emoji_processor.py +106 -8
- generate_embeddings.py +99 -0
- utils.py +59 -1
.gitignore
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python cache files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
|
8 |
+
# Distribution / packaging
|
9 |
+
dist/
|
10 |
+
build/
|
11 |
+
*.egg-info/
|
12 |
+
|
13 |
+
# Virtual environments
|
14 |
+
venv/
|
15 |
+
env/
|
16 |
+
ENV/
|
17 |
+
|
18 |
+
# Jupyter Notebook
|
19 |
+
.ipynb_checkpoints
|
20 |
+
|
21 |
+
# VS Code
|
22 |
+
.vscode/
|
23 |
+
*.code-workspace
|
24 |
+
|
25 |
+
# PyCharm
|
26 |
+
.idea/
|
27 |
+
|
28 |
+
# Logs
|
29 |
+
*.log
|
30 |
+
logs/
|
31 |
+
|
32 |
+
# OS specific files
|
33 |
+
.DS_Store
|
34 |
+
Thumbs.db
|
35 |
+
desktop.ini
|
36 |
+
|
37 |
+
# Environment variables
|
38 |
+
.env
|
39 |
+
.env.local
|
40 |
+
|
41 |
+
# Temporary files
|
42 |
+
*.swp
|
43 |
+
*.swo
|
44 |
+
*~
|
45 |
+
.temp/
|
46 |
+
|
47 |
+
# NOTE: We're keeping the embeddings/*.pkl files since they're pre-generated
|
48 |
+
# for faster startup. They're managed by Git LFS as specified in .gitattributes.
|
49 |
+
|
50 |
+
# Gradio specific
|
51 |
+
gradio_cached_examples/
|
52 |
+
flagged/
|
53 |
+
|
54 |
+
# Local development files
|
55 |
+
.jupyter/
|
56 |
+
.local/
|
57 |
+
.bash_history
|
58 |
+
.python_history
|
app.py
CHANGED
@@ -6,36 +6,155 @@ This module handles the Gradio interface and application setup.
|
|
6 |
import gradio as gr
|
7 |
from utils import logger
|
8 |
from emoji_processor import EmojiProcessor
|
|
|
9 |
|
10 |
class EmojiMashupApp:
|
11 |
def __init__(self):
|
12 |
"""Initialize the Gradio application."""
|
13 |
logger.info("Initializing Emoji Mashup App")
|
14 |
-
self.processor = EmojiProcessor()
|
15 |
self.processor.load_emoji_dictionaries()
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def create_interface(self):
|
18 |
"""Create and configure the Gradio interface.
|
19 |
|
20 |
Returns:
|
21 |
Gradio Interface object
|
22 |
"""
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
gr.
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def run(self, share=True):
|
41 |
"""Launch the Gradio application.
|
|
|
6 |
import gradio as gr
|
7 |
from utils import logger
|
8 |
from emoji_processor import EmojiProcessor
|
9 |
+
from config import EMBEDDING_MODELS
|
10 |
|
11 |
class EmojiMashupApp:
|
12 |
def __init__(self):
|
13 |
"""Initialize the Gradio application."""
|
14 |
logger.info("Initializing Emoji Mashup App")
|
15 |
+
self.processor = EmojiProcessor(model_key="mpnet", use_cached_embeddings=True) # Default to mpnet
|
16 |
self.processor.load_emoji_dictionaries()
|
17 |
|
18 |
+
def create_model_dropdown_choices(self):
|
19 |
+
"""Create formatted choices for the model dropdown.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
List of formatted model choices
|
23 |
+
"""
|
24 |
+
return [
|
25 |
+
f"{key} ({info['size']}) - {info['notes']}"
|
26 |
+
for key, info in EMBEDDING_MODELS.items()
|
27 |
+
]
|
28 |
+
|
29 |
+
def handle_model_change(self, dropdown_value, use_cached_embeddings):
|
30 |
+
"""Handle model selection change from dropdown.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dropdown_value: Selected value from dropdown
|
34 |
+
use_cached_embeddings: Whether to use cached embeddings
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Status message about model change
|
38 |
+
"""
|
39 |
+
# Extract model key from dropdown value (first word before space)
|
40 |
+
model_key = dropdown_value.split()[0] if dropdown_value else "mpnet"
|
41 |
+
|
42 |
+
# Update processor cache setting
|
43 |
+
self.processor.use_cached_embeddings = use_cached_embeddings
|
44 |
+
|
45 |
+
if model_key in EMBEDDING_MODELS:
|
46 |
+
success = self.processor.switch_model(model_key)
|
47 |
+
if success:
|
48 |
+
cache_status = "using cached embeddings" if use_cached_embeddings else "computing fresh embeddings"
|
49 |
+
return f"Switched to {model_key} model ({cache_status}): {EMBEDDING_MODELS[model_key]['notes']}"
|
50 |
+
else:
|
51 |
+
return f"Failed to switch to {model_key} model"
|
52 |
+
else:
|
53 |
+
return f"Unknown model: {model_key}"
|
54 |
+
|
55 |
+
def process_with_model(self, model_selection, text, use_cached_embeddings):
|
56 |
+
"""Process text with selected model.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model_selection: Selected model from dropdown
|
60 |
+
text: User input text
|
61 |
+
use_cached_embeddings: Whether to use cached embeddings
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tuple of (emotion emoji, event emoji, mashup image)
|
65 |
+
"""
|
66 |
+
# Extract model key from dropdown value (first word before space)
|
67 |
+
model_key = model_selection.split()[0] if model_selection else "mpnet"
|
68 |
+
|
69 |
+
# Update processor cache setting
|
70 |
+
self.processor.use_cached_embeddings = use_cached_embeddings
|
71 |
+
|
72 |
+
if model_key in EMBEDDING_MODELS:
|
73 |
+
self.processor.switch_model(model_key)
|
74 |
+
|
75 |
+
# Process text with current model
|
76 |
+
return self.processor.sentence_to_emojis(text)
|
77 |
+
|
78 |
def create_interface(self):
|
79 |
"""Create and configure the Gradio interface.
|
80 |
|
81 |
Returns:
|
82 |
Gradio Interface object
|
83 |
"""
|
84 |
+
with gr.Blocks(title="Sentence → Emoji Mashup") as interface:
|
85 |
+
gr.Markdown("# Sentence → Emoji Mashup")
|
86 |
+
gr.Markdown("Get the top emotion and event emoji from your sentence, and view the mashup!")
|
87 |
+
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column(scale=3):
|
90 |
+
# Model selection dropdown
|
91 |
+
model_dropdown = gr.Dropdown(
|
92 |
+
choices=self.create_model_dropdown_choices(),
|
93 |
+
value=self.create_model_dropdown_choices()[0], # Default to first model (mpnet)
|
94 |
+
label="Embedding Model",
|
95 |
+
info="Select the model used for text-emoji matching"
|
96 |
+
)
|
97 |
+
|
98 |
+
# Cache toggle
|
99 |
+
cache_toggle = gr.Checkbox(
|
100 |
+
label="Use cached embeddings",
|
101 |
+
value=True,
|
102 |
+
info="When enabled, embeddings will be saved to and loaded from disk"
|
103 |
+
)
|
104 |
+
|
105 |
+
# Text input
|
106 |
+
text_input = gr.Textbox(
|
107 |
+
lines=2,
|
108 |
+
placeholder="Type a sentence...",
|
109 |
+
label="Your message"
|
110 |
+
)
|
111 |
+
|
112 |
+
# Process button
|
113 |
+
submit_btn = gr.Button("Generate Emoji Mashup", variant="primary")
|
114 |
+
|
115 |
+
with gr.Column(scale=2):
|
116 |
+
# Model info display
|
117 |
+
model_info = gr.Textbox(
|
118 |
+
value=f"Using mpnet model (using cached embeddings): {EMBEDDING_MODELS['mpnet']['notes']}",
|
119 |
+
label="Model Info",
|
120 |
+
interactive=False
|
121 |
+
)
|
122 |
+
|
123 |
+
# Output displays
|
124 |
+
emotion_out = gr.Text(label="Top Emotion Emoji")
|
125 |
+
event_out = gr.Text(label="Top Event Emoji")
|
126 |
+
mashup_out = gr.Image(label="Mashup Emoji")
|
127 |
+
|
128 |
+
# Set up event handlers
|
129 |
+
model_dropdown.change(
|
130 |
+
fn=self.handle_model_change,
|
131 |
+
inputs=[model_dropdown, cache_toggle],
|
132 |
+
outputs=[model_info]
|
133 |
+
)
|
134 |
+
|
135 |
+
cache_toggle.change(
|
136 |
+
fn=self.handle_model_change,
|
137 |
+
inputs=[model_dropdown, cache_toggle],
|
138 |
+
outputs=[model_info]
|
139 |
+
)
|
140 |
+
|
141 |
+
submit_btn.click(
|
142 |
+
fn=self.process_with_model,
|
143 |
+
inputs=[model_dropdown, text_input, cache_toggle],
|
144 |
+
outputs=[emotion_out, event_out, mashup_out]
|
145 |
+
)
|
146 |
+
|
147 |
+
# Examples
|
148 |
+
gr.Examples(
|
149 |
+
examples=[
|
150 |
+
["I feel so happy today!"],
|
151 |
+
["I'm really angry right now"],
|
152 |
+
["Feeling tired after a long day"]
|
153 |
+
],
|
154 |
+
inputs=text_input
|
155 |
+
)
|
156 |
+
|
157 |
+
return interface
|
158 |
|
159 |
def run(self, share=True):
|
160 |
"""Launch the Gradio application.
|
config.py
CHANGED
@@ -9,4 +9,23 @@ CONFIG = {
|
|
9 |
"item_file": "google-emoji-kitchen-item.txt",
|
10 |
"emoji_kitchen_url": "https://emojik.vercel.app/s/{emoji1}_{emoji2}",
|
11 |
"default_size": 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
}
|
|
|
9 |
"item_file": "google-emoji-kitchen-item.txt",
|
10 |
"emoji_kitchen_url": "https://emojik.vercel.app/s/{emoji1}_{emoji2}",
|
11 |
"default_size": 256
|
12 |
+
}
|
13 |
+
|
14 |
+
# Available embedding models
|
15 |
+
EMBEDDING_MODELS = {
|
16 |
+
"mpnet": {
|
17 |
+
"id": "all-mpnet-base-v2",
|
18 |
+
"size": "110M",
|
19 |
+
"notes": "Balanced, great general-purpose model"
|
20 |
+
},
|
21 |
+
"gte": {
|
22 |
+
"id": "thenlper/gte-large",
|
23 |
+
"size": "335M",
|
24 |
+
"notes": "Context-rich, good for emotion & nuance"
|
25 |
+
},
|
26 |
+
"bge": {
|
27 |
+
"id": "BAAI/bge-large-en-v1.5",
|
28 |
+
"size": "350M",
|
29 |
+
"notes": "Tuned for ranking & high-precision similarity"
|
30 |
+
}
|
31 |
}
|
embeddings/BAAI_bge-large-en-v1.5_emotion.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5456af7ceaa04bdc28b9b125e317eaebf503c60b6937f006b54c595850c3830a
|
3 |
+
size 463549
|
embeddings/BAAI_bge-large-en-v1.5_event.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c32a321359bd0a197e906c85731003294a835e7590c655043e6e9ebdfa607de9
|
3 |
+
size 2238733
|
embeddings/all-mpnet-base-v2_emotion.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6db3183f80970f30c7dee0cf846c832b5505890a071c1af8009f6ff452083f7c
|
3 |
+
size 348852
|
embeddings/all-mpnet-base-v2_event.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65d434cb2cdd1034e494a87d354345e67bfd25a90a44247cfa3406dc100334c0
|
3 |
+
size 1684668
|
embeddings/thenlper_gte-large_emotion.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8b5472bf5008613f76ac06738fa55c91ff2fd6ae7472c9a1f739d210b5f2f0e
|
3 |
+
size 463549
|
embeddings/thenlper_gte-large_event.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81961955fad517b578deb1969c8e84594fe92c8eed32d6b43f85e804f5214b82
|
3 |
+
size 2238733
|
emoji_processor.py
CHANGED
@@ -7,23 +7,36 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
7 |
import requests
|
8 |
from PIL import Image
|
9 |
from io import BytesIO
|
|
|
10 |
|
11 |
-
from config import CONFIG
|
12 |
-
from utils import logger, kitchen_txt_to_dict
|
|
|
|
|
13 |
|
14 |
class EmojiProcessor:
|
15 |
-
def __init__(self, model_name=
|
16 |
"""Initialize the emoji processor with the specified model.
|
17 |
|
18 |
Args:
|
19 |
-
model_name:
|
|
|
|
|
20 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
logger.info(f"Loading model: {model_name}")
|
22 |
self.model = SentenceTransformer(model_name)
|
|
|
23 |
self.emotion_dict = {}
|
24 |
self.event_dict = {}
|
25 |
self.emotion_embeddings = {}
|
26 |
self.event_embeddings = {}
|
|
|
27 |
|
28 |
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
|
29 |
"""Load emoji dictionaries from text files.
|
@@ -36,10 +49,95 @@ class EmojiProcessor:
|
|
36 |
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
|
37 |
self.event_dict = kitchen_txt_to_dict(item_file)
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
|
45 |
"""Find top matching emojis based on cosine similarity.
|
|
|
7 |
import requests
|
8 |
from PIL import Image
|
9 |
from io import BytesIO
|
10 |
+
import os
|
11 |
|
12 |
+
from config import CONFIG, EMBEDDING_MODELS
|
13 |
+
from utils import (logger, kitchen_txt_to_dict,
|
14 |
+
save_embeddings_to_pickle, load_embeddings_from_pickle,
|
15 |
+
get_embeddings_pickle_path)
|
16 |
|
17 |
class EmojiProcessor:
|
18 |
+
def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True):
|
19 |
"""Initialize the emoji processor with the specified model.
|
20 |
|
21 |
Args:
|
22 |
+
model_name: Direct name of the sentence transformer model to use
|
23 |
+
model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name)
|
24 |
+
use_cached_embeddings: Whether to use cached embeddings from pickle files
|
25 |
"""
|
26 |
+
# Get model name from the key if provided
|
27 |
+
if model_key and model_key in EMBEDDING_MODELS:
|
28 |
+
model_name = EMBEDDING_MODELS[model_key]['id']
|
29 |
+
elif not model_name:
|
30 |
+
model_name = CONFIG["model_name"]
|
31 |
+
|
32 |
logger.info(f"Loading model: {model_name}")
|
33 |
self.model = SentenceTransformer(model_name)
|
34 |
+
self.current_model_name = model_name
|
35 |
self.emotion_dict = {}
|
36 |
self.event_dict = {}
|
37 |
self.emotion_embeddings = {}
|
38 |
self.event_embeddings = {}
|
39 |
+
self.use_cached_embeddings = use_cached_embeddings
|
40 |
|
41 |
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
|
42 |
"""Load emoji dictionaries from text files.
|
|
|
49 |
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
|
50 |
self.event_dict = kitchen_txt_to_dict(item_file)
|
51 |
|
52 |
+
# Load or compute embeddings
|
53 |
+
self._load_or_compute_embeddings()
|
54 |
+
|
55 |
+
def _load_or_compute_embeddings(self):
|
56 |
+
"""Load embeddings from pickle files if available, otherwise compute them."""
|
57 |
+
if self.use_cached_embeddings:
|
58 |
+
# Try to load emotion embeddings
|
59 |
+
emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion")
|
60 |
+
loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path)
|
61 |
+
|
62 |
+
# Try to load event embeddings
|
63 |
+
event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event")
|
64 |
+
loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path)
|
65 |
+
|
66 |
+
# Check if we need to compute any embeddings
|
67 |
+
compute_emotion = loaded_emotion_embeddings is None
|
68 |
+
compute_event = loaded_event_embeddings is None
|
69 |
+
|
70 |
+
if not compute_emotion:
|
71 |
+
# Verify all emoji keys are present in loaded embeddings
|
72 |
+
for emoji in self.emotion_dict.keys():
|
73 |
+
if emoji not in loaded_emotion_embeddings:
|
74 |
+
logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute")
|
75 |
+
compute_emotion = True
|
76 |
+
break
|
77 |
+
|
78 |
+
if not compute_emotion:
|
79 |
+
self.emotion_embeddings = loaded_emotion_embeddings
|
80 |
+
|
81 |
+
if not compute_event:
|
82 |
+
# Verify all emoji keys are present in loaded embeddings
|
83 |
+
for emoji in self.event_dict.keys():
|
84 |
+
if emoji not in loaded_event_embeddings:
|
85 |
+
logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute")
|
86 |
+
compute_event = True
|
87 |
+
break
|
88 |
+
|
89 |
+
if not compute_event:
|
90 |
+
self.event_embeddings = loaded_event_embeddings
|
91 |
+
|
92 |
+
# Compute any missing embeddings
|
93 |
+
if compute_emotion:
|
94 |
+
logger.info(f"Computing emotion embeddings for model: {self.current_model_name}")
|
95 |
+
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
|
96 |
+
# Save for future use
|
97 |
+
save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path)
|
98 |
+
|
99 |
+
if compute_event:
|
100 |
+
logger.info(f"Computing event embeddings for model: {self.current_model_name}")
|
101 |
+
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
|
102 |
+
# Save for future use
|
103 |
+
save_embeddings_to_pickle(self.event_embeddings, event_pickle_path)
|
104 |
+
else:
|
105 |
+
# Compute embeddings without caching
|
106 |
+
logger.info("Computing embeddings for emoji dictionaries (no caching)")
|
107 |
+
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
|
108 |
+
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
|
109 |
+
|
110 |
+
def switch_model(self, model_key):
|
111 |
+
"""Switch to a different embedding model.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
model_key: Key from EMBEDDING_MODELS to use
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
True if model was switched successfully, False otherwise
|
118 |
+
"""
|
119 |
+
if model_key not in EMBEDDING_MODELS:
|
120 |
+
logger.error(f"Unknown model key: {model_key}")
|
121 |
+
return False
|
122 |
+
|
123 |
+
model_name = EMBEDDING_MODELS[model_key]['id']
|
124 |
+
if model_name == self.current_model_name:
|
125 |
+
logger.info(f"Model {model_key} is already loaded")
|
126 |
+
return True
|
127 |
+
|
128 |
+
try:
|
129 |
+
logger.info(f"Switching to model: {model_name}")
|
130 |
+
self.model = SentenceTransformer(model_name)
|
131 |
+
self.current_model_name = model_name
|
132 |
+
|
133 |
+
# Load or recompute embeddings with new model
|
134 |
+
if self.emotion_dict and self.event_dict:
|
135 |
+
self._load_or_compute_embeddings()
|
136 |
+
|
137 |
+
return True
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Error switching model: {e}")
|
140 |
+
return False
|
141 |
|
142 |
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
|
143 |
"""Find top matching emojis based on cosine similarity.
|
generate_embeddings.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility script to pre-generate embedding pickle files for all models.
|
3 |
+
|
4 |
+
This script will:
|
5 |
+
1. Load each embedding model
|
6 |
+
2. Generate embeddings for both emotion and event dictionaries
|
7 |
+
3. Save the embeddings as pickle files in the 'embeddings' directory
|
8 |
+
|
9 |
+
Run this script once locally to create all pickle files before uploading to the repository.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import os
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from config import CONFIG, EMBEDDING_MODELS
|
17 |
+
from utils import (logger, kitchen_txt_to_dict,
|
18 |
+
save_embeddings_to_pickle, get_embeddings_pickle_path)
|
19 |
+
|
20 |
+
def generate_embeddings_for_model(model_key, model_info):
|
21 |
+
"""Generate and save embeddings for a specific model.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_key: Key of the model in EMBEDDING_MODELS
|
25 |
+
model_info: Model information dictionary
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Tuple of (success_emotion, success_event)
|
29 |
+
"""
|
30 |
+
model_id = model_info['id']
|
31 |
+
print(f"\nProcessing model: {model_key} ({model_id}) - {model_info['size']}")
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Load the model
|
35 |
+
print(f"Loading {model_key} model...")
|
36 |
+
model = SentenceTransformer(model_id)
|
37 |
+
|
38 |
+
# Load emoji dictionaries
|
39 |
+
print("Loading emoji dictionaries...")
|
40 |
+
emotion_dict = kitchen_txt_to_dict(CONFIG["emotion_file"])
|
41 |
+
event_dict = kitchen_txt_to_dict(CONFIG["item_file"])
|
42 |
+
|
43 |
+
if not emotion_dict or not event_dict:
|
44 |
+
print("Error: Failed to load emoji dictionaries")
|
45 |
+
return False, False
|
46 |
+
|
47 |
+
# Generate emotion embeddings
|
48 |
+
print(f"Generating {len(emotion_dict)} emotion embeddings...")
|
49 |
+
emotion_embeddings = {}
|
50 |
+
for emoji, desc in tqdm(emotion_dict.items()):
|
51 |
+
emotion_embeddings[emoji] = model.encode(desc)
|
52 |
+
|
53 |
+
# Generate event embeddings
|
54 |
+
print(f"Generating {len(event_dict)} event embeddings...")
|
55 |
+
event_embeddings = {}
|
56 |
+
for emoji, desc in tqdm(event_dict.items()):
|
57 |
+
event_embeddings[emoji] = model.encode(desc)
|
58 |
+
|
59 |
+
# Save embeddings
|
60 |
+
emotion_pickle_path = get_embeddings_pickle_path(model_id, "emotion")
|
61 |
+
event_pickle_path = get_embeddings_pickle_path(model_id, "event")
|
62 |
+
|
63 |
+
success_emotion = save_embeddings_to_pickle(emotion_embeddings, emotion_pickle_path)
|
64 |
+
success_event = save_embeddings_to_pickle(event_embeddings, event_pickle_path)
|
65 |
+
|
66 |
+
return success_emotion, success_event
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error generating embeddings for model {model_key}: {e}")
|
69 |
+
return False, False
|
70 |
+
|
71 |
+
def main():
|
72 |
+
"""Main function to generate embeddings for all models."""
|
73 |
+
# Create embeddings directory if it doesn't exist
|
74 |
+
os.makedirs('embeddings', exist_ok=True)
|
75 |
+
|
76 |
+
print(f"Generating embeddings for {len(EMBEDDING_MODELS)} models...")
|
77 |
+
|
78 |
+
results = {}
|
79 |
+
|
80 |
+
# Generate embeddings for each model
|
81 |
+
for model_key, model_info in EMBEDDING_MODELS.items():
|
82 |
+
success_emotion, success_event = generate_embeddings_for_model(model_key, model_info)
|
83 |
+
results[model_key] = {
|
84 |
+
'emotion': success_emotion,
|
85 |
+
'event': success_event
|
86 |
+
}
|
87 |
+
|
88 |
+
# Print summary
|
89 |
+
print("\n=== Embedding Generation Summary ===")
|
90 |
+
for model_key, result in results.items():
|
91 |
+
status_emotion = "✓ Success" if result['emotion'] else "✗ Failed"
|
92 |
+
status_event = "✓ Success" if result['event'] else "✗ Failed"
|
93 |
+
print(f"{model_key:<10}: Emotion: {status_emotion}, Event: {status_event}")
|
94 |
+
|
95 |
+
print("\nDone! Embedding pickle files are stored in the 'embeddings' directory.")
|
96 |
+
print("You can now upload these files to your repository.")
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
main()
|
utils.py
CHANGED
@@ -3,6 +3,8 @@ Utility functions for the Emoji Mashup application.
|
|
3 |
"""
|
4 |
|
5 |
import logging
|
|
|
|
|
6 |
|
7 |
# Configure logging
|
8 |
def setup_logging():
|
@@ -36,4 +38,60 @@ def kitchen_txt_to_dict(filepath):
|
|
36 |
return emoji_dict
|
37 |
except Exception as e:
|
38 |
logger.error(f"Error loading emoji dictionary from {filepath}: {e}")
|
39 |
-
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
import logging
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
|
9 |
# Configure logging
|
10 |
def setup_logging():
|
|
|
38 |
return emoji_dict
|
39 |
except Exception as e:
|
40 |
logger.error(f"Error loading emoji dictionary from {filepath}: {e}")
|
41 |
+
return {}
|
42 |
+
|
43 |
+
def save_embeddings_to_pickle(embeddings, filepath):
|
44 |
+
"""Save embeddings dictionary to a pickle file.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
embeddings: Dictionary of embeddings to save
|
48 |
+
filepath: Path to save the pickle file to
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
True if successful, False otherwise
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
55 |
+
with open(filepath, 'wb') as f:
|
56 |
+
pickle.dump(embeddings, f)
|
57 |
+
logger.info(f"Saved embeddings to {filepath}")
|
58 |
+
return True
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Error saving embeddings to {filepath}: {e}")
|
61 |
+
return False
|
62 |
+
|
63 |
+
def load_embeddings_from_pickle(filepath):
|
64 |
+
"""Load embeddings dictionary from a pickle file.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
filepath: Path to load the pickle file from
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Dictionary of embeddings if successful, None otherwise
|
71 |
+
"""
|
72 |
+
if not os.path.exists(filepath):
|
73 |
+
logger.info(f"Pickle file {filepath} does not exist")
|
74 |
+
return None
|
75 |
+
|
76 |
+
try:
|
77 |
+
with open(filepath, 'rb') as f:
|
78 |
+
embeddings = pickle.load(f)
|
79 |
+
logger.info(f"Loaded embeddings from {filepath}")
|
80 |
+
return embeddings
|
81 |
+
except Exception as e:
|
82 |
+
logger.error(f"Error loading embeddings from {filepath}: {e}")
|
83 |
+
return None
|
84 |
+
|
85 |
+
def get_embeddings_pickle_path(model_id, emoji_type):
|
86 |
+
"""Generate the path for an embeddings pickle file.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
model_id: ID of the embedding model
|
90 |
+
emoji_type: Type of emoji ('emotion' or 'event')
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Path to the embeddings pickle file
|
94 |
+
"""
|
95 |
+
# Create a safe filename from the model ID
|
96 |
+
safe_model_id = model_id.replace('/', '_').replace('\\', '_')
|
97 |
+
return os.path.join('embeddings', f"{safe_model_id}_{emoji_type}.pkl")
|