Spaces:
Running
on
L4
Running
on
L4
import os | |
import json | |
import re | |
import numpy as np | |
import streamlit as st | |
# from openai import OpenAI | |
import random | |
from utils.help import get_disclaimer | |
from utils.format import sec_to_time, fix_latex, get_youtube_embed | |
from utils.rag_utils import load_youtube_data, load_book_data, load_summary, fixed_knn_retrieval, get_random_question | |
from utils.system_prompts import get_expert_system_prompt, get_synthesis_system_prompt | |
from utils.openai_utils import embed_question_openai, openai_domain_specific_answer_generation, openai_context_integration | |
from utils.llama_utils import get_bnb_config, load_base_model, load_fine_tuned_model, generate_response | |
st.set_page_config(page_title="AI University") | |
st.markdown(""" | |
<style> | |
.video-wrapper { | |
position: relative; | |
padding-bottom: 56.25%; | |
height: 0; | |
} | |
.video-wrapper iframe { | |
position: absolute; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Set the cache directory to persistent storage | |
os.environ["HF_HOME"] = "/data/.cache/huggingface" | |
# # client = OpenAI(api_key=st.secrets["general"]["OpenAI_API"]) | |
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# --------------------------------------- | |
base_path = "data/" | |
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct" | |
adapter_path = "./llm_files/llama-tommi-v0.35-weights/" | |
st.title(":red[AI University]") | |
st.markdown("### Finite Element Methods") | |
# st.markdown("### Based on Introduction to Finite Element Methods (FEM) by Prof. Krishna Garikipati") | |
# st.markdown("##### [YouTube playlist of the FEM lectures](https://www.youtube.com/playlist?list=PLJhG_d-Sp_JHKVRhfTgDqbic_4MHpltXZ)") | |
st.markdown(":gray[Welcome to] :red[AI University]:gray[, developed at the] :red[University of Southern California]:gray[. This app leverages AI to provide expert answers to queries related to] :red[Finite Element Methods (FEM)]:gray[.]") | |
# As the content is AI-generated, we strongly recommend independently verifying the information provided. | |
st.markdown(" ") | |
st.markdown(" ") | |
# st.divider() | |
# Sidebar for settings | |
with st.sidebar: | |
st.header("Settings") | |
# with st.container(border=True): | |
# Embedding model | |
model_name = st.selectbox("Choose content embedding model", [ | |
"text-embedding-3-small", | |
# "text-embedding-3-large", | |
# "all-MiniLM-L6-v2", | |
# "all-mpnet-base-v2" | |
], | |
# help=""" | |
# Select the embedding model to use for encoding the retrieved text data. | |
# Options include OpenAI's `text-embedding-3` models and two widely | |
# used SentenceTransformers models. | |
# """ | |
) | |
with st.container(border=True): | |
st.write('**Video lectures**') | |
yt_token_choice = st.select_slider("Token per content", [256, 512, 1024], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="yt_token_len") | |
yt_chunk_tokens = yt_token_choice | |
yt_max_content = {128: 32, 256: 16, 512: 8, 1024: 4}[yt_chunk_tokens] | |
top_k_YT = st.slider("Number of relevant content pieces to retrieve", 0, yt_max_content, 4, key="yt_token_num") | |
yt_overlap_tokens = yt_chunk_tokens // 4 | |
# st.divider() | |
with st.container(border=True): | |
st.write('**Textbook**') | |
show_textbook = False | |
# show_textbook = st.toggle("Show Textbook Content", value=False) | |
latex_token_choice = st.select_slider("Token per content", [128, 256, 512, 1024], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="latex_token_len") | |
latex_chunk_tokens = latex_token_choice | |
latex_max_content = {128: 32, 256: 16, 512: 8, 1024: 4}[latex_chunk_tokens] | |
top_k_Latex = st.slider("Number of relevant content pieces to retrieve", 0, latex_max_content, 4, key="latex_token_num") | |
# latex_overlap_tokens = latex_chunk_tokens // 4 | |
latex_overlap_tokens = 0 | |
st.write(' ') | |
with st.expander('Expert model', expanded=False): | |
use_expert_answer = st.toggle("Use expert answer", value=True) | |
show_expert_responce = st.toggle("Show initial expert answer", value=False) | |
st.session_state.expert_model = st.selectbox( | |
"Choose the LLM model", | |
["gpt-4o-mini", | |
"gpt-3.5-turbo", | |
"llama-tommi-0.35"], | |
key='a1model' | |
) | |
if st.session_state.expert_model == "llama-tommi-0.35": | |
tommi_do_sample = st.toggle("Enable Sampling", value=False, key='tommi_sample') | |
if tommi_do_sample: | |
tommi_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='tommi_temp') | |
tommi_top_k = st.slider("Top K", 0, 100, 50, key='tommi_top_k') | |
tommi_top_p = st.slider("Top P", 0.0, 1.0, 0.95, key='tommi_top_p') | |
else: | |
tommi_num_beams = st.slider("Num Beams", 1, 4, 1, key='tommi_num_beams') | |
tommi_max_new_tokens = st.slider("Max New Tokens", 100, 2000, 500, step=50, key='tommi_max_new_tokens') | |
else: | |
expert_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='a1t') | |
expert_top_p = st.slider("Top P", 0.0, 1.0, 0.9, key='a1p') | |
with st.expander('Synthesis model',expanded=False): | |
# with st.container(border=True): | |
# Choose the LLM model | |
model = st.selectbox("Choose the LLM model", ["gpt-4o-mini", "gpt-3.5-turbo"], key='a2model') | |
# Temperature | |
integration_temperature = st.slider("Temperature", 0.0, .3, .5, help="Defines the randomness in the next token prediction. Lower: More predictable and focused. Higher: More adventurous and diverse.", key='a2t') | |
integration_top_p = st.slider("Top P", 0.1, 0.5, .3, help="Defines the range of token choices the model can consider in the next prediction. Lower: More focused and restricted to high-probability options. Higher: More creative, allowing consideration of less likely options.", key='a2p') | |
# Main content area | |
if "question" not in st.session_state: | |
st.session_state.question = "" | |
text_area_placeholder = st.empty() | |
question_help = "Including details or instructions improves the answer." | |
st.session_state.question = text_area_placeholder.text_area( | |
"**Enter your question/query about Finite Element Method**", | |
height=120, | |
value=st.session_state.question, | |
help=question_help | |
) | |
_, col1, col2, _ = st.columns([4, 2, 4, 3]) | |
with col1: | |
submit_button_placeholder = st.empty() | |
with col2: | |
if st.button("Random Question"): | |
while True: | |
random_question = get_random_question(base_path + "/questions.txt") | |
if random_question != st.session_state.question: | |
break | |
st.session_state.question = random_question | |
text_area_placeholder.text_area( | |
"**Enter your question:**", | |
height=120, | |
value=st.session_state.question, | |
help=question_help | |
) | |
# Load YouTube and LaTeX data | |
text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens) | |
text_data_Latex, context_embeddings_Latex = load_book_data(base_path, model_name, latex_chunk_tokens, latex_overlap_tokens) | |
summary = load_summary('data/KG_FEM_summary.json') | |
if 'question_answered' not in st.session_state: | |
st.session_state.question_answered = False | |
if 'context_by_video' not in st.session_state: | |
st.session_state.context_by_video = {} | |
if 'context_by_section' not in st.session_state: | |
st.session_state.context_by_section = {} | |
if 'answer' not in st.session_state: | |
st.session_state.answer = "" | |
if 'playing_video_id' not in st.session_state: | |
st.session_state.playing_video_id = None | |
if submit_button_placeholder.button("AI Answer", type="primary"): | |
if st.session_state.question != "": | |
with st.spinner("Finding relevant contexts..."): | |
question_embedding = embed_question_openai(st.session_state.question, model_name) | |
initial_max_k = int(0.1 * context_embeddings_YT.shape[0]) | |
idx_YT = fixed_knn_retrieval(question_embedding, context_embeddings_YT, top_k=top_k_YT, min_k=0) | |
idx_Latex = fixed_knn_retrieval(question_embedding, context_embeddings_Latex, top_k=top_k_Latex, min_k=0) | |
with st.spinner("Answering the question..."): | |
relevant_contexts_YT = sorted([text_data_YT[i] for i in idx_YT], key=lambda x: x['order']) | |
relevant_contexts_Latex = sorted([text_data_Latex[i] for i in idx_Latex], key=lambda x: x['order']) | |
st.session_state.context_by_video = {} | |
for context_item in relevant_contexts_YT: | |
video_id = context_item['video_id'] | |
if video_id not in st.session_state.context_by_video: | |
st.session_state.context_by_video[video_id] = [] | |
st.session_state.context_by_video[video_id].append(context_item) | |
st.session_state.context_by_section = {} | |
for context_item in relevant_contexts_Latex: | |
section_id = context_item['section'] | |
if section_id not in st.session_state.context_by_section: | |
st.session_state.context_by_section[section_id] = [] | |
st.session_state.context_by_section[section_id].append(context_item) | |
context = '' | |
for i, (video_id, contexts) in enumerate(st.session_state.context_by_video.items(), start=1): | |
for context_item in contexts: | |
start_time = int(context_item['start']) | |
context += f'Video {i}, time: {sec_to_time(start_time)}:' + context_item['text'] + '\n\n' | |
for i, (section_id, contexts) in enumerate(st.session_state.context_by_section.items(), start=1): | |
context += f'Section {i} ({section_id}):\n' | |
for context_item in contexts: | |
context += context_item['text'] + '\n\n' | |
if use_expert_answer: | |
if st.session_state.expert_model == "llama-tommi-0.35": | |
if 'tommi_model' not in st.session_state: | |
tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path) | |
st.session_state.tommi_model = tommi_model | |
st.session_state.tommi_tokenizer = tommi_tokenizer | |
messages = [ | |
{"role": "system", "content": "You are an expert in Finite Element Methods."}, | |
{"role": "user", "content": st.session_state.question} | |
] | |
expert_answer = generate_response( | |
model=st.session_state.tommi_model, | |
tokenizer=st.session_state.tommi_tokenizer, | |
messages=messages, | |
do_sample=tommi_do_sample, | |
temperature=tommi_temperature if tommi_do_sample else None, | |
top_k=tommi_top_k if tommi_do_sample else None, | |
top_p=tommi_top_p if tommi_do_sample else None, | |
num_beams=tommi_num_beams if not tommi_do_sample else 1, | |
max_new_tokens=tommi_max_new_tokens | |
) | |
elif st.session_state.expert_model in ["gpt-4o-mini", "gpt-3.5-turbo"]: | |
expert_answer = openai_domain_specific_answer_generation( | |
get_expert_system_prompt(), | |
st.session_state.question, | |
model=model, | |
temperature=expert_temperature, | |
top_p=expert_top_p | |
) | |
st.session_state.expert_answer = fix_latex(expert_answer) | |
else: | |
st.session_state.expert_answer = 'No Expert Answer. Only use the context.' | |
answer = openai_context_integration( | |
get_synthesis_system_prompt("Finite Element Method"), | |
st.session_state.question, | |
st.session_state.expert_answer, | |
context, | |
model=model, | |
temperature=integration_temperature, | |
top_p=integration_top_p | |
) | |
answer = fix_latex(answer) | |
if answer.split()[0] == "NOT_ENOUGH_INFO": | |
st.markdown("") | |
st.markdown("#### Query:") | |
st.markdown(fix_latex(st.session_state.question)) | |
if show_expert_responce: | |
st.markdown("#### Initial Expert Answer:") | |
st.markdown(st.session_state.expert_answer) | |
st.markdown("#### Answer:") | |
st.write(":smiling_face_with_tear:") | |
st.markdown(answer.split('NOT_ENOUGH_INFO')[1]) | |
st.divider() | |
st.caption(get_disclaimer()) | |
# st.caption("The AI Teaching Assistant project") | |
st.session_state.question_answered = False | |
st.stop() | |
else: | |
st.session_state.answer = answer | |
st.session_state.question_answered = True | |
else: | |
st.markdown("") | |
st.write("Please enter a question. :smirk:") | |
st.session_state.question_answered = False | |
if st.session_state.question_answered: | |
st.markdown("") | |
st.markdown("#### Query:") | |
st.markdown(fix_latex(st.session_state.question)) | |
if show_expert_responce: | |
st.markdown("#### Initial Expert Answer:") | |
st.markdown(st.session_state.expert_answer) | |
st.markdown("#### Answer:") | |
st.markdown(st.session_state.answer) | |
if top_k_YT > 0: | |
st.markdown("#### Retrieved content in lecture videos") | |
for i, (video_id, contexts) in enumerate(st.session_state.context_by_video.items(), start=1): | |
# with st.expander(f"**Video {i}** | {contexts[0]['title']}", expanded=True): | |
with st.container(border=True): | |
st.markdown(f"**Video {i} | {contexts[0]['title']}**") | |
video_placeholder = st.empty() | |
video_placeholder.markdown(get_youtube_embed(video_id, 0, 0), unsafe_allow_html=True) | |
st.markdown('') | |
with st.container(border=False): | |
st.markdown("Retrieved Times") | |
cols = st.columns([1 for i in range(len(contexts))] + [9 - len(contexts)]) | |
for j, context_item in enumerate(contexts): | |
start_time = int(context_item['start']) | |
label = sec_to_time(start_time) | |
if cols[j].button(label, key=f"{video_id}_{start_time}"): | |
if st.session_state.playing_video_id is not None: | |
st.session_state.playing_video_id = None | |
video_placeholder.empty() | |
video_placeholder.markdown(get_youtube_embed(video_id, start_time, 1), unsafe_allow_html=True) | |
st.session_state.playing_video_id = video_id | |
with st.expander("Video Summary", expanded=False): | |
# st.write("##### Video Overview:") | |
st.markdown(summary[video_id]) | |
if show_textbook and top_k_Latex > 0: | |
st.markdown("#### Retrieved content in textbook",help="The Finite Element Method: Linear Static and Dynamic Finite Element Analysis") | |
for i, (section_id, contexts) in enumerate(st.session_state.context_by_section.items(), start=1): | |
# with st.expander(f"**Section {i} | {section_id}**", expanded=True): | |
st.markdown(f"**Section {i} | {section_id}**") | |
for context_item in contexts: | |
st.markdown(context_item['text']) | |
st.divider() | |
st.markdown(" ") | |
st.divider() | |
st.caption(get_disclaimer()) |