Spaces:
Runtime error
Runtime error
from langchain.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
# Set logging for the queries | |
import logging | |
logging.basicConfig() | |
class MultiQueryDocumentRetriever: | |
def __init__(self, vector_store): | |
self.vector_store = vector_store | |
self.retriever = None | |
self.llm = None | |
# self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd" | |
def initialize(self): | |
# self.llama = LlamaAPI(self.token) | |
self.llm = HuggingFacePipeline.from_model_id( | |
# model_id="bigscience/bloom-1b7", | |
model_id="bigscience/bloomz-1b7", | |
task="text-generation", | |
# device=1, | |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2}, | |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2}, | |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
) | |
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) | |
self.retriever = MultiQueryRetriever.from_llm( | |
retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}), | |
llm=self.llm | |
) | |
def retrieve(self, query: str, k: int = 4): | |
pass |