Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,246 Bytes
a099612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import logging
import os
from typing import List
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from langchain_milvus import Milvus
from sandbox.light_rag.hf_embedding import HFEmbedding
from sandbox.light_rag.hf_llm import HFLLM
context_template = "Document:\n{document}\n"
token_limit = 4096
logger = logging.getLogger()
class LightRAG:
def __init__(self, config: dict):
self.config = config
lazy_loading = os.environ.get("LAZY_LOADING")
self.gen_model = None if lazy_loading else HFLLM(config['generation_model_id'])
self._embedding_model = None if lazy_loading else HFEmbedding(config['embedding_model_id'])
# self._vector_store = None
self._pre_cached_indices = {}
# now lazy:
# Milvus(
# embedding_function=self._embedding_model,
# collection_name=config['milvus_collection_name'].replace("-", "_"),
# index_params={"metric_ttpe": "cosine".upper()},
# # connection_args = ({"uri": "./milvus/text/milvus.db"})
# connection_args = ({"uri": config['milvus_db_path']})
# )
def _get_embedding_model(self):
if self._embedding_model is None:
self._embedding_model = HFEmbedding(self.config['embedding_model_id'])
return self._embedding_model
def precache_milvus(self, collection, db):
# col_name = self.config["milvus_collection_name"] if collection is None else collection
# db = self.config["milvus_db_path"] if db is None else db
key = self._cache_key(collection, db)
self._pre_cached_indices[key] = Milvus(
embedding_function=self._get_embedding_model(),
collection_name=collection.replace("-", "_"),
index_params={"metric_ttpe": "cosine".upper()},
# connection_args = ({"uri": "./milvus/text/milvus.db"})
connection_args=({"uri": db}),
)
def _get_milvus_index(self, collection, db):
key = self._cache_key(collection, db)
if key in self._pre_cached_indices:
print(f"cache hit: {key}")
return self._pre_cached_indices[key]
else:
return Milvus(
embedding_function=self._get_embedding_model(),
collection_name=collection.replace("-", "_"),
index_params={"metric_ttpe": "cosine".upper()},
# connection_args = ({"uri": "./milvus/text/milvus.db"})
connection_args=({"uri": db}),
)
def search(self, query: str, top_n: int = 5, collection=None, db=None) -> list[Document]:
# if self._vector_store is None:
# TODO: be more clever :)
col_name = self.config["milvus_collection_name"] if collection is None else collection
db = self.config["milvus_db_path"] if db is None else db
# print(f"col_name: {col_name} on db: {db}")
vs = self._get_milvus_index(col_name, db)
# self._vector_store = Milvus(
# embedding_function=self._get_embedding_model(),
# collection_name=col_name.replace("-", "_"),
# index_params={"metric_ttpe": "cosine".upper()},
# # connection_args = ({"uri": "./milvus/text/milvus.db"})
# connection_args=({"uri": db}),
# )
context = vs.similarity_search(
query=query,
k=100,
)
results = []
for d in context:
if d.metadata.get("type") == "text": # and not ("Picture placeholder" in d.page_content):
results.append(d)
elif d.metadata.get("type") == "image_description":
if not any(r.metadata["document_id"] == d.metadata.get("document_id") for r in results):
results.append(d)
top_n = min(top_n, len(results))
return results[:top_n]
def _build_prompt(self, question: str, context: List[Document]):
# Prepare documents:
text_documents = []
for doc in context:
if doc.metadata['type'] == 'text':
text_documents.append(doc.page_content.strip())
elif doc.metadata['type'] == 'image_description':
text_documents.append(doc.metadata['image_description'].strip())
else:
logger.warning('Should not get here!')
documents = [{"text": x} for x in text_documents]
prompt = self.gen_model.tokenizer.apply_chat_template(
conversation=[
{
"role": "user",
"content": question,
}
],
documents=documents, # This uses the documents support in the Granite chat template
add_generation_prompt=True,
tokenize=False,
)
return prompt
def generate(self, query, context=None):
if self.gen_model is None:
self.gen_model = HFLLM(self.config["generation_model_id"])
# build prompt
question = query
prompt = self._build_prompt(question, context)
# print(f"prompt: |||{prompt}|||")
# infer
results = self.gen_model.generate(prompt)
# print(f"results: {results}")
answer = results[0]["answer"]
return answer, prompt
def _cache_key(self, collection, db):
return collection + "___" + db
# if __name__ == '__main__':
# from dotenv import load_dotenv
# load_dotenv()
#
# config = {
# "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
# "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
# "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine",
# "milvus_db_path": "/dccstor/mm-rag/adi/code/RAGEval/milvus/text/milvus.db"
# }
#
# rag_app = LightRAG(config)
#
# query = "What models are available in Watsonx?"
#
# # run retrieval
# context = rag_app.search(query=query, top_n=5)
# # generate answers
# answer, prompt = rag_app.generate(query=query, context=context)
#
# print(f"Answer:\n{answer}")
# print(f"Used prompt:\n{prompt}")
# python -m debugpy --connect cccxl009.pok.ibm.com:3002 ./sandbox/light_rag/light_rag.py
|