Spaces:
Paused
Paused
File size: 2,568 Bytes
c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 3cbf55f c2f1b61 bd76728 c2f1b61 54eb26a c2f1b61 54eb26a c2f1b61 54eb26a c2f1b61 54eb26a c2f1b61 54eb26a c2f1b61 54eb26a c2f1b61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import torch
import streamlit as st
from dotenv import load_dotenv
from peft import PeftModel, PeftConfig
from chromadb import HttpClient
from utils.embedding_utils import CustomEmbeddingFunction
from transformers import AutoModelForCausalLM, AutoTokenizer
st.title("FormulAI")
# Device and model configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "unsloth/Llama-3.2-1B"
# Load pretrained model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load PEFT configuration and apply to model on device
adapter_name = "FormulAI/FormuLLaMa-3.2-1B-LoRA"
peft_config = PeftConfig.from_pretrained(adapter_name)
model = PeftModel(model, peft_config).to(device)
template = """Answer the following QUESTION based on the CONTEXT given.
If you do not know the answer and the CONTEXT doesn't contain the answer truthfully say "I don't know".
CONTEXT:
{context}
QUESTION:
{question}
ANSWER:
"""
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
def get_text():
input_text = st.text_input("Ask a question regarding Formula 1: ", "", key="input")
return input_text
load_dotenv("chroma.env")
chroma_host = os.getenv("CHROMA_HOST", "localhost")
chroma_port = os.getenv("CHROMA_PORT", 8000)
chroma_collection = os.getenv("CHROMA_COLLECTION", "F1-wiki")
chroma_client = HttpClient(host=chroma_host, port=chroma_port)
collection = chroma_client.get_collection(name="F1-wiki", embedding_function=CustomEmbeddingFunction())
question = get_text()
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if question:
with st.spinner("Generating answer... "):
response = collection.query(query_texts=question, include=['documents'], n_results=5)
context = " ".join(response['documents'][0])
input_text = template.replace("{context}", context).replace("{question}", question)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).to(device)
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, early_stopping=True)
answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1].strip()
st.session_state.past.append(question)
st.session_state.generated.append(answer)
st.write(answer)
|