Spaces:
Sleeping
Sleeping
File size: 2,491 Bytes
84bd7df fd75f40 84bd7df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from typing import List, AsyncIterable, Annotated, Optional
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain_core.output_parsers import StrOutputParser
import asyncio
import datetime
import csv
load_dotenv()
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
GROQ_API_BASE = os.environ.get("GROQ_API_BASE")
GROQ_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME")
def read_pattern_files(pattern: str) -> (str, str):
system_file = 'system.md'
user_file = 'user.md'
system_content = ""
user_content = ""
pattern_dir = "patterns"
# Construct the full paths
system_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, system_file))
user_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, user_file))
print(system_file_path)
print(user_file_path)
# Check if system.md exists
if os.path.exists(system_file_path):
with open(system_file_path, 'r') as file:
system_content = file.read()
# Check if user.md exists
if os.path.exists(user_file_path):
with open(user_file_path, 'r') as file:
user_content = file.read()
return system_content, user_content
async def generate_pattern(pattern: str, query: str) -> AsyncIterable[str] :
callback = AsyncIteratorCallbackHandler()
chat = ChatOpenAI(
openai_api_base=GROQ_API_BASE,
api_key=GROQ_API_KEY,
temperature=0.0,
model_name= "mixtral-8x7b-32768", #GROQ_MODEL_NAME,
streaming=True, # ! important
verbose=True,
callbacks=[callback]
)
system, usr_content = read_pattern_files(pattern=pattern)
print('Sys Content -- > ')
print(system)
print('User Content --- > ')
print(usr_content)
human = usr_content + "{text}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
chain = prompt | chat | StrOutputParser()
task = asyncio.create_task(
chain.ainvoke({"text": query})
)
index = 0
try:
async for token in callback.aiter():
print(index, ": ", token, ": ", datetime.datetime.now().time())
index = index + 1
yield token
except Exception as e:
print(f"Caught exception: {e}")
finally:
callback.done.set()
await task
|