maya-persistence / src /llm /geminiLLM.py
anubhav77's picture
v0.1.2
4ab9cb1
raw
history blame
4.61 kB
from typing import Any, List, Mapping, Optional, Dict
from pydantic import Extra, Field #, root_validator, model_validator
import os,json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import google.generativeai as genai
from google.generativeai import types
import ast
#from langchain.llms import GooglePalm
import requests,logging
logger=logging.getLogger("llm")
class GeminiLLM(LLM):
model_name: str = "gemini-pro"
temperature: float = 0
max_tokens: int = 2048
stop: Optional[List] = []
prev_prompt: Optional[str]=""
prev_stop: Optional[str]=""
prev_run_manager:Optional[Any]=None
model: Optional[Any]=None
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
self.model=genai.GenerativeModel(self.model_name)
#self.model = palm.Text2Text(self.model_name)
@property
def _llm_type(self) -> str:
return "text2text-generation"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
self.prev_prompt=prompt
self.prev_stop=stop
self.prev_run_manager=run_manager
#print(types.SafetySettingDict)
if stop == None:
stop=self.stop
logger.debug("\nLLM in use is:" +self._llm_type)
logger.debug("Request to LLM is "+prompt)
response=self.model.generate_content(prompt,
generation_config={"stop_sequences":self.stop,
"temperature":self.temperature, "max_output_tokens":self.max_tokens},
safety_settings=[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_NONE"},
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_NONE"},
{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_NONE"},
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_NONE"}],
stream=False
)
try:
val=response.text
if val == None:
logger.debug("Response from LLM was None\n")
filterStr=""
for item in response.filters:
for key,val in item.items():
filterStr+=key+":"+str(val)
logger.error("Will switch to fallback LLM as response from palm is None::"+filterStr)
raise(Exception)
else:
logger.debug("Response from LLM "+val)
except Exception as ex:
logger.error("Will switch to fallback LLM as response from palm is None::")
raise(Exception)
if run_manager:
pass
#run_manager.on_llm_end(val)
return val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"name": self.model_name, "type": "palm"}
def extractJson(self,val:str) -> Any:
"""Helper function to extract json from this LLMs output"""
#This is assuming the json is the first item within ````
# palm is responding always with ```json and ending with ```, however sometimes response is not complete
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
try:
count=0
while val.startswith("```json") and not val.endswith("```") and count<7:
val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
count+=1
v2=val.replace("```json","```").split("```")[1]
try:
v4=json.loads(v2)
except:
#v3=v2.replace("\n","").replace("\r","").replace("'","\"")
v3=json.dumps(ast.literal_eval(v2))
v4=json.loads(v3)
except:
v2=val.replace("\n","").replace("\r","")
v3=json.dumps(ast.literal_eval(val))
#v3=v2.replace("'","\"")
v4=json.loads(v3)
#v4=json.loads(v2)
return v4
def extractPython(self,val:str) -> Any:
"""Helper function to extract python from this LLMs output"""
#This is assuming the python is the first item within ````
v2=val.replace("```python","```").split("```")[1]
return v2