marigen_api / get_pattern.py
jameszokah's picture
Update get_pattern.py
a7571eb verified
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from typing import List, AsyncIterable, Annotated, Optional
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain_core.output_parsers import StrOutputParser
import asyncio
import datetime
import csv
load_dotenv()
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
GROQ_API_BASE = os.environ.get("GROQ_API_BASE")
GROQ_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME")
def read_pattern_files(pattern: str) -> (str, str):
system_file = 'system.md'
user_file = 'user.md'
system_content = ""
user_content = ""
pattern_dir = "patterns"
# Construct the full paths
system_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, system_file))
user_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, user_file))
print(system_file_path)
print(user_file_path)
# Check if system.md exists
if os.path.exists(system_file_path):
with open(system_file_path, 'r') as file:
system_content = file.read()
# Check if user.md exists
if os.path.exists(user_file_path):
with open(user_file_path, 'r') as file:
user_content = file.read()
return system_content, user_content
async def generate_pattern(pattern: str, query: str) -> AsyncIterable[str] :
callback = AsyncIteratorCallbackHandler()
chat = ChatOpenAI(
openai_api_base=GROQ_API_BASE,
api_key=GROQ_API_KEY,
temperature=0.0,
model_name= GROQ_MODEL_NAME, #"mixtral-8x7b-32768", #GROQ_MODEL_NAME,
streaming=True, # ! important
verbose=True,
callbacks=[callback]
)
system, usr_content = read_pattern_files(pattern=pattern)
print('Sys Content -- > ')
print(system)
print('User Content --- > ')
print(usr_content)
human = usr_content + "{text}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
chain = prompt | chat | StrOutputParser()
task = asyncio.create_task(
chain.ainvoke({"text": query})
)
index = 0
try:
async for token in callback.aiter():
print(index, ": ", token, ": ", datetime.datetime.now().time())
index = index + 1
yield token
except Exception as e:
print(f"Caught exception: {e}")
finally:
callback.done.set()
await task