|
import streamlit as st
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import pandas as pd
|
|
import os
|
|
import time
|
|
import sentencepiece as spm
|
|
|
|
|
|
st.set_page_config(page_title="Embedding Model Comparison", layout="wide")
|
|
|
|
|
|
@st.cache_resource
|
|
def load_tokenizer(tokenizer_path="sentencepiece.model"):
|
|
if not os.path.exists(tokenizer_path):
|
|
st.error(f"Tokenizer file not found: {tokenizer_path}")
|
|
return None
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(tokenizer_path)
|
|
return sp
|
|
|
|
|
|
def load_model(model_path):
|
|
if not os.path.exists(model_path):
|
|
st.error(f"Model file not found: {model_path}")
|
|
return None
|
|
|
|
interpreter = tf.lite.Interpreter(model_path=model_path)
|
|
interpreter.allocate_tensors()
|
|
return interpreter
|
|
|
|
|
|
def get_embedding(text, interpreter, tokenizer):
|
|
if interpreter is None or tokenizer is None:
|
|
return None, 0
|
|
|
|
|
|
input_details = interpreter.get_input_details()
|
|
output_details = interpreter.get_output_details()
|
|
|
|
|
|
input_shape = input_details[0]['shape']
|
|
max_seq_length = input_shape[1] if len(input_shape) > 1 else 64
|
|
|
|
|
|
tokens = tokenizer.encode(text, out_type=int)
|
|
|
|
|
|
if len(tokens) > max_seq_length:
|
|
tokens = tokens[:max_seq_length]
|
|
else:
|
|
tokens = tokens + [0] * (max_seq_length - len(tokens))
|
|
|
|
|
|
token_ids = np.array([tokens], dtype=np.int32)
|
|
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], token_ids)
|
|
|
|
|
|
start_time = time.time()
|
|
interpreter.invoke()
|
|
inference_time = time.time() - start_time
|
|
|
|
|
|
embedding = interpreter.get_tensor(output_details[0]['index'])
|
|
|
|
return embedding, inference_time
|
|
|
|
|
|
def load_sentences(file_path):
|
|
if not os.path.exists(file_path):
|
|
return ["Hello world", "This is a test", "Embedding models are useful",
|
|
"TensorFlow Lite is great for mobile applications",
|
|
"Streamlit makes it easy to create web apps",
|
|
"Python is a popular programming language",
|
|
"Machine learning is an exciting field",
|
|
"Natural language processing helps computers understand human language",
|
|
"Semantic search finds meaning, not just keywords",
|
|
"Quantization reduces model size with minimal accuracy loss"]
|
|
|
|
with open(file_path, 'r') as f:
|
|
sentences = [line.strip() for line in f if line.strip()]
|
|
|
|
return sentences
|
|
|
|
|
|
def find_similar_sentences(query_embedding, sentence_embeddings, sentences):
|
|
if query_embedding is None or len(sentence_embeddings) == 0:
|
|
return []
|
|
|
|
|
|
similarities = cosine_similarity(query_embedding, sentence_embeddings)[0]
|
|
|
|
|
|
sorted_indices = np.argsort(similarities)[::-1]
|
|
|
|
|
|
results = []
|
|
for idx in sorted_indices:
|
|
results.append({
|
|
"sentence": sentences[idx],
|
|
"similarity": similarities[idx]
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
st.title("Embedding Model Comparison")
|
|
|
|
|
|
with st.sidebar:
|
|
st.header("Configuration")
|
|
old_model_path = st.text_input("Old Model Path", "old.tflite")
|
|
new_model_path = st.text_input("New Model Path", "new.tflite")
|
|
sentences_path = st.text_input("Sentences File Path", "sentences.txt")
|
|
tokenizer_path = st.text_input("Tokenizer Path", "sentencepiece.model")
|
|
|
|
|
|
tokenizer = load_tokenizer(tokenizer_path)
|
|
if tokenizer:
|
|
st.sidebar.success("Tokenizer loaded successfully")
|
|
st.sidebar.write(f"Vocabulary size: {tokenizer.get_piece_size()}")
|
|
else:
|
|
st.sidebar.error("Failed to load tokenizer")
|
|
return
|
|
|
|
|
|
st.header("Models")
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.subheader("Old Model")
|
|
old_model = load_model(old_model_path)
|
|
if old_model:
|
|
st.success("Old model loaded successfully")
|
|
old_input_details = old_model.get_input_details()
|
|
old_output_details = old_model.get_output_details()
|
|
st.write(f"Input shape: {old_input_details[0]['shape']}")
|
|
st.write(f"Output shape: {old_output_details[0]['shape']}")
|
|
|
|
with col2:
|
|
st.subheader("New Model")
|
|
new_model = load_model(new_model_path)
|
|
if new_model:
|
|
st.success("New model loaded successfully")
|
|
new_input_details = new_model.get_input_details()
|
|
new_output_details = new_model.get_output_details()
|
|
st.write(f"Input shape: {new_input_details[0]['shape']}")
|
|
st.write(f"Output shape: {new_output_details[0]['shape']}")
|
|
|
|
|
|
sentences = load_sentences(sentences_path)
|
|
st.header("Sentences")
|
|
st.write(f"Loaded {len(sentences)} sentences")
|
|
if st.checkbox("Show loaded sentences"):
|
|
st.write(sentences[:10])
|
|
if len(sentences) > 10:
|
|
st.write("...")
|
|
|
|
|
|
if 'old_sentence_embeddings' not in st.session_state or st.button("Recompute Embeddings"):
|
|
st.session_state.old_sentence_embeddings = []
|
|
st.session_state.new_sentence_embeddings = []
|
|
|
|
if old_model and new_model:
|
|
progress_bar = st.progress(0)
|
|
st.write("Computing sentence embeddings...")
|
|
|
|
for i, sentence in enumerate(sentences):
|
|
if i % 10 == 0:
|
|
progress_bar.progress(i / len(sentences))
|
|
|
|
old_embedding, _ = get_embedding(sentence, old_model, tokenizer)
|
|
new_embedding, _ = get_embedding(sentence, new_model, tokenizer)
|
|
|
|
if old_embedding is not None:
|
|
st.session_state.old_sentence_embeddings.append(old_embedding[0])
|
|
|
|
if new_embedding is not None:
|
|
st.session_state.new_sentence_embeddings.append(new_embedding[0])
|
|
|
|
progress_bar.progress(1.0)
|
|
st.write("Embeddings computed!")
|
|
|
|
|
|
st.header("Search")
|
|
query = st.text_input("Enter a search query")
|
|
|
|
if query and old_model and new_model:
|
|
|
|
with st.expander("View tokenization"):
|
|
tokens = tokenizer.encode(query, out_type=int)
|
|
pieces = tokenizer.encode(query, out_type=str)
|
|
st.write("Token IDs:", tokens)
|
|
st.write("Token pieces:", pieces)
|
|
|
|
|
|
old_query_embedding, old_time = get_embedding(query, old_model, tokenizer)
|
|
new_query_embedding, new_time = get_embedding(query, new_model, tokenizer)
|
|
|
|
|
|
old_results = find_similar_sentences(
|
|
old_query_embedding,
|
|
st.session_state.old_sentence_embeddings,
|
|
sentences
|
|
)
|
|
|
|
new_results = find_similar_sentences(
|
|
new_query_embedding,
|
|
st.session_state.new_sentence_embeddings,
|
|
sentences
|
|
)
|
|
|
|
|
|
for i, result in enumerate(old_results):
|
|
result["rank"] = i + 1
|
|
|
|
for i, result in enumerate(new_results):
|
|
result["rank"] = i + 1
|
|
|
|
|
|
old_df = pd.DataFrame([
|
|
{"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
|
|
for r in old_results
|
|
])
|
|
|
|
new_df = pd.DataFrame([
|
|
{"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
|
|
for r in new_results
|
|
])
|
|
|
|
|
|
st.subheader("Search Results")
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.markdown("### Old Model Results")
|
|
st.dataframe(old_df, use_container_width=True)
|
|
|
|
with col2:
|
|
st.markdown("### New Model Results")
|
|
st.dataframe(new_df, use_container_width=True)
|
|
|
|
|
|
st.subheader("Inference Time")
|
|
st.write(f"Old model: {old_time * 1000:.2f} ms")
|
|
st.write(f"New model: {new_time * 1000:.2f} ms")
|
|
st.write(f"Speed improvement: {old_time / new_time:.2f}x")
|
|
|
|
|
|
st.subheader("Embedding Visualizations")
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.write("Old Model Embedding (first 20 dimensions)")
|
|
st.bar_chart(pd.DataFrame({
|
|
'value': old_query_embedding[0][:20]
|
|
}))
|
|
|
|
with col2:
|
|
st.write("New Model Embedding (first 20 dimensions)")
|
|
st.bar_chart(pd.DataFrame({
|
|
'value': new_query_embedding[0][:20]
|
|
}))
|
|
|
|
if __name__ == "__main__":
|
|
main() |