File size: 1,470 Bytes
e04cd14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.retrievers.multi_query import MultiQueryRetriever

# Set logging for the queries
import logging

logging.basicConfig()


class MultiQueryDocumentRetriever:
    def __init__(self, vector_store):
        self.vector_store = vector_store
        self.retriever = None
        self.llm = None
        # self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd"

    def initialize(self):
        # self.llama = LlamaAPI(self.token)
        self.llm = HuggingFacePipeline.from_model_id(
            # model_id="bigscience/bloom-1b7",
            model_id="bigscience/bloomz-1b7",
            task="text-generation",
            # device=1,
            # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
            model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
            # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
            pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
        )

        logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
        self.retriever = MultiQueryRetriever.from_llm(
            retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}),
            llm=self.llm
        )

    def retrieve(self, query: str, k: int = 4):
        pass