File size: 2,491 Bytes
84bd7df
 
 
 
 
 
 
 
 
fd75f40
 
84bd7df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from typing import List, AsyncIterable, Annotated, Optional
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain_core.output_parsers import StrOutputParser
import asyncio
import datetime
import csv



load_dotenv()
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
GROQ_API_BASE = os.environ.get("GROQ_API_BASE")
GROQ_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME")

def read_pattern_files(pattern: str) -> (str, str):
    system_file = 'system.md'
    user_file = 'user.md'
    system_content = ""
    user_content = ""
    pattern_dir = "patterns"
    
    # Construct the full paths
    system_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, system_file))
    user_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, user_file))
    
    print(system_file_path)
    print(user_file_path)
    
    # Check if system.md exists
    if os.path.exists(system_file_path):
        with open(system_file_path, 'r') as file:
            system_content = file.read()
    
    # Check if user.md exists
    if os.path.exists(user_file_path):
        with open(user_file_path, 'r') as file:
            user_content = file.read()
    
    return system_content, user_content



async def generate_pattern(pattern: str, query: str) -> AsyncIterable[str] :

    callback = AsyncIteratorCallbackHandler()

    chat = ChatOpenAI(
        openai_api_base=GROQ_API_BASE,
        api_key=GROQ_API_KEY,
        temperature=0.0,
        model_name= "mixtral-8x7b-32768", #GROQ_MODEL_NAME,
        streaming=True,  # ! important
        verbose=True,
        callbacks=[callback]
    )

    system, usr_content = read_pattern_files(pattern=pattern)
    print('Sys Content -- > ')
    print(system)
    print('User Content --- > ')
    print(usr_content)

    human = usr_content + "{text}"
    prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

    chain = prompt | chat | StrOutputParser()
    

    task = asyncio.create_task(
        chain.ainvoke({"text": query})
    )
    index = 0
    try:
        async for token in callback.aiter():
            print(index, ": ", token, ": ", datetime.datetime.now().time())
            index = index + 1
            yield token
    except Exception as e:
        print(f"Caught exception: {e}")
    finally:
        callback.done.set()

    await task