Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from typing import List | |
from langchain_core.documents import Document | |
from langchain_core.vectorstores import VectorStore | |
from langchain_milvus import Milvus | |
from sandbox.light_rag.hf_embedding import HFEmbedding | |
from sandbox.light_rag.hf_llm import HFLLM | |
context_template = "Document:\n{document}\n" | |
token_limit = 4096 | |
logger = logging.getLogger() | |
class LightRAG: | |
def __init__(self, config: dict): | |
self.config = config | |
lazy_loading = os.environ.get("LAZY_LOADING") | |
self.gen_model = None if lazy_loading else HFLLM(config['generation_model_id']) | |
self._embedding_model = None if lazy_loading else HFEmbedding(config['embedding_model_id']) | |
# self._vector_store = None | |
self._pre_cached_indices = {} | |
# now lazy: | |
# Milvus( | |
# embedding_function=self._embedding_model, | |
# collection_name=config['milvus_collection_name'].replace("-", "_"), | |
# index_params={"metric_ttpe": "cosine".upper()}, | |
# # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
# connection_args = ({"uri": config['milvus_db_path']}) | |
# ) | |
def _get_embedding_model(self): | |
if self._embedding_model is None: | |
self._embedding_model = HFEmbedding(self.config['embedding_model_id']) | |
return self._embedding_model | |
def precache_milvus(self, collection, db): | |
# col_name = self.config["milvus_collection_name"] if collection is None else collection | |
# db = self.config["milvus_db_path"] if db is None else db | |
key = self._cache_key(collection, db) | |
self._pre_cached_indices[key] = Milvus( | |
embedding_function=self._get_embedding_model(), | |
collection_name=collection.replace("-", "_"), | |
index_params={"metric_ttpe": "cosine".upper()}, | |
# connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
connection_args=({"uri": db}), | |
) | |
def _get_milvus_index(self, collection, db): | |
key = self._cache_key(collection, db) | |
if key in self._pre_cached_indices: | |
print(f"cache hit: {key}") | |
return self._pre_cached_indices[key] | |
else: | |
return Milvus( | |
embedding_function=self._get_embedding_model(), | |
collection_name=collection.replace("-", "_"), | |
index_params={"metric_ttpe": "cosine".upper()}, | |
# connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
connection_args=({"uri": db}), | |
) | |
def search(self, query: str, top_n: int = 5, collection=None, db=None) -> list[Document]: | |
# if self._vector_store is None: | |
# TODO: be more clever :) | |
col_name = self.config["milvus_collection_name"] if collection is None else collection | |
db = self.config["milvus_db_path"] if db is None else db | |
# print(f"col_name: {col_name} on db: {db}") | |
vs = self._get_milvus_index(col_name, db) | |
# self._vector_store = Milvus( | |
# embedding_function=self._get_embedding_model(), | |
# collection_name=col_name.replace("-", "_"), | |
# index_params={"metric_ttpe": "cosine".upper()}, | |
# # connection_args = ({"uri": "./milvus/text/milvus.db"}) | |
# connection_args=({"uri": db}), | |
# ) | |
context = vs.similarity_search( | |
query=query, | |
k=100, | |
) | |
results = [] | |
for d in context: | |
if d.metadata.get("type") == "text": # and not ("Picture placeholder" in d.page_content): | |
results.append(d) | |
elif d.metadata.get("type") == "image_description": | |
if not any(r.metadata["document_id"] == d.metadata.get("document_id") for r in results): | |
results.append(d) | |
top_n = min(top_n, len(results)) | |
return results[:top_n] | |
def _build_prompt(self, question: str, context: List[Document]): | |
# Prepare documents: | |
text_documents = [] | |
for doc in context: | |
if doc.metadata['type'] == 'text': | |
text_documents.append(doc.page_content.strip()) | |
elif doc.metadata['type'] == 'image_description': | |
text_documents.append(doc.metadata['image_description'].strip()) | |
else: | |
logger.warning('Should not get here!') | |
documents = [{"text": x} for x in text_documents] | |
prompt = self.gen_model.tokenizer.apply_chat_template( | |
conversation=[ | |
{ | |
"role": "user", | |
"content": question, | |
} | |
], | |
documents=documents, # This uses the documents support in the Granite chat template | |
add_generation_prompt=True, | |
tokenize=False, | |
) | |
return prompt | |
def generate(self, query, context=None): | |
if self.gen_model is None: | |
self.gen_model = HFLLM(self.config["generation_model_id"]) | |
# build prompt | |
question = query | |
prompt = self._build_prompt(question, context) | |
# print(f"prompt: |||{prompt}|||") | |
# infer | |
results = self.gen_model.generate(prompt) | |
# print(f"results: {results}") | |
answer = results[0]["answer"] | |
return answer, prompt | |
def _cache_key(self, collection, db): | |
return collection + "___" + db | |
# if __name__ == '__main__': | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
# | |
# config = { | |
# "embedding_model_id": "ibm-granite/granite-embedding-125m-english", | |
# "generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
# "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine", | |
# "milvus_db_path": "/dccstor/mm-rag/adi/code/RAGEval/milvus/text/milvus.db" | |
# } | |
# | |
# rag_app = LightRAG(config) | |
# | |
# query = "What models are available in Watsonx?" | |
# | |
# # run retrieval | |
# context = rag_app.search(query=query, top_n=5) | |
# # generate answers | |
# answer, prompt = rag_app.generate(query=query, context=context) | |
# | |
# print(f"Answer:\n{answer}") | |
# print(f"Used prompt:\n{prompt}") | |
# python -m debugpy --connect cccxl009.pok.ibm.com:3002 ./sandbox/light_rag/light_rag.py | |