vishalsh13's picture
updated for torch
9dc951f
raw
history blame
2.34 kB
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from app.utils.file_handler import extract_text_from_file
device = "cuda" if torch.cuda.is_available() else "cpu"
#obj_embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device=device)
# Load a pre-trained embedding model with GPU support
obj_embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cuda")
def process_files_to_vectors(v_folder_path, v_update=False, v_existing_index_path=None):
"""
Processes files to create or update a FAISS vector database.
Args:
v_folder_path (str): Path to the folder containing input files.
v_update (bool): Whether to update an existing vector database.
v_existing_index_path (str): Path to the existing FAISS index file (if updating).
Returns:
str: Path to the folder containing the updated vector database.
"""
v_vector_folder = os.path.join(v_folder_path, 'vectors')
os.makedirs(v_vector_folder, exist_ok=True)
# Initialize FAISS index
if v_update and v_existing_index_path:
v_index = faiss.read_index(v_existing_index_path)
with open(os.path.join(v_vector_folder, 'metadata.json'), 'r') as obj_meta:
import json
v_metadata = json.load(obj_meta)
else:
v_index = faiss.IndexFlatL2(384) # Embedding dimensions = 384
v_metadata = {}
# Process files and update the vector database
for v_root, _, v_files in os.walk(v_folder_path):
for v_file in v_files:
v_file_path = os.path.join(v_root, v_file)
if v_file.endswith(('.pdf', '.pptx', '.csv')):
v_text = extract_text_from_file(v_file_path)
v_embeddings = obj_embedding_model.encode([v_text], convert_to_tensor=True).cpu().numpy()
v_index.add(v_embeddings)
v_metadata[len(v_index) - 1] = v_file_path
# Save the updated index and metadata
v_index_path = os.path.join(v_vector_folder, 'vector_index.faiss')
faiss.write_index(v_index, v_index_path)
with open(os.path.join(v_vector_folder, 'metadata.json'), 'w') as obj_meta:
import json
json.dump(v_metadata, obj_meta, indent=4)
return v_vector_folder