Tamil Eniyan commited on
Commit
1acd638
·
1 Parent(s): d01a902
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -4,7 +4,12 @@ import numpy as np
4
  import pickle
5
  import json
6
  from sentence_transformers import SentenceTransformer
7
- from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
 
 
 
8
  import torch
9
 
10
  # ========================
@@ -102,13 +107,11 @@ class CustomRagRetriever(RagRetriever):
102
  self.embed_model = embed_model # Embedding model used for encoding queries
103
  self.n_docs = n_docs # Number of top documents to retrieve
104
  self.tokenizer = tokenizer # Save tokenizer for internal use if needed
105
- # Call the base class constructor with the required arguments.
 
 
106
  super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
107
 
108
- def init_retrieval(self):
109
- # Override to bypass loading the default DPR passages.
110
- return
111
-
112
  def retrieve(self, query, n_docs=None):
113
  try:
114
  if n_docs is None:
 
4
  import pickle
5
  import json
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import (
8
+ pipeline,
9
+ RagTokenizer,
10
+ RagRetriever,
11
+ RagSequenceForGeneration,
12
+ )
13
  import torch
14
 
15
  # ========================
 
107
  self.embed_model = embed_model # Embedding model used for encoding queries
108
  self.n_docs = n_docs # Number of top documents to retrieve
109
  self.tokenizer = tokenizer # Save tokenizer for internal use if needed
110
+ # Override init_retrieval to bypass loading default passages.
111
+ self.init_retrieval = lambda: None
112
+ # Call the parent constructor with the required arguments.
113
  super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
114
 
 
 
 
 
115
  def retrieve(self, query, n_docs=None):
116
  try:
117
  if n_docs is None: