nolanzandi commited on
Commit
554d139
·
verified ·
1 Parent(s): 32f5b77

Upload 2 files

Browse files
Files changed (2) hide show
  1. pipelines.py +91 -0
  2. sqlite_functions.py +35 -0
pipelines.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack import Pipeline
2
+ from haystack.components.builders import PromptBuilder
3
+ from haystack.components.generators.openai import OpenAIGenerator
4
+ from haystack.components.routers import ConditionalRouter
5
+
6
+ from functions import SQLiteQuery
7
+
8
+ from typing import List
9
+ import sqlite3
10
+
11
+ import os
12
+ from getpass import getpass
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ if "OPENAI_API_KEY" not in os.environ:
18
+ os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
19
+
20
+ from haystack.components.builders import PromptBuilder
21
+ from haystack.components.generators import OpenAIGenerator
22
+
23
+ llm = OpenAIGenerator(model="gpt-4o")
24
+ sql_query = SQLiteQuery('data_source.db')
25
+
26
+ connection = sqlite3.connect('data_source.db')
27
+ cur=connection.execute('select * from data_source')
28
+ columns = [i[0] for i in cur.description]
29
+ cur.close()
30
+
31
+ #Rag Pipeline
32
+ prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
33
+ If the question cannot be answered given the provided table and columns, return 'no_answer'
34
+ The query is to be answered for the table is called 'data_source' with the following
35
+ Columns: {{columns}};
36
+ Answer:""")
37
+
38
+ routes = [
39
+ {
40
+ "condition": "{{'no_answer' not in replies[0]}}",
41
+ "output": "{{replies}}",
42
+ "output_name": "sql",
43
+ "output_type": List[str],
44
+ },
45
+ {
46
+ "condition": "{{'no_answer' in replies[0]}}",
47
+ "output": "{{question}}",
48
+ "output_name": "go_to_fallback",
49
+ "output_type": str,
50
+ },
51
+ ]
52
+
53
+ router = ConditionalRouter(routes)
54
+
55
+ fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answered with the given table.
56
+ The query was: {{question}} and the table had columns: {{columns}}.
57
+ Let the user know why the question cannot be answered""")
58
+ fallback_llm = OpenAIGenerator(model="gpt-4")
59
+
60
+ conditional_sql_pipeline = Pipeline()
61
+ conditional_sql_pipeline.add_component("prompt", prompt)
62
+ conditional_sql_pipeline.add_component("llm", llm)
63
+ conditional_sql_pipeline.add_component("router", router)
64
+ conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
65
+ conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
66
+ conditional_sql_pipeline.add_component("sql_querier", sql_query)
67
+
68
+ conditional_sql_pipeline.connect("prompt", "llm")
69
+ conditional_sql_pipeline.connect("llm.replies", "router.replies")
70
+ conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
71
+ conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
72
+ conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
73
+
74
+ def rag_pipeline_func(queries: str, columns: str):
75
+ print("RAG PIPELINE FUNCTION")
76
+ result = conditional_sql_pipeline.run({"prompt": {"question": queries,
77
+ "columns": columns},
78
+ "router": {"question": queries},
79
+ "fallback_prompt": {"columns": columns}})
80
+
81
+ if 'sql_querier' in result:
82
+ reply = result['sql_querier']['results'][0]
83
+ elif 'fallback_llm' in result:
84
+ reply = result['fallback_llm']['replies'][0]
85
+ else:
86
+ reply = result["llm"]["replies"][0]
87
+
88
+ print("reply content")
89
+ print(reply.content)
90
+
91
+ return {"reply": reply.content}
sqlite_functions.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from haystack import component
3
+ import pandas as pd
4
+ import sqlite3
5
+
6
+ @component
7
+ class SQLiteQuery:
8
+
9
+ def __init__(self, sql_database: str):
10
+ self.connection = sqlite3.connect(sql_database, check_same_thread=False)
11
+
12
+ @component.output_types(results=List[str], queries=List[str])
13
+ def run(self, queries: List[str]):
14
+ print("ATTEMPTING TO RUN QUERY")
15
+ results = []
16
+ for query in queries:
17
+ result = pd.read_sql(query, self.connection)
18
+ results.append(f"{result}")
19
+ "self.connection.close()"
20
+ return {"results": results, "queries": queries}
21
+
22
+
23
+ sql_query = SQLiteQuery('data_source.db')
24
+
25
+ def sqlite_query_func(queries: List[str]):
26
+ try:
27
+ result = sql_query.run(queries)
28
+ return {"reply": result["results"][0]}
29
+
30
+ except Exception as e:
31
+ reply = f"""There was an error running the SQL Query = {queries}
32
+ The error is {e},
33
+ You should probably try again.
34
+ """
35
+ return {"reply": reply}