user commited on
Commit
cc19159
·
1 Parent(s): 576b273

Fix UnhashableParamError and improve caching in create_embeddings

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -27,12 +27,12 @@ def load_and_process_text(file_path):
27
  return []
28
 
29
  @st.cache_data
30
- def create_embeddings(chunks, tokenizer, embedding_model):
31
  embeddings = []
32
  for chunk in chunks:
33
  inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
34
  with torch.no_grad():
35
- outputs = embedding_model(**inputs)
36
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
37
  return np.array(embeddings)
38
 
@@ -80,7 +80,7 @@ st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask question
80
  with st.spinner("Loading models and data..."):
81
  tokenizer, embedding_model, generation_model = load_models()
82
  chunks = load_and_process_text('ammons_muse.txt')
83
- embeddings = create_embeddings(chunks, tokenizer, embedding_model)
84
  index = create_faiss_index(embeddings)
85
 
86
  if tokenizer is None or embedding_model is None or generation_model is None or not chunks:
 
27
  return []
28
 
29
  @st.cache_data
30
+ def create_embeddings(chunks, _embedding_model):
31
  embeddings = []
32
  for chunk in chunks:
33
  inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
34
  with torch.no_grad():
35
+ outputs = _embedding_model(**inputs)
36
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
37
  return np.array(embeddings)
38
 
 
80
  with st.spinner("Loading models and data..."):
81
  tokenizer, embedding_model, generation_model = load_models()
82
  chunks = load_and_process_text('ammons_muse.txt')
83
+ embeddings = create_embeddings(chunks, embedding_model)
84
  index = create_faiss_index(embeddings)
85
 
86
  if tokenizer is None or embedding_model is None or generation_model is None or not chunks: