from typing import Any, List, Mapping, Optional, Dict from pydantic import Extra, Field #, root_validator, model_validator import os,json from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM import google.generativeai as genai from google.generativeai import types import ast #from langchain.llms import GooglePalm import requests,logging logger=logging.getLogger("llm") class GeminiLLM(LLM): model_name: str = "gemini-pro" temperature: float = 0 max_tokens: int = 2048 stop: Optional[List] = [] prev_prompt: Optional[str]="" prev_stop: Optional[str]="" prev_run_manager:Optional[Any]=None model: Optional[Any]=None def __init__( self, **kwargs ): super().__init__(**kwargs) self.model=genai.GenerativeModel(self.model_name) #self.model = palm.Text2Text(self.model_name) @property def _llm_type(self) -> str: return "text2text-generation" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: self.prev_prompt=prompt self.prev_stop=stop self.prev_run_manager=run_manager #print(types.SafetySettingDict) if stop == None: stop=self.stop logger.debug("\nLLM in use is:" +self._llm_type) logger.debug("Request to LLM is "+prompt) response=self.model.generate_content(prompt, generation_config={"stop_sequences":self.stop, "temperature":self.temperature, "max_output_tokens":self.max_tokens}, safety_settings=[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_NONE"}, {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_NONE"}, {"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_NONE"}, {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_NONE"}], stream=False ) try: val=response.text if val == None: logger.debug("Response from LLM was None\n") filterStr="" for item in response.filters: for key,val in item.items(): filterStr+=key+":"+str(val) logger.error("Will switch to fallback LLM as response from palm is None::"+filterStr) raise(Exception) else: logger.debug("Response from LLM "+val) except Exception as ex: logger.error("Will switch to fallback LLM as response from palm is None::") raise(Exception) if run_manager: pass #run_manager.on_llm_end(val) return val @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"name": self.model_name, "type": "palm"} def extractJson(self,val:str) -> Any: """Helper function to extract json from this LLMs output""" #This is assuming the json is the first item within ```` # palm is responding always with ```json and ending with ```, however sometimes response is not complete # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it try: count=0 while val.startswith("```json") and not val.endswith("```") and count<7: val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager) count+=1 v2=val.replace("```json","```").split("```")[1] try: v4=json.loads(v2) except: #v3=v2.replace("\n","").replace("\r","").replace("'","\"") v3=json.dumps(ast.literal_eval(v2)) v4=json.loads(v3) except: v2=val.replace("\n","").replace("\r","") v3=json.dumps(ast.literal_eval(val)) #v3=v2.replace("'","\"") v4=json.loads(v3) #v4=json.loads(v2) return v4 def extractPython(self,val:str) -> Any: """Helper function to extract python from this LLMs output""" #This is assuming the python is the first item within ```` v2=val.replace("```python","```").split("```")[1] return v2