MarioCerulo commited on
Commit
c2f1b61
·
verified ·
1 Parent(s): a8207c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from peft import PeftModel, PeftConfig
5
+ from chromadb import HttpClient
6
+ from utils.embedding_utils import CustomEmbeddingFunction
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ st.title("FormulAI Q&A")
10
+
11
+ model_name = "unsloth/Llama-3.2-1B"
12
+
13
+ model = AutoModelForCausalLM.from_pretrained(model_name)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ adapter_name = "FormulAI/FormuLLaMa-3.2-1B-LoRA"
17
+ peft_config = PeftConfig.from_pretrained(adapter_name)
18
+
19
+ model = PeftModel(model, peft_config)
20
+
21
+ template = """Answer the following QUESTION based on the CONTEXT given.
22
+ If you do not know the answer and the CONTEXT doesn't contain the answer truthfully say "I don't know".
23
+
24
+ CONTEXT:
25
+ {context}
26
+
27
+ QUESTION:
28
+ {question}
29
+
30
+ ANSWER:
31
+ """
32
+
33
+ if 'generated' not in st.session_state:
34
+ st.session_state['generated'] = []
35
+
36
+ if 'past' not in st.session_state:
37
+ st.session_state['past'] = []
38
+
39
+ def get_text():
40
+ input_text = st.text_input("Chiedi qualcosa: ", "", key="input")
41
+ return input_text
42
+
43
+ load_dotenv("chroma.env")
44
+ chroma_host = os.getenv("CHROMA_HOST", "localhost")
45
+ chroma_port = os.getenv("CHROMA_PORT", 8000)
46
+ chroma_collection = os.getenv("CHROMA_COLLECTION", "F1-wiki")
47
+
48
+ chroma_client = HttpClient(host=chroma_host, port=chroma_port)
49
+
50
+ collection = chroma_client.get_collection(name="F1-wiki", embedding_function=CustomEmbeddingFunction())
51
+
52
+ question = get_text()
53
+
54
+ if question:
55
+ response = collection.query(query_texts=question, include=['documents'], n_results=5)
56
+
57
+ context = " ".join(response['documents'][0])
58
+
59
+ input_text = template.replace("{context}", context).replace("{question}", question)
60
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
61
+
62
+ output = model.generate(input_ids, max_new_tokens=200, early_stopping=True)
63
+ answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1]
64
+
65
+
66
+ st.session_state.past.append(question)
67
+ st.session_state.generated.append(answer)
68
+
69
+ st.write(answer)