BookWorld / modules /llm /Gemini.py
alienet's picture
first commit
e636070
from .BaseLLM import BaseLLM
import google.generativeai as genai
import os
import time
class Gemini(BaseLLM):
def __init__(self, model="gemini-1.5-flash"):
super(Gemini, self).__init__()
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
self.model_name = model
self.model = genai.GenerativeModel(model)
self.messages = []
def initialize_message(self):
self.messages = []
def ai_message(self, payload):
self.messages.append({"role": "model", "parts": payload})
def system_message(self, payload):
self.messages.append({"role": "system", "parts": payload})
def user_message(self, payload):
self.messages.append({"role": "user", "parts": payload})
def get_response(self,temperature = 0.8):
time.sleep(3)
chat = self.model.start_chat(
history = self.messages
)
response = chat.send_message(generation_config=genai.GenerationConfig(
temperature=temperature,
))
return response.text
def chat(self,text):
chat = self.model.start_chat()
response = chat.send_message(text)
return response.text
def print_prompt(self):
for message in self.messages:
print(message)