import logging import os from typing import List from langchain_core.documents import Document from langchain_core.vectorstores import VectorStore from langchain_milvus import Milvus from sandbox.light_rag.hf_embedding import HFEmbedding from sandbox.light_rag.hf_llm import HFLLM context_template = "Document:\n{document}\n" token_limit = 4096 logger = logging.getLogger() class LightRAG: def __init__(self, config: dict): self.config = config lazy_loading = os.environ.get("LAZY_LOADING") self.gen_model = None if lazy_loading else HFLLM(config['generation_model_id']) self._embedding_model = None if lazy_loading else HFEmbedding(config['embedding_model_id']) # self._vector_store = None self._pre_cached_indices = {} # now lazy: # Milvus( # embedding_function=self._embedding_model, # collection_name=config['milvus_collection_name'].replace("-", "_"), # index_params={"metric_ttpe": "cosine".upper()}, # # connection_args = ({"uri": "./milvus/text/milvus.db"}) # connection_args = ({"uri": config['milvus_db_path']}) # ) def _get_embedding_model(self): if self._embedding_model is None: self._embedding_model = HFEmbedding(self.config['embedding_model_id']) return self._embedding_model def precache_milvus(self, collection, db): # col_name = self.config["milvus_collection_name"] if collection is None else collection # db = self.config["milvus_db_path"] if db is None else db key = self._cache_key(collection, db) self._pre_cached_indices[key] = Milvus( embedding_function=self._get_embedding_model(), collection_name=collection.replace("-", "_"), index_params={"metric_ttpe": "cosine".upper()}, # connection_args = ({"uri": "./milvus/text/milvus.db"}) connection_args=({"uri": db}), ) def _get_milvus_index(self, collection, db): key = self._cache_key(collection, db) if key in self._pre_cached_indices: print(f"cache hit: {key}") return self._pre_cached_indices[key] else: return Milvus( embedding_function=self._get_embedding_model(), collection_name=collection.replace("-", "_"), index_params={"metric_ttpe": "cosine".upper()}, # connection_args = ({"uri": "./milvus/text/milvus.db"}) connection_args=({"uri": db}), ) def search(self, query: str, top_n: int = 5, collection=None, db=None) -> list[Document]: # if self._vector_store is None: # TODO: be more clever :) col_name = self.config["milvus_collection_name"] if collection is None else collection db = self.config["milvus_db_path"] if db is None else db # print(f"col_name: {col_name} on db: {db}") vs = self._get_milvus_index(col_name, db) # self._vector_store = Milvus( # embedding_function=self._get_embedding_model(), # collection_name=col_name.replace("-", "_"), # index_params={"metric_ttpe": "cosine".upper()}, # # connection_args = ({"uri": "./milvus/text/milvus.db"}) # connection_args=({"uri": db}), # ) context = vs.similarity_search( query=query, k=100, ) results = [] for d in context: if d.metadata.get("type") == "text": # and not ("Picture placeholder" in d.page_content): results.append(d) elif d.metadata.get("type") == "image_description": if not any(r.metadata["document_id"] == d.metadata.get("document_id") for r in results): results.append(d) top_n = min(top_n, len(results)) return results[:top_n] def _build_prompt(self, question: str, context: List[Document]): # Prepare documents: text_documents = [] for doc in context: if doc.metadata['type'] == 'text': text_documents.append(doc.page_content.strip()) elif doc.metadata['type'] == 'image_description': text_documents.append(doc.metadata['image_description'].strip()) else: logger.warning('Should not get here!') documents = [{"text": x} for x in text_documents] prompt = self.gen_model.tokenizer.apply_chat_template( conversation=[ { "role": "user", "content": question, } ], documents=documents, # This uses the documents support in the Granite chat template add_generation_prompt=True, tokenize=False, ) return prompt def generate(self, query, context=None): if self.gen_model is None: self.gen_model = HFLLM(self.config["generation_model_id"]) # build prompt question = query prompt = self._build_prompt(question, context) # print(f"prompt: |||{prompt}|||") # infer results = self.gen_model.generate(prompt) # print(f"results: {results}") answer = results[0]["answer"] return answer, prompt def _cache_key(self, collection, db): return collection + "___" + db # if __name__ == '__main__': # from dotenv import load_dotenv # load_dotenv() # # config = { # "embedding_model_id": "ibm-granite/granite-embedding-125m-english", # "generation_model_id": "ibm-granite/granite-3.1-8b-instruct", # "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine", # "milvus_db_path": "/dccstor/mm-rag/adi/code/RAGEval/milvus/text/milvus.db" # } # # rag_app = LightRAG(config) # # query = "What models are available in Watsonx?" # # # run retrieval # context = rag_app.search(query=query, top_n=5) # # generate answers # answer, prompt = rag_app.generate(query=query, context=context) # # print(f"Answer:\n{answer}") # print(f"Used prompt:\n{prompt}") # python -m debugpy --connect cccxl009.pok.ibm.com:3002 ./sandbox/light_rag/light_rag.py