Spaces:
Running
Running
File size: 2,398 Bytes
a2c10b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
from neo4j import GraphDatabase
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
load_index_from_storage,
Settings,
)
from llama_index.llms.ollama import Ollama
from llama_index.vector_stores.neo4jvector import Neo4jVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
class GraphRAGRetriever:
def __init__(self, neo4j_url, neo4j_username, neo4j_password):
# Set up the embedding model
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Set up the Ollama LLM
self.llm = Ollama(model="gemma:2b", request_timeout=None)
# Configure Settings
Settings.llm = self.llm
Settings.embed_model = self.embed_model
# Set up the Neo4j driver
self.driver = GraphDatabase.driver(neo4j_url, auth=(neo4j_username, neo4j_password))
# Set up the Neo4j vector store
self.vector_store = Neo4jVectorStore(
url=neo4j_url,
username=neo4j_username,
password=neo4j_password,
embedding_dimension=384, # Matches MiniLM model
driver=self.driver
)
def ingest_documents(self, directory_path):
# Load documents from the specified directory
documents = SimpleDirectoryReader(directory_path).load_data()
# Create the vector index
index = VectorStoreIndex.from_documents(
documents,
vector_store=self.vector_store,
)
# Persist the index to disk
index.storage_context.persist()
def query(self, question):
# Load the index from storage
storage_context = StorageContext.from_defaults(persist_dir="./storage")
index = load_index_from_storage(storage_context)
# Create a query engine and execute the query
query_engine = index.as_query_engine()
response = query_engine.query(question)
return str(response)
if __name__ == "__main__":
retriever = GraphRAGRetriever(
neo4j_url="bolt://localhost:7687/",
neo4j_username="neo4j",
neo4j_password=os.getenv("NEO4J_PASSWORD")
)
retriever.ingest_documents("/home/bapary/Music/AI Finance Agent/rag/data")
answer = retriever.query("What is the revenue of Company Microsoft in 2021?")
print(answer)
|