RTE Build
Deployment
a099612
import logging
import os
from typing import List
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from langchain_milvus import Milvus
from sandbox.light_rag.hf_embedding import HFEmbedding
from sandbox.light_rag.hf_llm import HFLLM
context_template = "Document:\n{document}\n"
token_limit = 4096
logger = logging.getLogger()
class LightRAG:
def __init__(self, config: dict):
self.config = config
lazy_loading = os.environ.get("LAZY_LOADING")
self.gen_model = None if lazy_loading else HFLLM(config['generation_model_id'])
self._embedding_model = None if lazy_loading else HFEmbedding(config['embedding_model_id'])
# self._vector_store = None
self._pre_cached_indices = {}
# now lazy:
# Milvus(
# embedding_function=self._embedding_model,
# collection_name=config['milvus_collection_name'].replace("-", "_"),
# index_params={"metric_ttpe": "cosine".upper()},
# # connection_args = ({"uri": "./milvus/text/milvus.db"})
# connection_args = ({"uri": config['milvus_db_path']})
# )
def _get_embedding_model(self):
if self._embedding_model is None:
self._embedding_model = HFEmbedding(self.config['embedding_model_id'])
return self._embedding_model
def precache_milvus(self, collection, db):
# col_name = self.config["milvus_collection_name"] if collection is None else collection
# db = self.config["milvus_db_path"] if db is None else db
key = self._cache_key(collection, db)
self._pre_cached_indices[key] = Milvus(
embedding_function=self._get_embedding_model(),
collection_name=collection.replace("-", "_"),
index_params={"metric_ttpe": "cosine".upper()},
# connection_args = ({"uri": "./milvus/text/milvus.db"})
connection_args=({"uri": db}),
)
def _get_milvus_index(self, collection, db):
key = self._cache_key(collection, db)
if key in self._pre_cached_indices:
print(f"cache hit: {key}")
return self._pre_cached_indices[key]
else:
return Milvus(
embedding_function=self._get_embedding_model(),
collection_name=collection.replace("-", "_"),
index_params={"metric_ttpe": "cosine".upper()},
# connection_args = ({"uri": "./milvus/text/milvus.db"})
connection_args=({"uri": db}),
)
def search(self, query: str, top_n: int = 5, collection=None, db=None) -> list[Document]:
# if self._vector_store is None:
# TODO: be more clever :)
col_name = self.config["milvus_collection_name"] if collection is None else collection
db = self.config["milvus_db_path"] if db is None else db
# print(f"col_name: {col_name} on db: {db}")
vs = self._get_milvus_index(col_name, db)
# self._vector_store = Milvus(
# embedding_function=self._get_embedding_model(),
# collection_name=col_name.replace("-", "_"),
# index_params={"metric_ttpe": "cosine".upper()},
# # connection_args = ({"uri": "./milvus/text/milvus.db"})
# connection_args=({"uri": db}),
# )
context = vs.similarity_search(
query=query,
k=100,
)
results = []
for d in context:
if d.metadata.get("type") == "text": # and not ("Picture placeholder" in d.page_content):
results.append(d)
elif d.metadata.get("type") == "image_description":
if not any(r.metadata["document_id"] == d.metadata.get("document_id") for r in results):
results.append(d)
top_n = min(top_n, len(results))
return results[:top_n]
def _build_prompt(self, question: str, context: List[Document]):
# Prepare documents:
text_documents = []
for doc in context:
if doc.metadata['type'] == 'text':
text_documents.append(doc.page_content.strip())
elif doc.metadata['type'] == 'image_description':
text_documents.append(doc.metadata['image_description'].strip())
else:
logger.warning('Should not get here!')
documents = [{"text": x} for x in text_documents]
prompt = self.gen_model.tokenizer.apply_chat_template(
conversation=[
{
"role": "user",
"content": question,
}
],
documents=documents, # This uses the documents support in the Granite chat template
add_generation_prompt=True,
tokenize=False,
)
return prompt
def generate(self, query, context=None):
if self.gen_model is None:
self.gen_model = HFLLM(self.config["generation_model_id"])
# build prompt
question = query
prompt = self._build_prompt(question, context)
# print(f"prompt: |||{prompt}|||")
# infer
results = self.gen_model.generate(prompt)
# print(f"results: {results}")
answer = results[0]["answer"]
return answer, prompt
def _cache_key(self, collection, db):
return collection + "___" + db
# if __name__ == '__main__':
# from dotenv import load_dotenv
# load_dotenv()
#
# config = {
# "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
# "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
# "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine",
# "milvus_db_path": "/dccstor/mm-rag/adi/code/RAGEval/milvus/text/milvus.db"
# }
#
# rag_app = LightRAG(config)
#
# query = "What models are available in Watsonx?"
#
# # run retrieval
# context = rag_app.search(query=query, top_n=5)
# # generate answers
# answer, prompt = rag_app.generate(query=query, context=context)
#
# print(f"Answer:\n{answer}")
# print(f"Used prompt:\n{prompt}")
# python -m debugpy --connect cccxl009.pok.ibm.com:3002 ./sandbox/light_rag/light_rag.py