gavinzli commited on
Commit
e83b975
·
1 Parent(s): c529966

Enhance model integration and error handling in retriever module

Browse files
Files changed (5) hide show
  1. chain/__init__.py +27 -1
  2. main.py +5 -5
  3. models/llm/__init__.py +71 -0
  4. retriever/__init__.py +28 -22
  5. token.pickle +0 -0
chain/__init__.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  from datetime import datetime
5
  from venv import logger
6
 
 
7
  from pymongo import errors
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
9
  from langchain_core.messages import BaseMessage, message_to_dict
@@ -11,10 +12,35 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
11
  from langchain.chains.retrieval import create_retrieval_chain
12
  from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
13
  from langchain_mongodb import MongoDBChatMessageHistory
 
14
 
15
- from models.llm import GPTModel
16
 
17
  llm = GPTModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
20
  You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
 
4
  from datetime import datetime
5
  from venv import logger
6
 
7
+ import torch
8
  from pymongo import errors
9
  from langchain_core.runnables.history import RunnableWithMessageHistory
10
  from langchain_core.messages import BaseMessage, message_to_dict
 
12
  from langchain.chains.retrieval import create_retrieval_chain
13
  from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain_mongodb import MongoDBChatMessageHistory
15
+ from langchain_huggingface import HuggingFacePipeline
16
 
17
+ from models.llm import GPTModel, Phi4MiniONNXLLM, HuggingfaceModel
18
 
19
  llm = GPTModel()
20
+ REPO_ID = "microsoft/Phi-4-mini-instruct-onnx"
21
+ SUBFOLDER = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
22
+ phi4_llm = Phi4MiniONNXLLM(REPO_ID, SUBFOLDER)
23
+
24
+ MODEL_NAME = "openai-community/gpt2"
25
+ MODEL_NAME = "microsoft/phi-1_5"
26
+ hf_llm = HuggingfaceModel(MODEL_NAME)
27
+
28
+ phi4_llm = HuggingFacePipeline.from_model_id(
29
+ model_id="microsoft/Phi-4",
30
+ task="text-generation",
31
+ pipeline_kwargs={
32
+ "max_new_tokens": 128,
33
+ "temperature": 0.3,
34
+ "top_k": 50,
35
+ "do_sample": True
36
+ },
37
+ model_kwargs={
38
+ "torch_dtype": "auto",
39
+ "device_map": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
40
+ "max_memory": {0: "10GB"},
41
+ "use_cache": False
42
+ }
43
+ )
44
 
45
  SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
46
  You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
main.py CHANGED
@@ -1,19 +1,19 @@
1
  """Module to run the mail collection process."""
2
  from dotenv import load_dotenv
3
 
4
- from controllers import mail
5
  from chain import RAGChain
6
  from retriever import DocRetriever
7
 
8
  load_dotenv()
9
 
10
  if __name__ == "__main__":
11
- mail.collect()
12
- mail.get_documents()
13
  req = {
14
  "query": "What is the latest news on the stock market?",
15
  }
16
  chain = RAGChain(DocRetriever(req=req))
17
  result = chain.invoke({"input": req['query']},
18
- config={"configurable": {"session_id": "abc"}})
19
- print(result)
 
1
  """Module to run the mail collection process."""
2
  from dotenv import load_dotenv
3
 
4
+ # from controllers import mail
5
  from chain import RAGChain
6
  from retriever import DocRetriever
7
 
8
  load_dotenv()
9
 
10
  if __name__ == "__main__":
11
+ # mail.collect()
12
+ # mail.get_documents()
13
  req = {
14
  "query": "What is the latest news on the stock market?",
15
  }
16
  chain = RAGChain(DocRetriever(req=req))
17
  result = chain.invoke({"input": req['query']},
18
+ config={"configurable": {"session_id": "123"}})
19
+ print(result.get("answer"))
models/llm/__init__.py CHANGED
@@ -1,5 +1,10 @@
1
  """Module for OpenAI model and embeddings."""
 
 
2
  from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
 
 
 
3
 
4
  class GPTModel(AzureChatOpenAI):
5
  """
@@ -31,3 +36,69 @@ class GPTEmbeddings(AzureOpenAIEmbeddings):
31
  Methods:
32
  Inherits all methods from AzureOpenAIEmbeddings.
33
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module for OpenAI model and embeddings."""
2
+ import os
3
+ import onnxruntime as ort
4
  from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
5
+ from langchain_huggingface import HuggingFacePipeline
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from huggingface_hub import hf_hub_download
8
 
9
  class GPTModel(AzureChatOpenAI):
10
  """
 
36
  Methods:
37
  Inherits all methods from AzureOpenAIEmbeddings.
38
  """
39
+
40
+ class Phi4MiniONNXLLM:
41
+ """
42
+ A class for interfacing with a pre-trained ONNX model for inference.
43
+
44
+ Attributes:
45
+ session (onnxruntime.InferenceSession): The ONNX runtime inference session for the model.
46
+ input_name (str): The name of the input node in the ONNX model.
47
+ output_name (str): The name of the output node in the ONNX model.
48
+
49
+ Methods:
50
+ __init__(model_path):
51
+ Initializes the Phi4MiniONNXLLM instance by loading the ONNX model from specified path.
52
+
53
+ __call__(input_ids):
54
+ Performs inference on the given input data and returns the model's output.
55
+ """
56
+ def __init__(self, repo_id, subfolder, onnx_file="model.onnx", weights_file="model.onnx.data"):
57
+ model_path = hf_hub_download(repo_id=repo_id, filename=f"{subfolder}/{onnx_file}")
58
+ weights_path = hf_hub_download(repo_id=repo_id, filename=f"{subfolder}/{weights_file}")
59
+ self.session = ort.InferenceSession(model_path)
60
+ # Verify both files exist
61
+ print(f"Model path: {model_path}, Exists: {os.path.exists(model_path)}")
62
+ print(f"Weights path: {weights_path}, Exists: {os.path.exists(weights_path)}")
63
+ self.input_name = self.session.get_inputs()[0].name
64
+ self.output_name = self.session.get_outputs()[0].name
65
+
66
+ def __call__(self, input_ids):
67
+ # Assuming input_ids is a tensor or numpy array
68
+ outputs = self.session.run([self.output_name], {self.input_name: input_ids})
69
+ return outputs[0]
70
+
71
+ class HuggingfaceModel(HuggingFacePipeline):
72
+ """
73
+ HuggingfaceModel is a wrapper class for the Hugging Face text-generation pipeline.
74
+
75
+ Attributes:
76
+ name (str): The name or path of the pre-trained model to load from Hugging Face.
77
+ max_tokens (int): The maximum number of new tokens to generate in the text output.
78
+ Defaults to 200.
79
+
80
+ Methods:
81
+ __init__(name, max_tokens=200):
82
+ Initializes the HuggingfaceModel with the specified model name and maximum token limit.
83
+ """
84
+ def __init__(self, name, max_tokens=200):
85
+ super().__init__(pipeline=pipeline(
86
+ "text-generation",
87
+ model=AutoModelForCausalLM.from_pretrained(name),
88
+ tokenizer=AutoTokenizer.from_pretrained(name),
89
+ max_new_tokens=max_tokens))
90
+
91
+ # model_name = "microsoft/phi-1_5"
92
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
94
+ # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
95
+
96
+ # phi4_llm = HuggingFacePipeline(pipeline=pipe)
97
+
98
+ # tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", pad_token_id=50256)
99
+ # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
100
+ # pipe = pipeline(
101
+ # "text-generation", model=model, tokenizer=tokenizer,
102
+ # max_new_tokens=10, truncation=True, # Truncate input sequences
103
+ # )
104
+ # phi4_llm = HuggingFacePipeline(pipeline=pipe)
retriever/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  """Module for retrievers that fetch documents from various sources."""
 
2
  from langchain_core.retrievers import BaseRetriever
3
  from langchain_core.vectorstores import VectorStoreRetriever
4
  from langchain_core.documents import Document
@@ -22,9 +23,9 @@ class DocRetriever(BaseRetriever):
22
  list: A list of Document objects with relevant metadata.
23
  """
24
  retriever: VectorStoreRetriever = None
25
- k: int = 10
26
 
27
- def __init__(self, req, k: int = 10) -> None:
28
  super().__init__()
29
  # _filter={}
30
  # if req.site != []:
@@ -32,30 +33,35 @@ class DocRetriever(BaseRetriever):
32
  # if req.id != []:
33
  # _filter.update({"id": {"$in": req.id}})
34
  self.retriever = vectorstore.as_retriever(
35
- search_type='similarity_score_threshold',
36
  search_kwargs={
37
  "k": k,
38
  # "filter": _filter,
39
- "score_threshold": .1
40
  }
41
  )
42
 
43
  def _get_relevant_documents(self, query: str, *, run_manager) -> list:
44
- retrieved_docs = self.retriever.invoke(query)
45
- doc_lst = []
46
- for doc in retrieved_docs:
47
- # date = str(doc.metadata['publishDate'])
48
- doc_lst.append(Document(
49
- page_content = doc.page_content,
50
- metadata = {
51
- "content": doc.page_content,
52
- # "id": doc.metadata['id'],
53
- # "title": doc.metadata['title'],
54
- # "site": doc.metadata['site'],
55
- # "link": doc.metadata['link'],
56
- # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
57
- # 'web': False,
58
- # "source": "Finfast"
59
- }
60
- ))
61
- return doc_lst
 
 
 
 
 
 
1
  """Module for retrievers that fetch documents from various sources."""
2
+ from venv import logger
3
  from langchain_core.retrievers import BaseRetriever
4
  from langchain_core.vectorstores import VectorStoreRetriever
5
  from langchain_core.documents import Document
 
23
  list: A list of Document objects with relevant metadata.
24
  """
25
  retriever: VectorStoreRetriever = None
26
+ k: int = 5
27
 
28
+ def __init__(self, req, k: int = 2) -> None:
29
  super().__init__()
30
  # _filter={}
31
  # if req.site != []:
 
33
  # if req.id != []:
34
  # _filter.update({"id": {"$in": req.id}})
35
  self.retriever = vectorstore.as_retriever(
36
+ search_type='similarity',
37
  search_kwargs={
38
  "k": k,
39
  # "filter": _filter,
40
+ # "score_threshold": .1
41
  }
42
  )
43
 
44
  def _get_relevant_documents(self, query: str, *, run_manager) -> list:
45
+ try:
46
+ retrieved_docs = self.retriever.invoke(query)
47
+ doc_lst = []
48
+ for doc in retrieved_docs:
49
+ # date = str(doc.metadata['publishDate'])
50
+ doc_lst.append(Document(
51
+ page_content = doc.page_content,
52
+ metadata = {
53
+ "content": doc.page_content,
54
+ # "id": doc.metadata['id'],
55
+ # "title": doc.metadata['title'],
56
+ # "site": doc.metadata['site'],
57
+ # "link": doc.metadata['link'],
58
+ # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
59
+ # 'web': False,
60
+ # "source": "Finfast"
61
+ }
62
+ ))
63
+ # print(doc_lst)
64
+ return doc_lst
65
+ except RuntimeError as e:
66
+ logger.error("Error retrieving documents: %s", e)
67
+ return []
token.pickle CHANGED
Binary files a/token.pickle and b/token.pickle differ