anubhav77 commited on
Commit
9a1d7f1
·
1 Parent(s): ca5636b

move to gemini-1.5-flash

Browse files
Files changed (3) hide show
  1. persistence.log +14 -0
  2. src/chromaIntf.py +119 -91
  3. src/llm/geminiLLM.py +87 -69
persistence.log CHANGED
@@ -64,3 +64,17 @@
64
  2024-01-11 23:38:12,386 - posthog.py - __init__() - 20 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
65
  2024-01-11 23:38:14,482 - llmFactory.py - get_llm() - 36 - DEBUG - executor3
66
  2024-01-11 23:38:14,482 - llmFactory.py - get_llm() - 37 - DEBUG - {'llm_config': {'max_tokens': 1024, 'temperature': 0.1}, 'llm_type': 'geminiLLM'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  2024-01-11 23:38:12,386 - posthog.py - __init__() - 20 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
65
  2024-01-11 23:38:14,482 - llmFactory.py - get_llm() - 36 - DEBUG - executor3
66
  2024-01-11 23:38:14,482 - llmFactory.py - get_llm() - 37 - DEBUG - {'llm_config': {'max_tokens': 1024, 'temperature': 0.1}, 'llm_type': 'geminiLLM'}
67
+ 2024-01-12 10:12:12,735 - dropbox_client.py - refresh_access_token() - 390 - INFO - Refreshing access token.
68
+ 2024-01-12 10:12:13,184 - dropbox_client.py - request_json_string_with_retry() - 474 - INFO - Request to users/get_current_account
69
+ 2024-01-12 10:12:22,103 - SentenceTransformer.py - __init__() - 66 - INFO - Load pretrained SentenceTransformer: BAAI/bge-large-en-v1.5
70
+ 2024-01-12 10:12:27,492 - posthog.py - __init__() - 20 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
71
+ 2024-01-12 10:12:29,074 - llmFactory.py - get_llm() - 36 - DEBUG - executor3
72
+ 2024-01-12 10:12:29,074 - llmFactory.py - get_llm() - 37 - DEBUG - {'llm_config': {'max_tokens': 1024, 'temperature': 0.1}, 'llm_type': 'geminiLLM'}
73
+ 2024-01-12 10:12:29,132 - proactor_events.py - __init__() - 629 - DEBUG - Using proactor: IocpProactor
74
+ 2024-01-12 10:12:29,240 - connectionpool.py - _make_request() - 456 - DEBUG - https://app.posthog.com:443 "POST /batch/ HTTP/1.1" 200 None
75
+ 2024-01-12 10:12:29,293 - SentenceTransformer.py - __init__() - 66 - INFO - Load pretrained SentenceTransformer: BAAI/bge-large-en-v1.5
76
+ 2024-01-12 10:12:41,220 - connectionpool.py - _new_conn() - 1003 - DEBUG - Starting new HTTPS connection (1): device-1a455.firebaseio.com:443
77
+ 2024-01-12 10:12:41,324 - connectionpool.py - _make_request() - 456 - DEBUG - https://app.posthog.com:443 "POST /batch/ HTTP/1.1" 200 None
78
+ 2024-01-12 10:12:41,762 - connectionpool.py - _make_request() - 456 - DEBUG - https://device-1a455.firebaseio.com:443 "GET /users/131251/llm_config/executor3.json?auth=eyJ0eXAiOiAiSldUIiwgImFsZyI6ICJIUzI1NiJ9.eyJhZG1pbiI6IGZhbHNlLCAiZGVidWciOiBmYWxzZSwgInYiOiAwLCAiaWF0IjogMTcwNTAxNDc2MSwgImQiOiB7ImlkIjogIjEzMTI1MSIsICJkZWJ1ZyI6IGZhbHNlLCAiYWRtaW4iOiBmYWxzZSwgImVtYWlsIjogImFudWJoYXY3N0BnbWFpbC5jb20iLCAicHJvdmlkZXIiOiAicGFzc3dvcmQifX0.vRs8wPErJN9HLbVChqjLnOO-W7pkPq3LIVUmN1jVPGU HTTP/1.1" 200 75
79
+ 2024-01-12 10:12:41,764 - llmFactory.py - get_llm() - 36 - DEBUG - executor3
80
+ 2024-01-12 10:12:41,764 - llmFactory.py - get_llm() - 37 - DEBUG - {'llm_config': {'max_tokens': 1024, 'temperature': 0.1}, 'llm_type': 'geminiLLM'}
src/chromaIntf.py CHANGED
@@ -1,12 +1,15 @@
1
  import sys
 
2
  try:
3
  import pysqlite3
 
4
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
5
  except:
6
  pass
7
  import chromadb
8
  from langchain.vectorstores import Chroma
9
- from chromadb.api.fastapi import requests
 
10
  from langchain.schema import Document
11
  from langchain.chains import RetrievalQA
12
  from langchain.embeddings import HuggingFaceBgeEmbeddings
@@ -21,31 +24,45 @@ from uuid import UUID
21
  from langchain.text_splitter import RecursiveCharacterTextSplitter
22
  import logging, asyncio
23
 
24
- logger=logging.getLogger("root")
 
25
 
26
  class myChromaTranslator(ChromaTranslator):
27
  allowed_operators = ["$and", "$or"]
28
  """Subset of allowed logical operators."""
29
- allowed_comparators = [ "$eq","$ne","$gt","$gte","$lt","$lte",
30
- "$contains","$not_contains","$in","$nin"]
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- class ChromaIntf():
33
  def __init__(self):
34
- self.db_interface=DbInterface()
35
 
36
  model_name = "BAAI/bge-large-en-v1.5"
37
- encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
 
 
38
 
39
  self.embedding = HuggingFaceBgeEmbeddings(
40
- model_name=model_name,
41
- model_kwargs={'device': 'cpu'},
42
- encode_kwargs=encode_kwargs
43
  )
44
 
45
- self.persist_db_directory = 'db'
46
  self.persist_docs_directory = "persistence-docs"
47
  self.logger_file = "persistence.log"
48
- loop=asyncio.get_event_loop()
49
  try:
50
  loop.run_until_complete(dbh.restoreFolder(self.persist_db_directory))
51
  loop.run_until_complete(dbh.restoreFolder(self.persist_docs_directory))
@@ -54,25 +71,31 @@ class ChromaIntf():
54
  docs = [
55
  Document(
56
  page_content="this is test doc",
57
- metadata={"timestamp":1696743148.474055,"ID":"2000-01-01 15:57:11::664165-test","source":"test"},
58
- id="2000-01-01 15:57:11::664165-test"
59
- ),
60
- ]
 
 
 
 
61
 
62
- self.vectorstore = Chroma.from_documents(documents=docs,
63
- embedding=self.embedding,
64
- persist_directory=self.persist_db_directory)
65
- #self.vectorstore._client.
 
 
66
 
67
  # timestamp --> time when added
68
  # source --> notes/references/web/youtube/book/conversation, default conversation
69
  # title --> of document , will be conversation when source is conversation, default blank
70
  # author --> will default to blank
71
  # "Year": 2024,
72
- #"Month": 1,
73
- #"Day": 3,
74
- #"Hour": 11,
75
- #"Minute": 29
76
  self.metadata_field_info = [
77
  AttributeInfo(
78
  name="timestamp",
@@ -118,12 +141,14 @@ class ChromaIntf():
118
  name="author",
119
  description="Author of the entry",
120
  type="string",
121
- )
122
- ]
123
- self.document_content_description = "Information to store for retrival from LLM based chatbot"
124
- lf=LLMFactory()
125
- #self.llm=lf.get_llm("executor2")
126
- self.llm=lf.get_llm("executor3")
 
 
127
 
128
  self.retriever = SelfQueryRetriever.from_llm(
129
  self.llm,
@@ -131,63 +156,61 @@ class ChromaIntf():
131
  self.document_content_description,
132
  self.metadata_field_info,
133
  structured_query_translator=ChromaTranslator(),
134
- verbose=True
135
  )
136
 
137
-
138
- async def getRelevantDocs(self,query:str,kwargs:dict):
139
  """This should also post the result to firebase"""
140
- print("retriver state",self.retriever.search_kwargs)
141
- print("retriver state",self.retriever.search_type)
142
  try:
143
  for key in kwargs.keys():
144
  if "search_type" in key:
145
- self.retriever.search_type=kwargs[key]
146
  else:
147
- self.retriever.search_kwargs[key]=kwargs[key]
148
  except:
149
  print("setting search args failed")
150
  print("reaching step2")
151
  try:
152
- #loop=asyncio.get_event_loop()
153
- retVal=self.retriever.get_relevant_documents(query)
154
  except Exception as ex:
155
- logger.exception("Exception occured:",exc_info=True)
156
- value=[]
157
- excludeMeta=True
158
  print("reaching step3")
159
  print(str(len(retVal)))
160
  print("reaching step4")
161
  try:
162
  for item in retVal:
163
  if excludeMeta:
164
- v=item.page_content+" \n"
165
  else:
166
- v="Info:"+item.page_content+" "
167
  for key in item.metadata.keys():
168
  if key != "ID":
169
- v+=key+":"+str(item.metadata[key])+" "
170
  value.append(v)
171
  print("reaching step5")
172
- self.db_interface.add_to_cache(input=query,value=value)
173
  except:
174
  print("reaching step6")
175
  for item in retVal:
176
  if excludeMeta:
177
- v=item['page_content']+" \n"
178
  else:
179
- v="Info:"+item['page_content']+" "
180
- for key in item['metadata'].keys():
181
  if key != "ID":
182
- v+=key+":"+str(item['metadata'][key])+" "
183
  value.append(v)
184
  print("reaching step7")
185
- self.db_interface.add_to_cache(input=query,value=value)
186
  print("reaching step8")
187
  return retVal
188
-
189
 
190
- async def addText(self,inStr:str,metadata):
191
  # metadata expected is some of following
192
  # timestamp --> time when added
193
  # source --> notes/references/web/youtube/book/conversation, default conversation
@@ -195,78 +218,83 @@ class ChromaIntf():
195
  # author --> will default to blank
196
 
197
  ##TODO: Preprocess inStr to remove any html, markdown tags etc.
198
- metadata=metadata.dict()
199
  if "timestamp" not in metadata.keys():
200
- metadata['timestamp']=datetime.now().isoformat()
201
  else:
202
- metadata['timestamp']=datetime.fromisoformat(metadata['timestamp'])
203
  pass
204
  if "source" not in metadata.keys():
205
- metadata['source']="conversation"
206
- if "title" not in metadata.keys():
207
  metadata["title"] = ""
208
- if metadata["source"] == "conversation":
209
  metadata["title"] == "conversation"
210
- if "author" not in metadata.keys():
211
  metadata["author"] = ""
212
-
213
- #TODO: If url is present in input or when the splitting need to be done, then we'll need to change how we
214
  # formulate the ID and may be filename to store information
215
- metadata['ID']=metadata['timestamp'].strftime("%Y-%m-%d %H-%M-%S")+"-"+metadata['title']
216
- metadata['Year']=metadata['timestamp'].year
217
- metadata['Month']=metadata['timestamp'].month
218
- metadata['Day']=int(metadata['timestamp'].strftime("%d"))
219
- metadata['Hour']=metadata['timestamp'].hour
220
- metadata['Minute']=metadata['timestamp'].minute
221
- metadata['timestamp']=metadata['timestamp'].isoformat()
 
 
 
 
222
  print("Metadata is:")
223
  print(metadata)
224
- #md.pop("timestamp")
225
- with open("./docs/"+metadata['ID']+".txt","w") as fd:
226
  fd.write(inStr)
227
  print("written to file", inStr)
228
  text_splitter = RecursiveCharacterTextSplitter(
229
  chunk_size=800,
230
  chunk_overlap=50,
231
  length_function=len,
232
- is_separator_regex=False)
233
- #docs = [ Document(page_content=inStr, metadata=metadata)]
234
- docs=text_splitter.create_documents([inStr],[metadata])
235
- partNumber=0
 
236
  for doc in docs:
237
  if partNumber > 0:
238
- doc.metadata['ID']+=f"__{partNumber}"
239
- partNumber+=1
240
  print(f"{partNumber} follows:")
241
  print(doc)
242
  try:
243
- print(metadata['ID'])
244
- ids=[doc.metadata['ID'] for doc in docs]
245
  print("ids are:")
246
  print(ids)
247
- return await self.vectorstore.aadd_documents(docs,ids=ids)
248
  except Exception as ex:
249
- logger.exception("exception in adding",exc_info=True)
250
  print("inside expect of addText")
251
- return await self.vectorstore.aadd_documents(docs,ids=[metadata.ID])
252
-
253
  async def listDocs(self):
254
- collection=self.vectorstore._client.get_collection(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME,embedding_function=self.embedding)
 
 
 
255
  return collection.get()
256
- #return self.vectorstore._client._get(collection_id=self._uuid(collectionInfo.id))
257
-
258
-
259
  async def persist(self):
260
  self.vectorstore.persist()
261
  await dbh.backupFile(self.logger_file)
262
  await dbh.backupFolder(self.persist_db_directory)
263
  return await dbh.backupFolder(self.persist_docs_directory)
264
-
265
- def _uuid(self,uuid_str: str) -> UUID:
266
  try:
267
  return UUID(uuid_str)
268
  except ValueError:
269
  print("Error generating uuid")
270
  raise ValueError(f"Could not parse {uuid_str} as a UUID")
271
-
272
-
 
1
  import sys
2
+
3
  try:
4
  import pysqlite3
5
+
6
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
7
  except:
8
  pass
9
  import chromadb
10
  from langchain.vectorstores import Chroma
11
+
12
+ # from chromadb.api.fastapi import requests
13
  from langchain.schema import Document
14
  from langchain.chains import RetrievalQA
15
  from langchain.embeddings import HuggingFaceBgeEmbeddings
 
24
  from langchain.text_splitter import RecursiveCharacterTextSplitter
25
  import logging, asyncio
26
 
27
+ logger = logging.getLogger("root")
28
+
29
 
30
  class myChromaTranslator(ChromaTranslator):
31
  allowed_operators = ["$and", "$or"]
32
  """Subset of allowed logical operators."""
33
+ allowed_comparators = [
34
+ "$eq",
35
+ "$ne",
36
+ "$gt",
37
+ "$gte",
38
+ "$lt",
39
+ "$lte",
40
+ "$contains",
41
+ "$not_contains",
42
+ "$in",
43
+ "$nin",
44
+ ]
45
+
46
 
47
+ class ChromaIntf:
48
  def __init__(self):
49
+ self.db_interface = DbInterface()
50
 
51
  model_name = "BAAI/bge-large-en-v1.5"
52
+ encode_kwargs = {
53
+ "normalize_embeddings": True
54
+ } # set True to compute cosine similarity
55
 
56
  self.embedding = HuggingFaceBgeEmbeddings(
57
+ model_name=model_name,
58
+ model_kwargs={"device": "cpu"},
59
+ encode_kwargs=encode_kwargs,
60
  )
61
 
62
+ self.persist_db_directory = "db"
63
  self.persist_docs_directory = "persistence-docs"
64
  self.logger_file = "persistence.log"
65
+ loop = asyncio.get_event_loop()
66
  try:
67
  loop.run_until_complete(dbh.restoreFolder(self.persist_db_directory))
68
  loop.run_until_complete(dbh.restoreFolder(self.persist_docs_directory))
 
71
  docs = [
72
  Document(
73
  page_content="this is test doc",
74
+ metadata={
75
+ "timestamp": 1696743148.474055,
76
+ "ID": "2000-01-01 15:57:11::664165-test",
77
+ "source": "test",
78
+ },
79
+ id="2000-01-01 15:57:11::664165-test",
80
+ ),
81
+ ]
82
 
83
+ self.vectorstore = Chroma.from_documents(
84
+ documents=docs,
85
+ embedding=self.embedding,
86
+ persist_directory=self.persist_db_directory,
87
+ )
88
+ # self.vectorstore._client.
89
 
90
  # timestamp --> time when added
91
  # source --> notes/references/web/youtube/book/conversation, default conversation
92
  # title --> of document , will be conversation when source is conversation, default blank
93
  # author --> will default to blank
94
  # "Year": 2024,
95
+ # "Month": 1,
96
+ # "Day": 3,
97
+ # "Hour": 11,
98
+ # "Minute": 29
99
  self.metadata_field_info = [
100
  AttributeInfo(
101
  name="timestamp",
 
141
  name="author",
142
  description="Author of the entry",
143
  type="string",
144
+ ),
145
+ ]
146
+ self.document_content_description = (
147
+ "Information to store for retrival from LLM based chatbot"
148
+ )
149
+ lf = LLMFactory()
150
+ # self.llm=lf.get_llm("executor2")
151
+ self.llm = lf.get_llm("executor3")
152
 
153
  self.retriever = SelfQueryRetriever.from_llm(
154
  self.llm,
 
156
  self.document_content_description,
157
  self.metadata_field_info,
158
  structured_query_translator=ChromaTranslator(),
159
+ verbose=True,
160
  )
161
 
162
+ async def getRelevantDocs(self, query: str, kwargs: dict):
 
163
  """This should also post the result to firebase"""
164
+ print("retriver state", self.retriever.search_kwargs)
165
+ print("retriver state", self.retriever.search_type)
166
  try:
167
  for key in kwargs.keys():
168
  if "search_type" in key:
169
+ self.retriever.search_type = kwargs[key]
170
  else:
171
+ self.retriever.search_kwargs[key] = kwargs[key]
172
  except:
173
  print("setting search args failed")
174
  print("reaching step2")
175
  try:
176
+ # loop=asyncio.get_event_loop()
177
+ retVal = self.retriever.get_relevant_documents(query)
178
  except Exception as ex:
179
+ logger.exception("Exception occured:", exc_info=True)
180
+ value = []
181
+ excludeMeta = True
182
  print("reaching step3")
183
  print(str(len(retVal)))
184
  print("reaching step4")
185
  try:
186
  for item in retVal:
187
  if excludeMeta:
188
+ v = item.page_content + " \n"
189
  else:
190
+ v = "Info:" + item.page_content + " "
191
  for key in item.metadata.keys():
192
  if key != "ID":
193
+ v += key + ":" + str(item.metadata[key]) + " "
194
  value.append(v)
195
  print("reaching step5")
196
+ self.db_interface.add_to_cache(input=query, value=value)
197
  except:
198
  print("reaching step6")
199
  for item in retVal:
200
  if excludeMeta:
201
+ v = item["page_content"] + " \n"
202
  else:
203
+ v = "Info:" + item["page_content"] + " "
204
+ for key in item["metadata"].keys():
205
  if key != "ID":
206
+ v += key + ":" + str(item["metadata"][key]) + " "
207
  value.append(v)
208
  print("reaching step7")
209
+ self.db_interface.add_to_cache(input=query, value=value)
210
  print("reaching step8")
211
  return retVal
 
212
 
213
+ async def addText(self, inStr: str, metadata):
214
  # metadata expected is some of following
215
  # timestamp --> time when added
216
  # source --> notes/references/web/youtube/book/conversation, default conversation
 
218
  # author --> will default to blank
219
 
220
  ##TODO: Preprocess inStr to remove any html, markdown tags etc.
221
+ metadata = metadata.dict()
222
  if "timestamp" not in metadata.keys():
223
+ metadata["timestamp"] = datetime.now().isoformat()
224
  else:
225
+ metadata["timestamp"] = datetime.fromisoformat(metadata["timestamp"])
226
  pass
227
  if "source" not in metadata.keys():
228
+ metadata["source"] = "conversation"
229
+ if "title" not in metadata.keys():
230
  metadata["title"] = ""
231
+ if metadata["source"] == "conversation":
232
  metadata["title"] == "conversation"
233
+ if "author" not in metadata.keys():
234
  metadata["author"] = ""
235
+
236
+ # TODO: If url is present in input or when the splitting need to be done, then we'll need to change how we
237
  # formulate the ID and may be filename to store information
238
+ metadata["ID"] = (
239
+ metadata["timestamp"].strftime("%Y-%m-%d %H-%M-%S")
240
+ + "-"
241
+ + metadata["title"]
242
+ )
243
+ metadata["Year"] = metadata["timestamp"].year
244
+ metadata["Month"] = metadata["timestamp"].month
245
+ metadata["Day"] = int(metadata["timestamp"].strftime("%d"))
246
+ metadata["Hour"] = metadata["timestamp"].hour
247
+ metadata["Minute"] = metadata["timestamp"].minute
248
+ metadata["timestamp"] = metadata["timestamp"].isoformat()
249
  print("Metadata is:")
250
  print(metadata)
251
+ # md.pop("timestamp")
252
+ with open("./docs/" + metadata["ID"] + ".txt", "w") as fd:
253
  fd.write(inStr)
254
  print("written to file", inStr)
255
  text_splitter = RecursiveCharacterTextSplitter(
256
  chunk_size=800,
257
  chunk_overlap=50,
258
  length_function=len,
259
+ is_separator_regex=False,
260
+ )
261
+ # docs = [ Document(page_content=inStr, metadata=metadata)]
262
+ docs = text_splitter.create_documents([inStr], [metadata])
263
+ partNumber = 0
264
  for doc in docs:
265
  if partNumber > 0:
266
+ doc.metadata["ID"] += f"__{partNumber}"
267
+ partNumber += 1
268
  print(f"{partNumber} follows:")
269
  print(doc)
270
  try:
271
+ print(metadata["ID"])
272
+ ids = [doc.metadata["ID"] for doc in docs]
273
  print("ids are:")
274
  print(ids)
275
+ return await self.vectorstore.aadd_documents(docs, ids=ids)
276
  except Exception as ex:
277
+ logger.exception("exception in adding", exc_info=True)
278
  print("inside expect of addText")
279
+ return await self.vectorstore.aadd_documents(docs, ids=[metadata.ID])
280
+
281
  async def listDocs(self):
282
+ collection = self.vectorstore._client.get_collection(
283
+ self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME,
284
+ embedding_function=self.embedding,
285
+ )
286
  return collection.get()
287
+ # return self.vectorstore._client._get(collection_id=self._uuid(collectionInfo.id))
288
+
 
289
  async def persist(self):
290
  self.vectorstore.persist()
291
  await dbh.backupFile(self.logger_file)
292
  await dbh.backupFolder(self.persist_db_directory)
293
  return await dbh.backupFolder(self.persist_docs_directory)
294
+
295
+ def _uuid(self, uuid_str: str) -> UUID:
296
  try:
297
  return UUID(uuid_str)
298
  except ValueError:
299
  print("Error generating uuid")
300
  raise ValueError(f"Could not parse {uuid_str} as a UUID")
 
 
src/llm/geminiLLM.py CHANGED
@@ -1,34 +1,33 @@
1
- from typing import Any, List, Mapping, Optional, Dict
2
- from pydantic import Extra, Field #, root_validator, model_validator
3
- import os,json
4
  from langchain.callbacks.manager import CallbackManagerForLLMRun
5
  from langchain.llms.base import LLM
6
  import google.generativeai as genai
7
  from google.generativeai import types
8
  import ast
9
- #from langchain.llms import GooglePalm
10
- import requests,logging
11
 
12
- logger=logging.getLogger("llm")
 
 
 
 
13
 
14
  class GeminiLLM(LLM):
15
-
16
- model_name: str = "gemini-pro"
17
  temperature: float = 0
18
  max_tokens: int = 2048
19
  stop: Optional[List] = []
20
- prev_prompt: Optional[str]=""
21
- prev_stop: Optional[str]=""
22
- prev_run_manager:Optional[Any]=None
23
- model: Optional[Any]=None
24
 
25
- def __init__(
26
- self,
27
- **kwargs
28
- ):
29
  super().__init__(**kwargs)
30
- self.model=genai.GenerativeModel(self.model_name)
31
- #self.model = palm.Text2Text(self.model_name)
32
 
33
  @property
34
  def _llm_type(self) -> str:
@@ -40,76 +39,95 @@ class GeminiLLM(LLM):
40
  stop: Optional[List[str]] = None,
41
  run_manager: Optional[CallbackManagerForLLMRun] = None,
42
  ) -> str:
43
- self.prev_prompt=prompt
44
- self.prev_stop=stop
45
- self.prev_run_manager=run_manager
46
- #print(types.SafetySettingDict)
47
  if stop == None:
48
- stop=self.stop
49
- logger.debug("\nLLM in use is:" +self._llm_type)
50
- logger.debug("Request to LLM is "+prompt)
51
-
52
- response=self.model.generate_content(prompt,
53
- generation_config={"stop_sequences":self.stop,
54
- "temperature":self.temperature, "max_output_tokens":self.max_tokens},
55
- safety_settings=[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_NONE"},
56
- {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_NONE"},
57
- {"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_NONE"},
58
- {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_NONE"}],
59
- stream=False
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
- val=response.text
63
  if val == None:
64
  logger.debug("Response from LLM was None\n")
65
- filterStr=""
66
  for item in response.filters:
67
- for key,val in item.items():
68
- filterStr+=key+":"+str(val)
69
- logger.error("Will switch to fallback LLM as response from palm is None::"+filterStr)
70
- raise(Exception)
 
 
 
71
  else:
72
- logger.debug("Response from LLM "+val)
73
  except Exception as ex:
74
- logger.error("Will switch to fallback LLM as response from palm is None::")
75
- raise(Exception)
76
  if run_manager:
77
  pass
78
- #run_manager.on_llm_end(val)
79
  return val
80
-
81
  @property
82
  def _identifying_params(self) -> Mapping[str, Any]:
83
  """Get the identifying parameters."""
84
- return {"name": self.model_name, "type": "palm"}
85
-
86
- def extractJson(self,val:str) -> Any:
87
  """Helper function to extract json from this LLMs output"""
88
- #This is assuming the json is the first item within ````
89
  # palm is responding always with ```json and ending with ```, however sometimes response is not complete
90
  # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
91
  try:
92
- count=0
93
- while val.startswith("```json") and not val.endswith("```") and count<7:
94
- val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
95
- count+=1
96
- v2=val.replace("```json","```").split("```")[1]
 
 
 
 
97
  try:
98
- v4=json.loads(v2)
99
  except:
100
- #v3=v2.replace("\n","").replace("\r","").replace("'","\"")
101
- v3=json.dumps(ast.literal_eval(v2))
102
- v4=json.loads(v3)
103
  except:
104
- v2=val.replace("\n","").replace("\r","")
105
- v3=json.dumps(ast.literal_eval(val))
106
- #v3=v2.replace("'","\"")
107
- v4=json.loads(v3)
108
- #v4=json.loads(v2)
109
  return v4
110
-
111
- def extractPython(self,val:str) -> Any:
112
  """Helper function to extract python from this LLMs output"""
113
- #This is assuming the python is the first item within ````
114
- v2=val.replace("```python","```").split("```")[1]
115
- return v2
 
1
+ from typing import Any, List, Mapping, Optional, Dict
2
+ from pydantic import Extra, Field # , root_validator, model_validator
3
+ import os, json
4
  from langchain.callbacks.manager import CallbackManagerForLLMRun
5
  from langchain.llms.base import LLM
6
  import google.generativeai as genai
7
  from google.generativeai import types
8
  import ast
 
 
9
 
10
+ # from langchain.llms import GooglePalm
11
+ import requests, logging
12
+
13
+ logger = logging.getLogger("llm")
14
+
15
 
16
  class GeminiLLM(LLM):
17
+
18
+ model_name: str = "gemini-1.5-flash" # "gemini-pro"
19
  temperature: float = 0
20
  max_tokens: int = 2048
21
  stop: Optional[List] = []
22
+ prev_prompt: Optional[str] = ""
23
+ prev_stop: Optional[str] = ""
24
+ prev_run_manager: Optional[Any] = None
25
+ model: Optional[Any] = None
26
 
27
+ def __init__(self, **kwargs):
 
 
 
28
  super().__init__(**kwargs)
29
+ self.model = genai.GenerativeModel(self.model_name)
30
+ # self.model = palm.Text2Text(self.model_name)
31
 
32
  @property
33
  def _llm_type(self) -> str:
 
39
  stop: Optional[List[str]] = None,
40
  run_manager: Optional[CallbackManagerForLLMRun] = None,
41
  ) -> str:
42
+ self.prev_prompt = prompt
43
+ self.prev_stop = stop
44
+ self.prev_run_manager = run_manager
45
+ # print(types.SafetySettingDict)
46
  if stop == None:
47
+ stop = self.stop
48
+ logger.debug("\nLLM in use is:" + self._llm_type)
49
+ logger.debug("Request to LLM is " + prompt)
50
+
51
+ response = self.model.generate_content(
52
+ prompt,
53
+ generation_config={
54
+ "stop_sequences": self.stop,
55
+ "temperature": self.temperature,
56
+ "max_output_tokens": self.max_tokens,
57
+ },
58
+ safety_settings=[
59
+ {
60
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
61
+ "threshold": "BLOCK_NONE",
62
+ },
63
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
64
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
65
+ {
66
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
67
+ "threshold": "BLOCK_NONE",
68
+ },
69
+ ],
70
+ stream=False,
71
+ )
72
  try:
73
+ val = response.text
74
  if val == None:
75
  logger.debug("Response from LLM was None\n")
76
+ filterStr = ""
77
  for item in response.filters:
78
+ for key, val in item.items():
79
+ filterStr += key + ":" + str(val)
80
+ logger.error(
81
+ "Will switch to fallback LLM as response from palm is None::"
82
+ + filterStr
83
+ )
84
+ raise (Exception)
85
  else:
86
+ logger.debug("Response from LLM " + val)
87
  except Exception as ex:
88
+ logger.error("Will switch to fallback LLM as response from palm is None::")
89
+ raise (Exception)
90
  if run_manager:
91
  pass
92
+ # run_manager.on_llm_end(val)
93
  return val
94
+
95
  @property
96
  def _identifying_params(self) -> Mapping[str, Any]:
97
  """Get the identifying parameters."""
98
+ return {"name": self.model_name, "type": "palm"}
99
+
100
+ def extractJson(self, val: str) -> Any:
101
  """Helper function to extract json from this LLMs output"""
102
+ # This is assuming the json is the first item within ````
103
  # palm is responding always with ```json and ending with ```, however sometimes response is not complete
104
  # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
105
  try:
106
+ count = 0
107
+ while val.startswith("```json") and not val.endswith("```") and count < 7:
108
+ val = self._call(
109
+ prompt=self.prev_prompt + " " + val,
110
+ stop=self.prev_stop,
111
+ run_manager=self.prev_run_manager,
112
+ )
113
+ count += 1
114
+ v2 = val.replace("```json", "```").split("```")[1]
115
  try:
116
+ v4 = json.loads(v2)
117
  except:
118
+ # v3=v2.replace("\n","").replace("\r","").replace("'","\"")
119
+ v3 = json.dumps(ast.literal_eval(v2))
120
+ v4 = json.loads(v3)
121
  except:
122
+ v2 = val.replace("\n", "").replace("\r", "")
123
+ v3 = json.dumps(ast.literal_eval(val))
124
+ # v3=v2.replace("'","\"")
125
+ v4 = json.loads(v3)
126
+ # v4=json.loads(v2)
127
  return v4
128
+
129
+ def extractPython(self, val: str) -> Any:
130
  """Helper function to extract python from this LLMs output"""
131
+ # This is assuming the python is the first item within ````
132
+ v2 = val.replace("```python", "```").split("```")[1]
133
+ return v2