anubhav77 commited on
Commit
37419af
·
1 Parent(s): c53e220
Files changed (3) hide show
  1. src/llm/geminiLLM.py +114 -0
  2. src/llm/llmFactory.py +3 -0
  3. src/main.py +2 -2
src/llm/geminiLLM.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional, Dict
2
+ from pydantic import Extra, Field #, root_validator, model_validator
3
+ import os,json
4
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
5
+ from langchain.llms.base import LLM
6
+ import google.generativeai as genai
7
+ from google.generativeai import types
8
+ import ast
9
+ #from langchain.llms import GooglePalm
10
+ import requests,logging
11
+
12
+ logger=logging.getLogger("llm")
13
+
14
+ class GeminiLLM(LLM):
15
+
16
+ model_name: str = "gemini-pro"
17
+ temperature: float = 0
18
+ max_tokens: int = 2048
19
+ stop: Optional[List] = []
20
+ prev_prompt: Optional[str]=""
21
+ prev_stop: Optional[str]=""
22
+ prev_run_manager:Optional[Any]=None
23
+ model: Optional[Any]=None
24
+
25
+ def __init__(
26
+ self,
27
+ **kwargs
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.model=genai.GenerativeModel(self.model_name)
31
+ #self.model = palm.Text2Text(self.model_name)
32
+
33
+ @property
34
+ def _llm_type(self) -> str:
35
+ return "text2text-generation"
36
+
37
+ def _call(
38
+ self,
39
+ prompt: str,
40
+ stop: Optional[List[str]] = None,
41
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
42
+ ) -> str:
43
+ self.prev_prompt=prompt
44
+ self.prev_stop=stop
45
+ self.prev_run_manager=run_manager
46
+ #print(types.SafetySettingDict)
47
+ if stop == None:
48
+ stop=self.stop
49
+ logger.debug("\nLLM in use is:" +self._llm_type)
50
+ logger.debug("Request to LLM is "+prompt)
51
+
52
+ response=self.model.generate_content(prompt,
53
+ generation_config={"stop_sequences":self.stop,
54
+ "temperature":self.temperature, "max_output_tokens":self.max_tokens},
55
+ safety_settings=[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_NONE"},
56
+ {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_NONE"},
57
+ {"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_NONE"},
58
+ {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_NONE"}],
59
+ stream=False
60
+ )
61
+ try:
62
+ val=response.text
63
+ if val == None:
64
+ logger.debug("Response from LLM was None\n")
65
+ filterStr=""
66
+ for item in response.filters:
67
+ for key,val in item.items():
68
+ filterStr+=key+":"+str(val)
69
+ logger.error("Will switch to fallback LLM as response from palm is None::"+filterStr)
70
+ raise(Exception)
71
+ else:
72
+ logger.debug("Response from LLM "+val)
73
+ except Exception as ex:
74
+ logger.error("Will switch to fallback LLM as response from palm is None::")
75
+ raise(Exception)
76
+ if run_manager:
77
+ run_manager.on_llm_end(val)
78
+ return val
79
+
80
+ @property
81
+ def _identifying_params(self) -> Mapping[str, Any]:
82
+ """Get the identifying parameters."""
83
+ return {"name": self.model_name, "type": "palm"}
84
+
85
+ def extractJson(self,val:str) -> Any:
86
+ """Helper function to extract json from this LLMs output"""
87
+ #This is assuming the json is the first item within ````
88
+ # palm is responding always with ```json and ending with ```, however sometimes response is not complete
89
+ # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
90
+ try:
91
+ count=0
92
+ while val.startswith("```json") and not val.endswith("```") and count<7:
93
+ val=self._call(prompt=self.prev_prompt+" "+val,stop=self.prev_stop,run_manager=self.prev_run_manager)
94
+ count+=1
95
+ v2=val.replace("```json","```").split("```")[1]
96
+ try:
97
+ v4=json.loads(v2)
98
+ except:
99
+ #v3=v2.replace("\n","").replace("\r","").replace("'","\"")
100
+ v3=json.dumps(ast.literal_eval(v2))
101
+ v4=json.loads(v3)
102
+ except:
103
+ v2=val.replace("\n","").replace("\r","")
104
+ v3=json.dumps(ast.literal_eval(val))
105
+ #v3=v2.replace("'","\"")
106
+ v4=json.loads(v3)
107
+ #v4=json.loads(v2)
108
+ return v4
109
+
110
+ def extractPython(self,val:str) -> Any:
111
+ """Helper function to extract python from this LLMs output"""
112
+ #This is assuming the python is the first item within ````
113
+ v2=val.replace("```python","```").split("```")[1]
114
+ return v2
src/llm/llmFactory.py CHANGED
@@ -3,6 +3,7 @@ from baseInfra.dbInterface import DbInterface
3
  from llm.hostedLLM import HostedLLM
4
  from llm.togetherLLM import TogetherLLM
5
  from llm.palmLLM import PalmLLM
 
6
 
7
 
8
  class LLMFactory:
@@ -49,6 +50,8 @@ class LLMFactory:
49
  return TogetherLLM(**llm_config)
50
  elif llm_type == "palmLLM":
51
  return PalmLLM(**llm_config)
 
 
52
  else:
53
  logger.error(f"Invalid LLM type: {llm_type}")
54
  raise ValueError(f"Invalid LLM type: {llm_type}")
 
3
  from llm.hostedLLM import HostedLLM
4
  from llm.togetherLLM import TogetherLLM
5
  from llm.palmLLM import PalmLLM
6
+ from llm.geminiLLM import GeminiLLM
7
 
8
 
9
  class LLMFactory:
 
50
  return TogetherLLM(**llm_config)
51
  elif llm_type == "palmLLM":
52
  return PalmLLM(**llm_config)
53
+ elif llm_type == "geminiLLM":
54
+ return GeminiLLM(**llm_config)
55
  else:
56
  logger.error(f"Invalid LLM type: {llm_type}")
57
  raise ValueError(f"Invalid LLM type: {llm_type}")
src/main.py CHANGED
@@ -3,7 +3,7 @@ import logging,os
3
  import fastapi
4
  from fastapi import Body, Depends
5
  import uvicorn
6
- from fastapi import HTTPException , status
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi import FastAPI as Response
@@ -60,7 +60,7 @@ app.add_middleware(
60
  api_base="/api/v1"
61
 
62
  @app.post(api_base+"/getMatchingDocs")
63
- async def get_matching_docs(inStr: str, kwargs: Dict [Any, Any] ) -> Any:
64
  """
65
  Gets the query embeddings and uses metadata appropriately and gets the matching docs for query
66
  TODO: Add parameter for type of query and number of docs to return
 
3
  import fastapi
4
  from fastapi import Body, Depends
5
  import uvicorn
6
+ from fastapi import BackgroundTasks,HTTPException , status
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi import FastAPI as Response
 
60
  api_base="/api/v1"
61
 
62
  @app.post(api_base+"/getMatchingDocs")
63
+ async def get_matching_docs(inStr: str, kwargs: Dict [Any, Any] ,background_tasks:BackgroundTasks) -> Any:
64
  """
65
  Gets the query embeddings and uses metadata appropriately and gets the matching docs for query
66
  TODO: Add parameter for type of query and number of docs to return