IBMHackRAG / rag.py
BaRiDo's picture
Update rag.py
f94d355 verified
import os
import random
import string
import json
import gzip
import chromadb
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference, Rerank
from ibm_watsonx_ai.foundation_models.embeddings.sentence_transformer_embeddings import SentenceTransformerEmbeddings
VECTOR_DB = "c8af7dfa-bcad-46e5-b69d-cd85ce9315d1"
def get_credentials():
"""
Obtain credentials for Watsonx.ai from environment.
"""
return {
"url": "https://us-south.ml.cloud.ibm.com",
"apikey": os.getenv("IBM_API_KEY")
}
def rerank(client, documents, query, top_n):
"""
Rerank a list of documents given a query using the Rerank model.
Returns the documents in a new order (highest relevance first).
"""
reranker = Rerank(
model_id="cross-encoder/ms-marco-minilm-l-12-v2",
api_client=client,
params={
"return_options": {
"top_n": top_n
},
"truncate_input_tokens": 512
}
)
reranked_results = reranker.generate(query=query, inputs=documents)["results"]
# Build the new list of documents
new_documents = []
for result in reranked_results:
result_index = result["index"]
new_documents.append(documents[result_index])
return new_documents
def RAGinit():
"""
Initialize:
- Watsonx.ai Client
- Foundation Model
- Embeddings
- ChromaDB Collection
- Vector index properties
- Top N for query
Returns all objects/values needed by RAG_proximity_search.
"""
# Project/Space from environment
project_id = os.getenv("IBM_PROJECT_ID")
space_id = os.getenv("IBM_SPACE_ID")
# Watsonx.ai client
wml_credentials = get_credentials()
client = APIClient(credentials=wml_credentials, project_id=project_id)
# Model Inference
model_inference_params = {
"decoding_method": "greedy",
"max_new_tokens": 900,
"min_new_tokens": 0,
"repetition_penalty": 1
}
model = ModelInference(
model_id="ibm/granite-3-8b-instruct",
params=model_inference_params,
credentials=get_credentials(),
project_id=project_id,
space_id=space_id
)
# Vector index details
vector_index_id =VECTOR_DB
vector_index_details = client.data_assets.get_details(vector_index_id)
vector_index_properties = vector_index_details["entity"]["vector_index"]
# Decide how many results to return
top_n = 20 if vector_index_properties["settings"].get("rerank") \
else int(vector_index_properties["settings"]["top_k"])
# Embedding model
emb = SentenceTransformerEmbeddings('sentence-transformers/all-MiniLM-L6-v2')
# Hydrate ChromaDB with embeddings from the vector index
chroma_collection = _hydrate_chromadb(client, vector_index_id)
return client, model, emb, chroma_collection, vector_index_properties, top_n
def _hydrate_chromadb(client, vector_index_id):
"""
Helper function to retrieve the stored embedding data from Watsonx.ai,
then create (or reset) and populate a ChromaDB collection.
"""
data = client.data_assets.get_content(vector_index_id)
content = gzip.decompress(data)
stringified_vectors = content.decode("utf-8")
vectors = json.loads(stringified_vectors)
# Use a Persistent ChromaDB client (on-disk)
chroma_client = chromadb.PersistentClient(path="./chroma_db")
# Create or clear the collection
collection_name = "my_collection"
try:
chroma_client.delete_collection(name=collection_name)
except:
print("Collection didn't exist - nothing to do.")
collection = chroma_client.create_collection(name=collection_name)
# Prepare data for insertion
vector_embeddings = []
vector_documents = []
vector_metadatas = []
vector_ids = []
for vector in vectors:
embedding = vector["embedding"]
content = vector["content"]
metadata = vector["metadata"]
lines = metadata["loc"]["lines"]
vector_embeddings.append(embedding)
vector_documents.append(content)
clean_metadata = {
"asset_id": metadata["asset_id"],
"asset_name": metadata["asset_name"],
"url": metadata["url"],
"from": lines["from"],
"to": lines["to"]
}
vector_metadatas.append(clean_metadata)
# Generate unique ID
asset_id = metadata["asset_id"]
random_string = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
doc_id = f"{asset_id}:{lines['from']}-{lines['to']}-{random_string}"
vector_ids.append(doc_id)
# Add all data to the collection
collection.add(
embeddings=vector_embeddings,
documents=vector_documents,
metadatas=vector_metadatas,
ids=vector_ids
)
return collection
def RAG_proximity_search(question, client, model, emb, chroma_collection, vector_index_properties, top_n):
"""
Execute a proximity search in the ChromaDB collection for the given question.
Optionally rerank results if specified in the vector index properties.
Returns a concatenated string of best matching documents.
"""
# Embed query
query_vectors = emb.embed_query(question)
# Query top_n results from ChromaDB
query_result = chroma_collection.query(
query_embeddings=query_vectors,
n_results=top_n,
include=["documents", "metadatas", "distances"]
)
# Documents come back in ascending distance, so best match is index=0
documents = query_result["documents"][0]
# If rerank is enabled, reorder the documents
if vector_index_properties["settings"].get("rerank"):
documents = rerank(client, documents, question, vector_index_properties["settings"]["top_k"])
# Return them as a single string
return "\n".join(documents)