Spaces:
Running
Running
import os | |
import random | |
import string | |
import json | |
import gzip | |
import chromadb | |
from ibm_watsonx_ai.client import APIClient | |
from ibm_watsonx_ai.foundation_models import ModelInference, Rerank | |
from ibm_watsonx_ai.foundation_models.embeddings.sentence_transformer_embeddings import SentenceTransformerEmbeddings | |
VECTOR_DB = "c8af7dfa-bcad-46e5-b69d-cd85ce9315d1" | |
def get_credentials(): | |
""" | |
Obtain credentials for Watsonx.ai from environment. | |
""" | |
return { | |
"url": "https://us-south.ml.cloud.ibm.com", | |
"apikey": os.getenv("IBM_API_KEY") | |
} | |
def rerank(client, documents, query, top_n): | |
""" | |
Rerank a list of documents given a query using the Rerank model. | |
Returns the documents in a new order (highest relevance first). | |
""" | |
reranker = Rerank( | |
model_id="cross-encoder/ms-marco-minilm-l-12-v2", | |
api_client=client, | |
params={ | |
"return_options": { | |
"top_n": top_n | |
}, | |
"truncate_input_tokens": 512 | |
} | |
) | |
reranked_results = reranker.generate(query=query, inputs=documents)["results"] | |
# Build the new list of documents | |
new_documents = [] | |
for result in reranked_results: | |
result_index = result["index"] | |
new_documents.append(documents[result_index]) | |
return new_documents | |
def RAGinit(): | |
""" | |
Initialize: | |
- Watsonx.ai Client | |
- Foundation Model | |
- Embeddings | |
- ChromaDB Collection | |
- Vector index properties | |
- Top N for query | |
Returns all objects/values needed by RAG_proximity_search. | |
""" | |
# Project/Space from environment | |
project_id = os.getenv("IBM_PROJECT_ID") | |
space_id = os.getenv("IBM_SPACE_ID") | |
# Watsonx.ai client | |
wml_credentials = get_credentials() | |
client = APIClient(credentials=wml_credentials, project_id=project_id) | |
# Model Inference | |
model_inference_params = { | |
"decoding_method": "greedy", | |
"max_new_tokens": 900, | |
"min_new_tokens": 0, | |
"repetition_penalty": 1 | |
} | |
model = ModelInference( | |
model_id="ibm/granite-3-8b-instruct", | |
params=model_inference_params, | |
credentials=get_credentials(), | |
project_id=project_id, | |
space_id=space_id | |
) | |
# Vector index details | |
vector_index_id =VECTOR_DB | |
vector_index_details = client.data_assets.get_details(vector_index_id) | |
vector_index_properties = vector_index_details["entity"]["vector_index"] | |
# Decide how many results to return | |
top_n = 20 if vector_index_properties["settings"].get("rerank") \ | |
else int(vector_index_properties["settings"]["top_k"]) | |
# Embedding model | |
emb = SentenceTransformerEmbeddings('sentence-transformers/all-MiniLM-L6-v2') | |
# Hydrate ChromaDB with embeddings from the vector index | |
chroma_collection = _hydrate_chromadb(client, vector_index_id) | |
return client, model, emb, chroma_collection, vector_index_properties, top_n | |
def _hydrate_chromadb(client, vector_index_id): | |
""" | |
Helper function to retrieve the stored embedding data from Watsonx.ai, | |
then create (or reset) and populate a ChromaDB collection. | |
""" | |
data = client.data_assets.get_content(vector_index_id) | |
content = gzip.decompress(data) | |
stringified_vectors = content.decode("utf-8") | |
vectors = json.loads(stringified_vectors) | |
# Use a Persistent ChromaDB client (on-disk) | |
chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
# Create or clear the collection | |
collection_name = "my_collection" | |
try: | |
chroma_client.delete_collection(name=collection_name) | |
except: | |
print("Collection didn't exist - nothing to do.") | |
collection = chroma_client.create_collection(name=collection_name) | |
# Prepare data for insertion | |
vector_embeddings = [] | |
vector_documents = [] | |
vector_metadatas = [] | |
vector_ids = [] | |
for vector in vectors: | |
embedding = vector["embedding"] | |
content = vector["content"] | |
metadata = vector["metadata"] | |
lines = metadata["loc"]["lines"] | |
vector_embeddings.append(embedding) | |
vector_documents.append(content) | |
clean_metadata = { | |
"asset_id": metadata["asset_id"], | |
"asset_name": metadata["asset_name"], | |
"url": metadata["url"], | |
"from": lines["from"], | |
"to": lines["to"] | |
} | |
vector_metadatas.append(clean_metadata) | |
# Generate unique ID | |
asset_id = metadata["asset_id"] | |
random_string = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) | |
doc_id = f"{asset_id}:{lines['from']}-{lines['to']}-{random_string}" | |
vector_ids.append(doc_id) | |
# Add all data to the collection | |
collection.add( | |
embeddings=vector_embeddings, | |
documents=vector_documents, | |
metadatas=vector_metadatas, | |
ids=vector_ids | |
) | |
return collection | |
def RAG_proximity_search(question, client, model, emb, chroma_collection, vector_index_properties, top_n): | |
""" | |
Execute a proximity search in the ChromaDB collection for the given question. | |
Optionally rerank results if specified in the vector index properties. | |
Returns a concatenated string of best matching documents. | |
""" | |
# Embed query | |
query_vectors = emb.embed_query(question) | |
# Query top_n results from ChromaDB | |
query_result = chroma_collection.query( | |
query_embeddings=query_vectors, | |
n_results=top_n, | |
include=["documents", "metadatas", "distances"] | |
) | |
# Documents come back in ascending distance, so best match is index=0 | |
documents = query_result["documents"][0] | |
# If rerank is enabled, reorder the documents | |
if vector_index_properties["settings"].get("rerank"): | |
documents = rerank(client, documents, question, vector_index_properties["settings"]["top_k"]) | |
# Return them as a single string | |
return "\n".join(documents) | |