Spaces:
Runtime error
Runtime error
File size: 1,470 Bytes
e04cd14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.retrievers.multi_query import MultiQueryRetriever
# Set logging for the queries
import logging
logging.basicConfig()
class MultiQueryDocumentRetriever:
def __init__(self, vector_store):
self.vector_store = vector_store
self.retriever = None
self.llm = None
# self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd"
def initialize(self):
# self.llama = LlamaAPI(self.token)
self.llm = HuggingFacePipeline.from_model_id(
# model_id="bigscience/bloom-1b7",
model_id="bigscience/bloomz-1b7",
task="text-generation",
# device=1,
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
)
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
self.retriever = MultiQueryRetriever.from_llm(
retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}),
llm=self.llm
)
def retrieve(self, query: str, k: int = 4):
pass |