Spaces:
Runtime error
Runtime error
from typing import Any, List, Mapping, Optional, Dict | |
from pydantic import Extra, Field #, root_validator, model_validator | |
import os,json | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
import google.generativeai as genai | |
from google.generativeai import types | |
import ast | |
#from langchain.llms import GooglePalm | |
import requests,logging | |
logger=logging.getLogger("llm") | |
class GeminiLLM(LLM): | |
model_name: str = "gemini-pro" | |
temperature: float = 0 | |
max_tokens: int = 2048 | |
stop: Optional[List] = [] | |
prev_prompt: Optional[str]="" | |
prev_stop: Optional[str]="" | |
prev_run_manager:Optional[Any]=None | |
model: Optional[Any]=None | |
def __init__( | |
self, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.model=genai.GenerativeModel(self.model_name) | |
#self.model = palm.Text2Text(self.model_name) | |
def _llm_type(self) -> str: | |
return "text2text-generation" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
self.prev_prompt=prompt | |
self.prev_stop=stop | |
self.prev_run_manager=run_manager | |
#print(types.SafetySettingDict) | |
if stop == None: | |
stop=self.stop | |
logger.debug("\nLLM in use is:" +self._llm_type) | |
logger.debug("Request to LLM is "+prompt) | |
response=self.model.generate_content(prompt, | |
generation_config={"stop_sequences":self.stop, | |
"temperature":self.temperature, "max_output_tokens":self.max_tokens}, | |
safety_settings=[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_NONE"}, | |
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_NONE"}, | |
{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_NONE"}, | |
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_NONE"}], | |
stream=False | |
) | |
try: | |
val=response.text | |
if val == None: | |
logger.debug("Response from LLM was None\n") | |
filterStr="" | |
for item in response.filters: | |
for key,val in item.items(): | |
filterStr+=key+":"+str(val) | |
logger.error("Will switch to fallback LLM as response from palm is None::"+filterStr) | |
raise(Exception) | |
else: | |
logger.debug("Response from LLM "+val) | |
except Exception as ex: | |
logger.error("Will switch to fallback LLM as response from palm is None::") | |
raise(Exception) | |
if run_manager: | |
pass | |
#run_manager.on_llm_end(val) | |
return val | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"name": self.model_name, "type": "palm"} | |
def extractJson(self,val:str) -> Any: | |
"""Helper function to extract json from this LLMs output""" | |
#This is assuming the json is the first item within ```` | |
# palm is responding always with ```json and ending with ```, however sometimes response is not complete | |
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it | |
try: | |
count=0 | |
while val.startswith("```json") and not val.endswith("```") and count<7: | |
val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager) | |
count+=1 | |
v2=val.replace("```json","```").split("```")[1] | |
try: | |
v4=json.loads(v2) | |
except: | |
#v3=v2.replace("\n","").replace("\r","").replace("'","\"") | |
v3=json.dumps(ast.literal_eval(v2)) | |
v4=json.loads(v3) | |
except: | |
v2=val.replace("\n","").replace("\r","") | |
v3=json.dumps(ast.literal_eval(val)) | |
#v3=v2.replace("'","\"") | |
v4=json.loads(v3) | |
#v4=json.loads(v2) | |
return v4 | |
def extractPython(self,val:str) -> Any: | |
"""Helper function to extract python from this LLMs output""" | |
#This is assuming the python is the first item within ```` | |
v2=val.replace("```python","```").split("```")[1] | |
return v2 |