vaishnav commited on
Commit
e19d910
·
1 Parent(s): 3a0580c

update gradio sdk and add lfu caching

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +11 -16
  3. caching/lfu.py +43 -0
  4. llm_setup/llm_setup.py +8 -7
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -13,24 +13,30 @@ config.set_envs() # Set environment variables using the config module
13
  store = stores.chroma.ChromaDB(config.EMBEDDINGS)
14
  service = services.scraper.Service(store)
15
 
16
-
17
  # Scrape data and get the store vector retriever
18
  service.scrape_and_get_store_vector_retriever(config.URLS)
19
 
20
  # Initialize the LLMService with logger, prompt, and store vector retriever
21
  llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
22
 
23
- def respond(user_input, history):
24
  if user_input == "clear_chat_history_aisdb_override":
25
  llm_svc.store={}
26
  return "Memory Cache cleared"
27
  response = llm_svc.conversational_rag_chain().invoke(
28
  {"input": user_input},
29
- config={"configurable": {"session_id": "abc"}},
30
  )["answer"]
31
 
32
  return response
33
 
 
 
 
 
 
 
 
34
 
35
  def on_reset_button_click():
36
  llm_svc.store={}
@@ -40,18 +46,7 @@ if __name__ == '__main__':
40
  logging.info("Starting AIVIz Bot")
41
 
42
  with gr.Blocks() as demo:
43
- gr.Markdown("# 🚢 AIVIz Bot - Vessel Trajectory Prediction")
44
- gr.Markdown("Welcome! Ask me anything about vessel tracking, AI models.")
45
-
46
- with gr.Row():
47
- chat_interface = gr.ChatInterface(fn=respond)
48
-
49
- with gr.Row():
50
- reset_button = gr.Button("🔄 Reset Chat Memory Cache")
51
- reset_status = gr.Textbox(label="Status", interactive=False)
52
-
53
- # Bind reset button to function
54
- reset_button.click(fn=on_reset_button_click, outputs=reset_status)
55
 
56
  # Launch the interface
57
- demo.launch(share=True)
 
13
  store = stores.chroma.ChromaDB(config.EMBEDDINGS)
14
  service = services.scraper.Service(store)
15
 
 
16
  # Scrape data and get the store vector retriever
17
  service.scrape_and_get_store_vector_retriever(config.URLS)
18
 
19
  # Initialize the LLMService with logger, prompt, and store vector retriever
20
  llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
21
 
22
+ def respond(user_input,session_hash):
23
  if user_input == "clear_chat_history_aisdb_override":
24
  llm_svc.store={}
25
  return "Memory Cache cleared"
26
  response = llm_svc.conversational_rag_chain().invoke(
27
  {"input": user_input},
28
+ config={"configurable": {"session_id": session_hash}},
29
  )["answer"]
30
 
31
  return response
32
 
33
+ def echo(text, chat_history, request: gr.Request):
34
+ if request:
35
+ session_hash = request.session_hash
36
+ return respond(text, session_hash)
37
+ else:
38
+ return "No request object received."
39
+
40
 
41
  def on_reset_button_click():
42
  llm_svc.store={}
 
46
  logging.info("Starting AIVIz Bot")
47
 
48
  with gr.Blocks() as demo:
49
+ gr.ChatInterface(fn=echo, type="messages")
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Launch the interface
52
+ demo.launch()
caching/lfu.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, OrderedDict
2
+
3
+ class LFUCache:
4
+ def __init__(self, capacity: int):
5
+ self.capacity = capacity
6
+ self.data = {} # session_id -> (value, freq)
7
+ self.freq_map = defaultdict(OrderedDict) # freq -> {session_id: None}
8
+ self.min_freq = 0
9
+
10
+ def _update_freq(self, session_id):
11
+ value, freq = self.data[session_id]
12
+ del self.freq_map[freq][session_id]
13
+ if not self.freq_map[freq]:
14
+ del self.freq_map[freq]
15
+ if self.min_freq == freq:
16
+ self.min_freq += 1
17
+
18
+ new_freq = freq + 1
19
+ self.data[session_id] = (value, new_freq)
20
+ self.freq_map[new_freq][session_id] = None
21
+
22
+ def get(self, session_id):
23
+ if session_id not in self.data:
24
+ return None
25
+ self._update_freq(session_id)
26
+ return self.data[session_id][0]
27
+
28
+ def put(self, session_id, value):
29
+ if self.capacity == 0:
30
+ return
31
+
32
+ if session_id in self.data:
33
+ self.data[session_id] = (value, self.data[session_id][1])
34
+ self._update_freq(session_id)
35
+ else:
36
+ if len(self.data) >= self.capacity:
37
+ # Evict the least frequently used item
38
+ lfu_session_id, _ = self.freq_map[self.min_freq].popitem(last=False)
39
+ del self.data[lfu_session_id]
40
+
41
+ self.data[session_id] = (value, 1)
42
+ self.freq_map[1][session_id] = None
43
+ self.min_freq = 1
llm_setup/llm_setup.py CHANGED
@@ -12,7 +12,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
12
  from langchain_community.chat_message_histories import ChatMessageHistory
13
  from langchain_core.runnables.history import RunnableWithMessageHistory
14
  from processing.documents import format_documents
15
-
16
 
17
  def _initialize_llm() -> ChatGoogleGenerativeAI:
18
  """
@@ -23,7 +23,7 @@ def _initialize_llm() -> ChatGoogleGenerativeAI:
23
 
24
 
25
  class LLMService:
26
- def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
27
  self._conversational_rag_chain = None
28
  self._logger = logger
29
  self.system_prompt = system_prompt
@@ -34,7 +34,7 @@ class LLMService:
34
  self._initialize_conversational_rag_chain()
35
 
36
  ### Statefully manage chat history ###
37
- self.store = {}
38
 
39
  def _initialize_conversational_rag_chain(self):
40
  """
@@ -55,7 +55,6 @@ class LLMService:
55
  )
56
 
57
 
58
-
59
  history_aware_retriever = create_history_aware_retriever(
60
  self.llm, self._web_retriever, contextualize_q_prompt)
61
 
@@ -79,9 +78,11 @@ class LLMService:
79
  )
80
 
81
  def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
82
- if session_id not in self.store:
83
- self.store[session_id] = ChatMessageHistory()
84
- return self.store[session_id]
 
 
85
 
86
  def conversational_rag_chain(self):
87
  """
 
12
  from langchain_community.chat_message_histories import ChatMessageHistory
13
  from langchain_core.runnables.history import RunnableWithMessageHistory
14
  from processing.documents import format_documents
15
+ from caching.lfu import LFUCache
16
 
17
  def _initialize_llm() -> ChatGoogleGenerativeAI:
18
  """
 
23
 
24
 
25
  class LLMService:
26
+ def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever,cache_capacity: int = 50):
27
  self._conversational_rag_chain = None
28
  self._logger = logger
29
  self.system_prompt = system_prompt
 
34
  self._initialize_conversational_rag_chain()
35
 
36
  ### Statefully manage chat history ###
37
+ self.store = LFUCache(capacity=cache_capacity)
38
 
39
  def _initialize_conversational_rag_chain(self):
40
  """
 
55
  )
56
 
57
 
 
58
  history_aware_retriever = create_history_aware_retriever(
59
  self.llm, self._web_retriever, contextualize_q_prompt)
60
 
 
78
  )
79
 
80
  def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
81
+ history = self.store.get(session_id)
82
+ if history is None:
83
+ history = ChatMessageHistory()
84
+ self.store.put(session_id, history)
85
+ return history
86
 
87
  def conversational_rag_chain(self):
88
  """