vivekvar commited on
Commit
6b5d076
·
verified ·
1 Parent(s): 58cd55d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,30 +1,33 @@
1
  import streamlit as st
2
- from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate
3
  from llama_index.llms.huggingface import HuggingFaceInferenceAPI
4
  from dotenv import load_dotenv
5
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
- from llama_index.core import Settings
7
  import os
8
  import base64
9
 
10
  # Load environment variables
11
  load_dotenv()
12
 
13
- # Configure the Llama index settings for using Hugging Face LLaMA model
14
- Settings.llm = HuggingFaceInferenceAPI(
15
- model_name="facebook/bedrock-llama-7b", # Use LLaMA 7B model here
16
- tokenizer_name="facebook/bedrock-llama-7b", # Tokenizer for the LLaMA model
17
- context_window=30000, # Set context window size (adjust if necessary)
18
  api_token=os.getenv("HF_TOKEN"), # Hugging Face API Token
19
  max_new_tokens=512,
20
- generate_kwargs={"temperature": 0.1}, # Control the generation temperature
21
  )
22
 
23
- # Set up Hugging Face Embedding model to use powerful LLaMA model
24
- Settings.embed_model = HuggingFaceEmbedding(
25
- model_name="facebook/bedrock-llama-7b" # Powerful model for embeddings
26
  )
27
 
 
 
 
28
  # Define the directory for persistent storage and data
29
  PERSIST_DIR = "./db"
30
  DATA_DIR = "data"
@@ -41,13 +44,13 @@ def displayPDF(file):
41
 
42
  def data_ingestion():
43
  documents = SimpleDirectoryReader(DATA_DIR).load_data()
44
- storage_context = StorageContext.from_defaults()
45
- index = VectorStoreIndex.from_documents(documents)
46
- index.storage_context.persist(persist_dir=PERSIST_DIR)
47
 
48
  def handle_query(query):
49
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
50
- index = load_index_from_storage(storage_context)
51
  chat_text_qa_msgs = [
52
  (
53
  "user",
@@ -94,4 +97,4 @@ if user_prompt:
94
 
95
  for message in st.session_state.messages:
96
  with st.chat_message(message['role']):
97
- st.write(message['content'])
 
1
  import streamlit as st
2
+ from llama_index import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate
3
  from llama_index.llms.huggingface import HuggingFaceInferenceAPI
4
  from dotenv import load_dotenv
5
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from llama_index import set_global_service_context
7
  import os
8
  import base64
9
 
10
  # Load environment variables
11
  load_dotenv()
12
 
13
+ # Configure the Llama index settings for using Hugging Face model
14
+ llm = HuggingFaceInferenceAPI(
15
+ model_name="bigscience/bloom-7b1", # Use a model available on Hugging Face Inference API
16
+ tokenizer_name="bigscience/bloom-7b1",
17
+ context_window=2048, # Adjust context window based on the model
18
  api_token=os.getenv("HF_TOKEN"), # Hugging Face API Token
19
  max_new_tokens=512,
20
+ generate_kwargs={"temperature": 0.1},
21
  )
22
 
23
+ # Set up Hugging Face Embedding model
24
+ embed_model = HuggingFaceEmbedding(
25
+ model_name="sentence-transformers/all-MiniLM-L6-v2" # Use a suitable embedding model
26
  )
27
 
28
+ # Set global service context
29
+ service_context = set_global_service_context(llm=llm, embed_model=embed_model)
30
+
31
  # Define the directory for persistent storage and data
32
  PERSIST_DIR = "./db"
33
  DATA_DIR = "data"
 
44
 
45
  def data_ingestion():
46
  documents = SimpleDirectoryReader(DATA_DIR).load_data()
47
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
48
+ index = VectorStoreIndex.from_documents(documents, service_context=service_context)
49
+ index.storage_context.persist()
50
 
51
  def handle_query(query):
52
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
53
+ index = load_index_from_storage(storage_context, service_context=service_context)
54
  chat_text_qa_msgs = [
55
  (
56
  "user",
 
97
 
98
  for message in st.session_state.messages:
99
  with st.chat_message(message['role']):
100
+ st.write(message['content'])