File size: 3,431 Bytes
dc66050
 
3674844
8586ebb
f1aaba3
 
 
 
 
 
 
 
 
 
 
af7f1fe
f1aaba3
cf29376
f1aaba3
592b501
 
 
ee14926
e1a03f2
 
ee14926
cf29376
 
 
 
 
5992538
71a1eda
ee14926
 
 
 
 
 
 
 
 
 
 
1a788a2
ee14926
 
7fdbbb2
1b0cc5f
ee14926
7a8951a
7fdbbb2
1a788a2
 
 
 
 
 
 
ee14926
 
 
 
dc66050
 
97413fe
ee14926
2d1e4f3
ee14926
f7947fc
2d1e4f3
 
 
3d2d75d
ee14926
2d1e4f3
 
 
 
 
 
 
 
 
 
ee14926
2d1e4f3
 
 
ee14926
 
 
 
 
 
 
 
 
 
 
 
2d1e4f3
 
3d2d75d
28fb60e
7a8951a
1b0cc5f
7a8951a
7fdbbb2
8586ebb
 
 
97413fe
7fdbbb2
e95efe1
1b0cc5f
e95efe1
 
 
 
 
 
 
 
 
c25281e
e95efe1
 
8bbbb33
e95efe1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import tool, CodeAgent, HfApiModel, GradioUI
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
    select,
)
import spaces

from dotenv import load_dotenv

load_dotenv()

# What is the average each customer paid? Create a sql statement and invoke your sql_engine tool


@spaces.GPU
def dummy():
    pass


@tool
def sql_engine_tool(query: str, engine: any) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'. Its description is as follows:
        Columns:
        - receipt_id: INTEGER
        - customer_name: VARCHAR(16)
        - price: FLOAT
        - tip: FLOAT

    Args:
        query: The query to perform. This should be correct SQL.
        engine: just use engine object as declared later
    """
    output = ""
    print("debug sql_engine_tool")
    print(engine)
    with engine.connect() as con:
        print(con.connection)
        print(metadata_objects.tables.keys())
        result = con.execute(
            text(
                "SELECT name FROM sqlite_master WHERE type='table' AND name='receipts'"
            )
        )
        print("tables available:", result.fetchone())

        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output


def init_db(engine):

    metadata_obj = MetaData()

    def insert_rows_into_table(rows, table, engine=engine):
        for row in rows:
            stmt = insert(table).values(**row)
            with engine.begin() as connection:
                connection.execute(stmt)

    table_name = "receipts"
    receipts = Table(
        table_name,
        metadata_obj,
        Column("receipt_id", Integer, primary_key=True),
        Column("customer_name", String(16), primary_key=True),
        Column("price", Float),
        Column("tip", Float),
    )
    metadata_obj.create_all(engine)

    rows = [
        {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
        {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
        {
            "receipt_id": 3,
            "customer_name": "Woodrow Wilson",
            "price": 53.43,
            "tip": 5.43,
        },
        {
            "receipt_id": 4,
            "customer_name": "Margaret James",
            "price": 21.11,
            "tip": 1.00,
        },
    ]
    insert_rows_into_table(rows, receipts)
    with engine.begin() as conn:
        print("SELECT test", conn.execute(text("SELECT * FROM receipts")).fetchall())
    print("init_db debug")
    print(engine)
    print()
    return engine, metadata_obj


if __name__ == "__main__":
    engine = create_engine("sqlite:///:memory:")
    engine, metadata_objects = init_db(engine)
    model = HfApiModel(
        model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        token=os.getenv("my_first_agents_hf_tokens"),
    )

    agent = CodeAgent(
        tools=[sql_engine_tool],
        #         system_prompt="""
        # You are a text to sql converter
        # """,
        model=model,
        max_steps=1,
        verbosity_level=1,
    )
    # agent.run("What is the average each customer paid?")
    GradioUI(agent).launch()