File size: 3,733 Bytes
8b6399b
 
 
 
 
 
 
 
 
 
 
 
242bba0
 
8b6399b
242bba0
 
8b6399b
16dcc46
 
 
8b6399b
 
242bba0
16dcc46
242bba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16dcc46
242bba0
 
 
 
 
 
 
 
16dcc46
242bba0
5b24b6b
 
 
242bba0
 
 
 
 
 
 
 
 
 
 
8b6399b
242bba0
8b6399b
242bba0
 
 
8b6399b
242bba0
 
8b6399b
242bba0
 
8b6399b
242bba0
df26c41
242bba0
 
df26c41
242bba0
 
 
df26c41
242bba0
 
8b6399b
242bba0
 
8b6399b
242bba0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import os
from io import StringIO
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import Document
import uuid
from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter

inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]

embed_model_name = st.text_input(
    'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")

llm_model_name = st.text_input(
    'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")

query = st.text_input(
    'Query', "What is the price of the product?")

html_file = st.file_uploader("Upload a html file", type=["html"])

if st.button('Start Pipeline'):
    if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
        st.write('Running Pipeline')
        llm = HuggingFaceInferenceAPI(
            model_name=llm_model_name, token=inference_api_key)

        embed_model = HuggingFaceInferenceAPIEmbedding(
            model_name=embed_model_name,
            token=inference_api_key,
            model_kwargs={"device": ""},
            encode_kwargs={"normalize_embeddings": True},
        )

        service_context = ServiceContext.from_defaults(
            embed_model=embed_model, llm=llm)

        stringio = StringIO(html_file.getvalue().decode("utf-8"))
        string_data = stringio.read()
        with st.expander("Uploaded HTML"):
            st.write(string_data)

        document_id = str(uuid.uuid4())

        document = Document(text=string_data)
        document.metadata["id"] = document_id
        documents = [document]

        filters = MetadataFilters(
            filters=[ExactMatchFilter(key="id", value=document_id)])

        index = VectorStoreIndex.from_documents(
            documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

        retriever = index.as_retriever()

        ranked_nodes = retriever.retrieve(
            query)

        with st.expander("Ranked Nodes"):
            for node in ranked_nodes:
                st.write(node.node.get_content(), "-> Score:", node.score)

        query_engine = index.as_query_engine(
            filters=filters, service_context=service_context)

        response = query_engine.query(query)

        st.write(response.response_txt)

        st.write(response.source_nodes)

    else:
        st.error('Please fill in all the fields')
else:
    st.write('Press start to begin')

# if html_file is not None:
#     stringio = StringIO(html_file.getvalue().decode("utf-8"))
#     string_data = stringio.read()
#     with st.expander("Uploaded HTML"):
#         st.write(string_data)

#     document_id = str(uuid.uuid4())

#     document = Document(text=string_data)
#     document.metadata["id"] = document_id
#     documents = [document]

#     filters = MetadataFilters(
#         filters=[ExactMatchFilter(key="id", value=document_id)])

#     index = VectorStoreIndex.from_documents(
#         documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

#     retriever = index.as_retriever()

#     ranked_nodes = retriever.retrieve(
#         "Get me all the information about the product")

#     with st.expander("Ranked Nodes"):
#         for node in ranked_nodes:
#             st.write(node.node.get_content(), "-> Score:", node.score)

#     query_engine = index.as_query_engine(
#         filters=filters, service_context=service_context)

#     response = query_engine.query(
#         "Get me all the information about the product")

#     st.write(response)