maya-persistence / src /llm /palmLLM.py
anubhav77's picture
updating for palm changes
c53e220
raw
history blame
4.2 kB
from typing import Any, List, Mapping, Optional, Dict
from pydantic import Extra, Field #, root_validator, model_validator
import os,json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import google.generativeai as palm
from google.generativeai import types
import ast
#from langchain.llms import GooglePalm
import requests
class PalmLLM(LLM):
model_name: str = "text-bison-001"
temperature: float = 0
max_tokens: int = 2048
stop: Optional[List] = []
prev_prompt: Optional[str]=""
prev_stop: Optional[str]=""
prev_run_manager:Optional[Any]=None
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
palm.configure()
#self.model = palm.Text2Text(self.model_name)
@property
def _llm_type(self) -> str:
return "text2text-generation"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
self.prev_prompt=prompt
self.prev_stop=stop
self.prev_run_manager=run_manager
if stop == None:
stop=self.stop
text=palm.generate_text(prompt=prompt,stop_sequences=self.stop,
temperature=self.temperature, max_output_tokens=self.max_tokens,
# safety_settings=[{"category":0,"threshold":4},
# {"category":1,"threshold":4},
# {"category":2,"threshold":4},
# {"category":3,"threshold":4},
# {"category":4,"threshold":4},
# {"category":5,"threshold":4},
# {"category":6, "threshold":4}]
safety_settings=[{"category":"HARM_CATEGORY_DEROGATORY","threshold":4},
{"category":"HARM_CATEGORY_TOXICITY","threshold":4},
{"category":"HARM_CATEGORY_VIOLENCE","threshold":4},
{"category":"HARM_CATEGORY_SEXUAL","threshold":4},
{"category":"HARM_CATEGORY_MEDICAL","threshold":4},
{"category":"HARM_CATEGORY_DANGEROUS","threshold":4}]
)
print("Response from palm",text)
val=text.result
if run_manager:
run_manager.on_llm_end(val)
return val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"name": self.model_name, "type": "palm"}
def extractJson(self,val:str) -> Any:
"""Helper function to extract json from this LLMs output"""
#This is assuming the json is the first item within ````
# palm is responding always with ```json and ending with ```, however sometimes response is not complete
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
try:
count=0
while val.startswith("```json") and not val.endswith("```") and count<7:
val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
count+=1
v2=val.replace("```json","```").split("```")[1]
#v3=v2.replace("\n","").replace("\r","").replace("'","\"")
v3=json.dumps(ast.literal_eval(v2))
v4=json.loads(v3)
except:
v2=val.replace("\n","").replace("\r","")
v3=json.dumps(ast.literal_eval(val))
#v3=v2.replace("'","\"")
v4=json.loads(v3)
#v4=json.loads(v2)
return v4
def extractPython(self,val:str) -> Any:
"""Helper function to extract python from this LLMs output"""
#This is assuming the python is the first item within ````
v2=val.replace("```python","```").split("```")[1]
return v2