Spaces:
Runtime error
Runtime error
from typing import Any, List, Mapping, Optional, Dict | |
from pydantic import Extra, Field #, root_validator, model_validator | |
import os,json | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
import google.generativeai as palm | |
from google.generativeai import types | |
import ast | |
#from langchain.llms import GooglePalm | |
import requests | |
class PalmLLM(LLM): | |
model_name: str = "text-bison-001" | |
temperature: float = 0 | |
max_tokens: int = 2048 | |
stop: Optional[List] = [] | |
prev_prompt: Optional[str]="" | |
prev_stop: Optional[str]="" | |
prev_run_manager:Optional[Any]=None | |
def __init__( | |
self, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
palm.configure() | |
#self.model = palm.Text2Text(self.model_name) | |
def _llm_type(self) -> str: | |
return "text2text-generation" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
self.prev_prompt=prompt | |
self.prev_stop=stop | |
self.prev_run_manager=run_manager | |
if stop == None: | |
stop=self.stop | |
text=palm.generate_text(prompt=prompt,stop_sequences=self.stop, | |
temperature=self.temperature, max_output_tokens=self.max_tokens, | |
# safety_settings=[{"category":0,"threshold":4}, | |
# {"category":1,"threshold":4}, | |
# {"category":2,"threshold":4}, | |
# {"category":3,"threshold":4}, | |
# {"category":4,"threshold":4}, | |
# {"category":5,"threshold":4}, | |
# {"category":6, "threshold":4}] | |
safety_settings=[{"category":"HARM_CATEGORY_DEROGATORY","threshold":4}, | |
{"category":"HARM_CATEGORY_TOXICITY","threshold":4}, | |
{"category":"HARM_CATEGORY_VIOLENCE","threshold":4}, | |
{"category":"HARM_CATEGORY_SEXUAL","threshold":4}, | |
{"category":"HARM_CATEGORY_MEDICAL","threshold":4}, | |
{"category":"HARM_CATEGORY_DANGEROUS","threshold":4}] | |
) | |
print("Response from palm",text) | |
val=text.result | |
if run_manager: | |
run_manager.on_llm_end(val) | |
return val | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"name": self.model_name, "type": "palm"} | |
def extractJson(self,val:str) -> Any: | |
"""Helper function to extract json from this LLMs output""" | |
#This is assuming the json is the first item within ```` | |
# palm is responding always with ```json and ending with ```, however sometimes response is not complete | |
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it | |
try: | |
count=0 | |
while val.startswith("```json") and not val.endswith("```") and count<7: | |
val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager) | |
count+=1 | |
v2=val.replace("```json","```").split("```")[1] | |
#v3=v2.replace("\n","").replace("\r","").replace("'","\"") | |
v3=json.dumps(ast.literal_eval(v2)) | |
v4=json.loads(v3) | |
except: | |
v2=val.replace("\n","").replace("\r","") | |
v3=json.dumps(ast.literal_eval(val)) | |
#v3=v2.replace("'","\"") | |
v4=json.loads(v3) | |
#v4=json.loads(v2) | |
return v4 | |
def extractPython(self,val:str) -> Any: | |
"""Helper function to extract python from this LLMs output""" | |
#This is assuming the python is the first item within ```` | |
v2=val.replace("```python","```").split("```")[1] | |
return v2 |