Spaces:
Sleeping
Sleeping
import os | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_openai import ChatOpenAI | |
from dotenv import load_dotenv | |
from typing import List, AsyncIterable, Annotated, Optional | |
from langchain.callbacks import AsyncIteratorCallbackHandler | |
from langchain_core.output_parsers import StrOutputParser | |
import asyncio | |
import datetime | |
import csv | |
load_dotenv() | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
GROQ_API_BASE = os.environ.get("GROQ_API_BASE") | |
GROQ_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") | |
def read_pattern_files(pattern: str) -> (str, str): | |
system_file = 'system.md' | |
user_file = 'user.md' | |
system_content = "" | |
user_content = "" | |
pattern_dir = "patterns" | |
# Construct the full paths | |
system_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, system_file)) | |
user_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, user_file)) | |
print(system_file_path) | |
print(user_file_path) | |
# Check if system.md exists | |
if os.path.exists(system_file_path): | |
with open(system_file_path, 'r') as file: | |
system_content = file.read() | |
# Check if user.md exists | |
if os.path.exists(user_file_path): | |
with open(user_file_path, 'r') as file: | |
user_content = file.read() | |
return system_content, user_content | |
async def generate_pattern(pattern: str, query: str) -> AsyncIterable[str] : | |
callback = AsyncIteratorCallbackHandler() | |
chat = ChatOpenAI( | |
openai_api_base=GROQ_API_BASE, | |
api_key=GROQ_API_KEY, | |
temperature=0.0, | |
model_name= GROQ_MODEL_NAME, #"mixtral-8x7b-32768", #GROQ_MODEL_NAME, | |
streaming=True, # ! important | |
verbose=True, | |
callbacks=[callback] | |
) | |
system, usr_content = read_pattern_files(pattern=pattern) | |
print('Sys Content -- > ') | |
print(system) | |
print('User Content --- > ') | |
print(usr_content) | |
human = usr_content + "{text}" | |
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) | |
chain = prompt | chat | StrOutputParser() | |
task = asyncio.create_task( | |
chain.ainvoke({"text": query}) | |
) | |
index = 0 | |
try: | |
async for token in callback.aiter(): | |
print(index, ": ", token, ": ", datetime.datetime.now().time()) | |
index = index + 1 | |
yield token | |
except Exception as e: | |
print(f"Caught exception: {e}") | |
finally: | |
callback.done.set() | |
await task | |