mostafa-sh's picture
add all-MiniLM-L6-v2 embeddings
0f9d3df
import json
import numpy as np
import random
import streamlit as st
from sentence_transformers import SentenceTransformer
@st.cache_resource
def load_youtube_data(base_path, embedding_model_name, chunk_tokens, overlap_tokens):
embedding_space_file_name = f'{base_path}/yt_embedding_space_{embedding_model_name}_tpc{chunk_tokens}_o{overlap_tokens}.json'
with open(embedding_space_file_name, 'r') as json_file:
loaded_data = json.load(json_file)
embedding_space = np.array(loaded_data['embedding_space'])
return loaded_data['chunks'], embedding_space
@st.cache_resource
def load_book_data(base_path, embedding_model_name, chunk_tokens, overlap_tokens):
embedding_space_file_name = f'{base_path}/latex_embedding_space_by_sections_{embedding_model_name}_tpc{chunk_tokens}_o{overlap_tokens}.json'
with open(embedding_space_file_name, 'r') as json_file:
loaded_data = json.load(json_file)
embedding_space = np.array(loaded_data['embedding_space'])
return loaded_data['chunks'], embedding_space
@st.cache_resource
def load_summary(file_path):
with open(file_path, 'r') as file:
transcripts = json.load(file)
return transcripts
def embed_question_sentence_transformer(texts, model_name="sentence-transformers/all-MiniLM-L6-v2"):
model = SentenceTransformer(model_name)
embeddings = model.encode(texts)
return embeddings.tolist()
def fixed_knn_retrieval(question_embedding, context_embeddings, top_k=5, min_k=1):
question_embedding = np.array(question_embedding)
# Normalize
question_embedding = question_embedding / np.linalg.norm(question_embedding)
context_embeddings = context_embeddings / np.linalg.norm(context_embeddings, axis=1, keepdims=True)
# Calculate cosine similarities between the question embedding and all context embeddings.
similarities = np.dot(context_embeddings, question_embedding)
# Sort the similarities in descending order and get the corresponding indices.
sorted_indices = np.argsort(similarities)[::-1]
# Select the top_k most similar contexts, ensuring at least min_k contexts are selected.
selected_indices = sorted_indices[:max(top_k, min_k)].tolist()
return selected_indices
def get_random_question(text_file):
with open(text_file, "r") as file:
questions = [line.strip() for line in file]
return random.choice(questions)