Tamil Eniyan
commited on
Commit
·
1acd638
1
Parent(s):
d01a902
mod app
Browse files
app.py
CHANGED
@@ -4,7 +4,12 @@ import numpy as np
|
|
4 |
import pickle
|
5 |
import json
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
8 |
import torch
|
9 |
|
10 |
# ========================
|
@@ -102,13 +107,11 @@ class CustomRagRetriever(RagRetriever):
|
|
102 |
self.embed_model = embed_model # Embedding model used for encoding queries
|
103 |
self.n_docs = n_docs # Number of top documents to retrieve
|
104 |
self.tokenizer = tokenizer # Save tokenizer for internal use if needed
|
105 |
-
#
|
|
|
|
|
106 |
super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
|
107 |
|
108 |
-
def init_retrieval(self):
|
109 |
-
# Override to bypass loading the default DPR passages.
|
110 |
-
return
|
111 |
-
|
112 |
def retrieve(self, query, n_docs=None):
|
113 |
try:
|
114 |
if n_docs is None:
|
|
|
4 |
import pickle
|
5 |
import json
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
from transformers import (
|
8 |
+
pipeline,
|
9 |
+
RagTokenizer,
|
10 |
+
RagRetriever,
|
11 |
+
RagSequenceForGeneration,
|
12 |
+
)
|
13 |
import torch
|
14 |
|
15 |
# ========================
|
|
|
107 |
self.embed_model = embed_model # Embedding model used for encoding queries
|
108 |
self.n_docs = n_docs # Number of top documents to retrieve
|
109 |
self.tokenizer = tokenizer # Save tokenizer for internal use if needed
|
110 |
+
# Override init_retrieval to bypass loading default passages.
|
111 |
+
self.init_retrieval = lambda: None
|
112 |
+
# Call the parent constructor with the required arguments.
|
113 |
super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
|
114 |
|
|
|
|
|
|
|
|
|
115 |
def retrieve(self, query, n_docs=None):
|
116 |
try:
|
117 |
if n_docs is None:
|