Spaces:
Running
Running
Added Callback type annotation
Browse files- pipeline.py +22 -18
pipeline.py
CHANGED
@@ -27,6 +27,8 @@ from langchain.chains import RetrievalQA, LLMChain
|
|
27 |
from langchain.prompts import PromptTemplate
|
28 |
from langchain.docstore.document import Document
|
29 |
|
|
|
|
|
30 |
# Custom chain imports
|
31 |
from classification_chain import get_classification_chain
|
32 |
from refusal_chain import get_refusal_chain
|
@@ -188,33 +190,35 @@ CACHE_SIZE_LIMIT = 1000
|
|
188 |
# logger.error(f"Failed to initialize ChatGroq: {e}")
|
189 |
# raise RuntimeError("ChatGroq initialization failed.") from e
|
190 |
|
191 |
-
# Define a dummy BaseCache class locally
|
192 |
-
class BaseCache:
|
193 |
-
def lookup(self, key: str):
|
194 |
-
return None # Always return None, meaning no cache hit
|
195 |
|
196 |
-
def update(self, key: str, value: str):
|
197 |
-
pass # Do nothing on cache update
|
198 |
|
199 |
-
# Define a no-op
|
200 |
-
class
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
-
#
|
204 |
-
|
205 |
|
206 |
-
#
|
207 |
-
ChatGroq.model_rebuild()
|
208 |
-
|
209 |
-
# Initialize ChatGroq without using caching
|
210 |
-
fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "GROQ_API_KEY")
|
211 |
try:
|
|
|
212 |
groq_fallback_llm = ChatGroq(
|
213 |
-
model=GROQ_MODELS["default"],
|
214 |
temperature=0.7,
|
215 |
groq_api_key=fallback_groq_api_key,
|
216 |
-
max_tokens=2048
|
|
|
217 |
)
|
|
|
|
|
|
|
|
|
218 |
except Exception as e:
|
219 |
logger.error(f"Failed to initialize ChatGroq: {e}")
|
220 |
raise RuntimeError("ChatGroq initialization failed.") from e
|
|
|
27 |
from langchain.prompts import PromptTemplate
|
28 |
from langchain.docstore.document import Document
|
29 |
|
30 |
+
from langchain.callbacks.base import BaseCallbacks # Updated import
|
31 |
+
from langchain.callbacks.manager import CallbackManager
|
32 |
# Custom chain imports
|
33 |
from classification_chain import get_classification_chain
|
34 |
from refusal_chain import get_refusal_chain
|
|
|
190 |
# logger.error(f"Failed to initialize ChatGroq: {e}")
|
191 |
# raise RuntimeError("ChatGroq initialization failed.") from e
|
192 |
|
|
|
|
|
|
|
|
|
193 |
|
|
|
|
|
194 |
|
195 |
+
# Define a no-op callback handler
|
196 |
+
class NoOpCallbacks(BaseCallbacks):
|
197 |
+
"""No-op callback handler."""
|
198 |
+
def on_llm_start(self, *args, **kwargs): pass
|
199 |
+
def on_llm_end(self, *args, **kwargs): pass
|
200 |
+
def on_llm_error(self, *args, **kwargs): pass
|
201 |
+
def on_chain_start(self, *args, **kwargs): pass
|
202 |
+
def on_chain_end(self, *args, **kwargs): pass
|
203 |
+
def on_chain_error(self, *args, **kwargs): pass
|
204 |
|
205 |
+
# Create a callback manager with no-op callbacks
|
206 |
+
callback_manager = CallbackManager([NoOpCallbacks()])
|
207 |
|
208 |
+
# Initialize ChatGroq with the callback manager
|
|
|
|
|
|
|
|
|
209 |
try:
|
210 |
+
fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "GROQ_API_KEY")
|
211 |
groq_fallback_llm = ChatGroq(
|
212 |
+
model=GROQ_MODELS["default"],
|
213 |
temperature=0.7,
|
214 |
groq_api_key=fallback_groq_api_key,
|
215 |
+
max_tokens=2048,
|
216 |
+
callback_manager=callback_manager # Add the callback manager here
|
217 |
)
|
218 |
+
|
219 |
+
# Rebuild the model after initialization
|
220 |
+
ChatGroq.model_rebuild()
|
221 |
+
|
222 |
except Exception as e:
|
223 |
logger.error(f"Failed to initialize ChatGroq: {e}")
|
224 |
raise RuntimeError("ChatGroq initialization failed.") from e
|