File size: 2,568 Bytes
c2f1b61
3cbf55f
 
c2f1b61
3cbf55f
c2f1b61
 
 
 
 
 
3cbf55f
c2f1b61
3cbf55f
 
c2f1b61
 
3cbf55f
 
c2f1b61
 
3cbf55f
c2f1b61
 
3cbf55f
c2f1b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd76728
c2f1b61
 
 
 
 
 
 
 
 
 
 
 
 
54eb26a
 
 
c2f1b61
54eb26a
 
c2f1b61
54eb26a
c2f1b61
54eb26a
 
 
c2f1b61
54eb26a
 
c2f1b61
54eb26a
 
c2f1b61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch

import streamlit as st

from dotenv import load_dotenv
from peft import PeftModel, PeftConfig
from chromadb import HttpClient
from utils.embedding_utils import CustomEmbeddingFunction
from transformers import AutoModelForCausalLM, AutoTokenizer

st.title("FormulAI")

# Device and model configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "unsloth/Llama-3.2-1B"

# Load pretrained model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load PEFT configuration and apply to model on device
adapter_name = "FormulAI/FormuLLaMa-3.2-1B-LoRA"
peft_config = PeftConfig.from_pretrained(adapter_name)
model = PeftModel(model, peft_config).to(device)

template = """Answer the following QUESTION based on the CONTEXT given.
If you do not know the answer and the CONTEXT doesn't contain the answer truthfully say "I don't know".

CONTEXT:
{context}

QUESTION:
{question}

ANSWER:
"""

if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

def get_text():
    input_text = st.text_input("Ask a question regarding Formula 1: ", "", key="input")
    return input_text 

load_dotenv("chroma.env")
chroma_host = os.getenv("CHROMA_HOST", "localhost")
chroma_port = os.getenv("CHROMA_PORT", 8000)
chroma_collection = os.getenv("CHROMA_COLLECTION", "F1-wiki")

chroma_client = HttpClient(host=chroma_host, port=chroma_port)

collection = chroma_client.get_collection(name="F1-wiki", embedding_function=CustomEmbeddingFunction())

question = get_text()

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

if question:
    with st.spinner("Generating answer... "):
        response = collection.query(query_texts=question, include=['documents'], n_results=5)

        context = " ".join(response['documents'][0])

        input_text = template.replace("{context}", context).replace("{question}", question)
        input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
        attention_mask = (input_ids != tokenizer.pad_token_id).to(device)

        output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, early_stopping=True)
        answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1].strip()

        st.session_state.past.append(question)
        st.session_state.generated.append(answer)

    st.write(answer)