Spaces:
Sleeping
Sleeping
File size: 3,679 Bytes
8b6399b 242bba0 8b6399b 242bba0 8b6399b 16dcc46 8b6399b 242bba0 16dcc46 242bba0 16dcc46 242bba0 16dcc46 242bba0 8b6399b 242bba0 8b6399b 242bba0 8b6399b 242bba0 8b6399b 242bba0 8b6399b 242bba0 df26c41 242bba0 df26c41 242bba0 df26c41 242bba0 8b6399b 242bba0 8b6399b 242bba0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
import os
from io import StringIO
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import Document
import uuid
from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
embed_model_name = st.text_input(
'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
llm_model_name = st.text_input(
'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
query = st.text_input(
'Query', "What is the price of the product?")
html_file = st.file_uploader("Upload a html file", type=["html"])
if st.button('Start Pipeline'):
if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
st.write('Running Pipeline')
llm = HuggingFaceInferenceAPI(
model_name=llm_model_name, token=inference_api_key)
embed_model = HuggingFaceInferenceAPIEmbedding(
model_name=embed_model_name,
token=inference_api_key,
model_kwargs={"device": ""},
encode_kwargs={"normalize_embeddings": True},
)
service_context = ServiceContext.from_defaults(
embed_model=embed_model, llm=llm)
stringio = StringIO(html_file.getvalue().decode("utf-8"))
string_data = stringio.read()
with st.expander("Uploaded HTML"):
st.write(string_data)
document_id = str(uuid.uuid4())
document = Document(text=string_data)
document.metadata["id"] = document_id
documents = [document]
filters = MetadataFilters(
filters=[ExactMatchFilter(key="id", value=document_id)])
index = VectorStoreIndex.from_documents(
documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
retriever = index.as_retriever()
ranked_nodes = retriever.retrieve(
query)
with st.expander("Ranked Nodes"):
for node in ranked_nodes:
st.write(node.node.get_content(), "-> Score:", node.score)
query_engine = index.as_query_engine(
filters=filters, service_context=service_context)
response = query_engine.query(query)
st.write(response)
else:
st.error('Please fill in all the fields')
else:
st.write('Press start to begin')
# if html_file is not None:
# stringio = StringIO(html_file.getvalue().decode("utf-8"))
# string_data = stringio.read()
# with st.expander("Uploaded HTML"):
# st.write(string_data)
# document_id = str(uuid.uuid4())
# document = Document(text=string_data)
# document.metadata["id"] = document_id
# documents = [document]
# filters = MetadataFilters(
# filters=[ExactMatchFilter(key="id", value=document_id)])
# index = VectorStoreIndex.from_documents(
# documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
# retriever = index.as_retriever()
# ranked_nodes = retriever.retrieve(
# "Get me all the information about the product")
# with st.expander("Ranked Nodes"):
# for node in ranked_nodes:
# st.write(node.node.get_content(), "-> Score:", node.score)
# query_engine = index.as_query_engine(
# filters=filters, service_context=service_context)
# response = query_engine.query(
# "Get me all the information about the product")
# st.write(response)
|