from typing import Any, List, Mapping, Optional, Dict from pydantic import Extra, Field #, root_validator, model_validator import os,json from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM import google.generativeai as palm from google.generativeai import types import ast #from langchain.llms import GooglePalm import requests class PalmLLM(LLM): model_name: str = "text-bison-001" temperature: float = 0 max_tokens: int = 2048 stop: Optional[List] = [] prev_prompt: Optional[str]="" prev_stop: Optional[str]="" prev_run_manager:Optional[Any]=None def __init__( self, **kwargs ): super().__init__(**kwargs) palm.configure() #self.model = palm.Text2Text(self.model_name) @property def _llm_type(self) -> str: return "text2text-generation" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: self.prev_prompt=prompt self.prev_stop=stop self.prev_run_manager=run_manager if stop == None: stop=self.stop text=palm.generate_text(prompt=prompt,stop_sequences=self.stop, temperature=self.temperature, max_output_tokens=self.max_tokens, # safety_settings=[{"category":0,"threshold":4}, # {"category":1,"threshold":4}, # {"category":2,"threshold":4}, # {"category":3,"threshold":4}, # {"category":4,"threshold":4}, # {"category":5,"threshold":4}, # {"category":6, "threshold":4}] safety_settings=[{"category":"HARM_CATEGORY_DEROGATORY","threshold":4}, {"category":"HARM_CATEGORY_TOXICITY","threshold":4}, {"category":"HARM_CATEGORY_VIOLENCE","threshold":4}, {"category":"HARM_CATEGORY_SEXUAL","threshold":4}, {"category":"HARM_CATEGORY_MEDICAL","threshold":4}, {"category":"HARM_CATEGORY_DANGEROUS","threshold":4}] ) print("Response from palm",text) val=text.result if run_manager: run_manager.on_llm_end(val) return val @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"name": self.model_name, "type": "palm"} def extractJson(self,val:str) -> Any: """Helper function to extract json from this LLMs output""" #This is assuming the json is the first item within ```` # palm is responding always with ```json and ending with ```, however sometimes response is not complete # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it try: count=0 while val.startswith("```json") and not val.endswith("```") and count<7: val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager) count+=1 v2=val.replace("```json","```").split("```")[1] #v3=v2.replace("\n","").replace("\r","").replace("'","\"") v3=json.dumps(ast.literal_eval(v2)) v4=json.loads(v3) except: v2=val.replace("\n","").replace("\r","") v3=json.dumps(ast.literal_eval(val)) #v3=v2.replace("'","\"") v4=json.loads(v3) #v4=json.loads(v2) return v4 def extractPython(self,val:str) -> Any: """Helper function to extract python from this LLMs output""" #This is assuming the python is the first item within ```` v2=val.replace("```python","```").split("```")[1] return v2