Feelings_to_Emoji / emoji_processor.py
Dan Mo
Add script to generate and save embeddings for models
cfb0d15
raw
history blame contribute delete
9.18 kB
"""
Core emoji processing logic for the Emoji Mashup application.
"""
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import requests
from PIL import Image
from io import BytesIO
import os
from config import CONFIG, EMBEDDING_MODELS
from utils import (logger, kitchen_txt_to_dict,
save_embeddings_to_pickle, load_embeddings_from_pickle,
get_embeddings_pickle_path)
class EmojiProcessor:
def __init__(self, model_name=None, model_key=None, use_cached_embeddings=True):
"""Initialize the emoji processor with the specified model.
Args:
model_name: Direct name of the sentence transformer model to use
model_key: Key from EMBEDDING_MODELS to use (takes precedence over model_name)
use_cached_embeddings: Whether to use cached embeddings from pickle files
"""
# Get model name from the key if provided
if model_key and model_key in EMBEDDING_MODELS:
model_name = EMBEDDING_MODELS[model_key]['id']
elif not model_name:
model_name = CONFIG["model_name"]
logger.info(f"Loading model: {model_name}")
self.model = SentenceTransformer(model_name)
self.current_model_name = model_name
self.emotion_dict = {}
self.event_dict = {}
self.emotion_embeddings = {}
self.event_embeddings = {}
self.use_cached_embeddings = use_cached_embeddings
def load_emoji_dictionaries(self, emotion_file=CONFIG["emotion_file"], item_file=CONFIG["item_file"]):
"""Load emoji dictionaries from text files.
Args:
emotion_file: Path to the emotion emoji file
item_file: Path to the item emoji file
"""
logger.info("Loading emoji dictionaries")
self.emotion_dict = kitchen_txt_to_dict(emotion_file)
self.event_dict = kitchen_txt_to_dict(item_file)
# Load or compute embeddings
self._load_or_compute_embeddings()
def _load_or_compute_embeddings(self):
"""Load embeddings from pickle files if available, otherwise compute them."""
if self.use_cached_embeddings:
# Try to load emotion embeddings
emotion_pickle_path = get_embeddings_pickle_path(self.current_model_name, "emotion")
loaded_emotion_embeddings = load_embeddings_from_pickle(emotion_pickle_path)
# Try to load event embeddings
event_pickle_path = get_embeddings_pickle_path(self.current_model_name, "event")
loaded_event_embeddings = load_embeddings_from_pickle(event_pickle_path)
# Check if we need to compute any embeddings
compute_emotion = loaded_emotion_embeddings is None
compute_event = loaded_event_embeddings is None
if not compute_emotion:
# Verify all emoji keys are present in loaded embeddings
for emoji in self.emotion_dict.keys():
if emoji not in loaded_emotion_embeddings:
logger.info(f"Cached emotion embeddings missing emoji: {emoji}, will recompute")
compute_emotion = True
break
if not compute_emotion:
self.emotion_embeddings = loaded_emotion_embeddings
if not compute_event:
# Verify all emoji keys are present in loaded embeddings
for emoji in self.event_dict.keys():
if emoji not in loaded_event_embeddings:
logger.info(f"Cached event embeddings missing emoji: {emoji}, will recompute")
compute_event = True
break
if not compute_event:
self.event_embeddings = loaded_event_embeddings
# Compute any missing embeddings
if compute_emotion:
logger.info(f"Computing emotion embeddings for model: {self.current_model_name}")
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
# Save for future use
save_embeddings_to_pickle(self.emotion_embeddings, emotion_pickle_path)
if compute_event:
logger.info(f"Computing event embeddings for model: {self.current_model_name}")
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
# Save for future use
save_embeddings_to_pickle(self.event_embeddings, event_pickle_path)
else:
# Compute embeddings without caching
logger.info("Computing embeddings for emoji dictionaries (no caching)")
self.emotion_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.emotion_dict.items()}
self.event_embeddings = {emoji: self.model.encode(desc) for emoji, desc in self.event_dict.items()}
def switch_model(self, model_key):
"""Switch to a different embedding model.
Args:
model_key: Key from EMBEDDING_MODELS to use
Returns:
True if model was switched successfully, False otherwise
"""
if model_key not in EMBEDDING_MODELS:
logger.error(f"Unknown model key: {model_key}")
return False
model_name = EMBEDDING_MODELS[model_key]['id']
if model_name == self.current_model_name:
logger.info(f"Model {model_key} is already loaded")
return True
try:
logger.info(f"Switching to model: {model_name}")
self.model = SentenceTransformer(model_name)
self.current_model_name = model_name
# Load or recompute embeddings with new model
if self.emotion_dict and self.event_dict:
self._load_or_compute_embeddings()
return True
except Exception as e:
logger.error(f"Error switching model: {e}")
return False
def find_top_emojis(self, embedding, emoji_embeddings, top_n=1):
"""Find top matching emojis based on cosine similarity.
Args:
embedding: Sentence embedding to compare
emoji_embeddings: Dictionary of emoji embeddings
top_n: Number of top emojis to return
Returns:
List of top matching emojis
"""
similarities = [
(emoji, cosine_similarity([embedding], [e_embed])[0][0])
for emoji, e_embed in emoji_embeddings.items()
]
similarities.sort(key=lambda x: x[1], reverse=True)
return [emoji for emoji, _ in similarities[:top_n]]
def get_emoji_mashup_url(self, emoji1, emoji2, size=CONFIG["default_size"]):
"""Generate URL for emoji mashup.
Args:
emoji1: First emoji character
emoji2: Second emoji character
size: Image size in pixels
Returns:
URL for the emoji mashup
"""
return f"{CONFIG['emoji_kitchen_url'].format(emoji1=emoji1, emoji2=emoji2)}?size={size}"
def fetch_mashup_image(self, url):
"""Fetch emoji mashup image from URL.
Args:
url: URL of the emoji mashup image
Returns:
PIL Image object or None if fetch failed
"""
try:
response = requests.get(url)
if response.status_code == 200 and "image" in response.headers.get("Content-Type", ""):
return Image.open(BytesIO(response.content))
else:
logger.warning(f"Failed to fetch image: Status code {response.status_code}")
return None
except Exception as e:
logger.error(f"Error fetching image: {e}")
return None
def sentence_to_emojis(self, sentence):
"""Process sentence to find matching emojis and generate mashup.
Args:
sentence: User input text
Returns:
Tuple of (emotion emoji, event emoji, mashup image)
"""
if not sentence.strip():
return "❓", "❓", None
try:
embedding = self.model.encode(sentence)
top_emotion = self.find_top_emojis(embedding, self.emotion_embeddings, top_n=1)[0]
top_event = self.find_top_emojis(embedding, self.event_embeddings, top_n=1)[0]
mashup_url = self.get_emoji_mashup_url(top_emotion, top_event)
mashup_image = self.fetch_mashup_image(mashup_url)
return top_emotion, top_event, mashup_image
except Exception as e:
logger.error(f"Error processing sentence: {e}")
return "❌", "❌", None