|
from typing import List, Optional |
|
|
|
import gradio as gr |
|
from langchain_core.vectorstores import VectorStore |
|
|
|
from config import ( |
|
LLM_MODEL_REPOS, |
|
EMBED_MODEL_REPOS, |
|
SUBTITLES_LANGUAGES, |
|
GENERATE_KWARGS, |
|
) |
|
|
|
from utils import ( |
|
load_llm_model, |
|
load_embed_model, |
|
load_documents_and_create_db, |
|
user_message_to_chatbot, |
|
update_user_message_with_context, |
|
get_llm_response, |
|
get_gguf_model_names, |
|
add_new_model_repo, |
|
clear_llm_folder, |
|
clear_embed_folder, |
|
get_memory_usage, |
|
) |
|
|
|
|
|
|
|
|
|
def get_rag_settings(rag_mode: bool, render: bool = True): |
|
k = gr.Radio( |
|
choices=[1, 2, 3, 4, 5, 'all'], |
|
value=2, |
|
label='Number of relevant documents for search', |
|
visible=rag_mode, |
|
render=render, |
|
) |
|
score_threshold = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.05, |
|
label='relevance_scores_threshold', |
|
visible=rag_mode, |
|
render=render, |
|
) |
|
return k, score_threshold |
|
|
|
|
|
def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component: |
|
num_lines = len(text.split('\n')) |
|
max_lines = 10 |
|
num_lines = max_lines if num_lines > max_lines else num_lines |
|
return gr.Textbox( |
|
text, |
|
visible=rag_mode, |
|
interactive=False, |
|
label='User Message With Context', |
|
lines=num_lines, |
|
) |
|
|
|
|
|
def get_system_prompt_component(interactive: bool) -> gr.Textbox: |
|
value = '' if interactive else 'System prompt is not supported by this model' |
|
return gr.Textbox(value=value, label='System prompt', interactive=interactive) |
|
|
|
|
|
def get_generate_args(do_sample: bool) -> List[gr.component]: |
|
generate_args = [ |
|
gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample), |
|
gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample), |
|
gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample), |
|
gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample), |
|
] |
|
return generate_args |
|
|
|
|
|
def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox: |
|
value = visible = db is not None |
|
return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible) |
|
|
|
|
|
|
|
|
|
start_llm_model, start_support_system_role, load_log = load_llm_model(LLM_MODEL_REPOS[0], 'gemma-2-2b-it-Q8_0.gguf') |
|
start_embed_model, load_log = load_embed_model(EMBED_MODEL_REPOS[0]) |
|
|
|
|
|
|
|
|
|
|
|
css = '''.gradio-container {width: 60% !important}''' |
|
|
|
with gr.Blocks(css=css) as interface: |
|
|
|
|
|
|
|
documents = gr.State([]) |
|
db = gr.State(None) |
|
user_message_with_context = gr.State('') |
|
support_system_role = gr.State(start_support_system_role) |
|
llm_model_repos = gr.State(LLM_MODEL_REPOS) |
|
embed_model_repos = gr.State(EMBED_MODEL_REPOS) |
|
llm_model = gr.State(start_llm_model) |
|
embed_model = gr.State(start_embed_model) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='Chatbot'): |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
type='messages', |
|
show_copy_button=True, |
|
bubble_full_width=False, |
|
height=480, |
|
) |
|
user_message = gr.Textbox(label='User') |
|
|
|
with gr.Row(): |
|
user_message_btn = gr.Button('Send') |
|
stop_btn = gr.Button('Stop') |
|
clear_btn = gr.Button('Clear') |
|
|
|
|
|
|
|
with gr.Column(scale=1, min_width=80): |
|
with gr.Group(): |
|
gr.Markdown('History size') |
|
history_len = gr.Slider( |
|
minimum=0, |
|
maximum=5, |
|
value=0, |
|
step=1, |
|
info='Number of previous messages taken into account in history', |
|
label='history_len', |
|
show_label=False, |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown('Generation parameters') |
|
do_sample = gr.Checkbox( |
|
value=False, |
|
label='do_sample', |
|
info='Activate random sampling', |
|
) |
|
generate_args = get_generate_args(do_sample.value) |
|
do_sample.change( |
|
fn=get_generate_args, |
|
inputs=do_sample, |
|
outputs=generate_args, |
|
show_progress=False, |
|
) |
|
|
|
rag_mode = get_rag_mode_component(db=db.value) |
|
k, score_threshold = get_rag_settings(rag_mode=rag_mode.value, render=False) |
|
rag_mode.change( |
|
fn=get_rag_settings, |
|
inputs=[rag_mode], |
|
outputs=[k, score_threshold], |
|
) |
|
with gr.Row(): |
|
k.render() |
|
score_threshold.render() |
|
|
|
|
|
|
|
with gr.Accordion('Prompt', open=True): |
|
system_prompt = get_system_prompt_component(interactive=support_system_role.value) |
|
user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value) |
|
|
|
|
|
|
|
generate_event = gr.on( |
|
triggers=[user_message.submit, user_message_btn.click], |
|
fn=user_message_to_chatbot, |
|
inputs=[user_message, chatbot], |
|
outputs=[user_message, chatbot], |
|
queue=False, |
|
).then( |
|
fn=update_user_message_with_context, |
|
inputs=[chatbot, rag_mode, db, k, score_threshold], |
|
outputs=[user_message_with_context], |
|
).then( |
|
fn=get_user_message_with_context, |
|
inputs=[user_message_with_context, rag_mode], |
|
outputs=[user_message_with_context], |
|
).then( |
|
fn=get_llm_response, |
|
inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt, |
|
support_system_role, history_len, do_sample, *generate_args], |
|
outputs=[chatbot], |
|
) |
|
|
|
stop_btn.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=generate_event, |
|
queue=False, |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: (None, ''), |
|
inputs=None, |
|
outputs=[chatbot, user_message_with_context], |
|
queue=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='Load documents'): |
|
with gr.Row(variant='compact'): |
|
upload_files = gr.File(file_count='multiple', label='Loading text files') |
|
web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube') |
|
|
|
with gr.Row(variant='compact'): |
|
chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size') |
|
chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap') |
|
|
|
subtitles_lang = gr.Radio( |
|
SUBTITLES_LANGUAGES, |
|
value=SUBTITLES_LANGUAGES[0], |
|
label='YouTube subtitle language', |
|
) |
|
|
|
load_documents_btn = gr.Button(value='Upload documents and initialize database') |
|
load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False) |
|
|
|
load_documents_btn.click( |
|
fn=load_documents_and_create_db, |
|
inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model], |
|
outputs=[documents, db, load_docs_log], |
|
).success( |
|
fn=get_rag_mode_component, |
|
inputs=[db], |
|
outputs=[rag_mode], |
|
) |
|
|
|
gr.HTML("""<h3 style='text-align: center'> |
|
<a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3> |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='View documents'): |
|
view_documents_btn = gr.Button(value='Show downloaded text chunks') |
|
view_documents_textbox = gr.Textbox( |
|
lines=1, |
|
placeholder='To view chunks, load documents in the Load documents tab', |
|
label='Uploaded chunks', |
|
) |
|
sep = '=' * 20 |
|
view_documents_btn.click( |
|
lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]), |
|
inputs=[documents], |
|
outputs=[view_documents_textbox], |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Tab('Load LLM model'): |
|
new_llm_model_repo = gr.Textbox( |
|
value='', |
|
label='Add repository', |
|
placeholder='Link to repository of HF models in GGUF format', |
|
) |
|
new_llm_model_repo_btn = gr.Button('Add repository') |
|
curr_llm_model_repo = gr.Dropdown( |
|
choices=LLM_MODEL_REPOS, |
|
value=None, |
|
label='HF Model Repository', |
|
) |
|
curr_llm_model_path = gr.Dropdown( |
|
choices=[], |
|
value=None, |
|
label='GGUF model file', |
|
) |
|
load_llm_model_btn = gr.Button('Loading and initializing model') |
|
load_llm_model_log = gr.Textbox( |
|
value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup', |
|
label='Model loading status', |
|
lines=6, |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown('Free up disk space by deleting all models except the currently selected one') |
|
clear_llm_folder_btn = gr.Button('Clear folder') |
|
|
|
new_llm_model_repo_btn.click( |
|
fn=add_new_model_repo, |
|
inputs=[new_llm_model_repo, llm_model_repos], |
|
outputs=[curr_llm_model_repo, load_llm_model_log], |
|
).success( |
|
fn=lambda: '', |
|
inputs=None, |
|
outputs=[new_llm_model_repo], |
|
) |
|
|
|
curr_llm_model_repo.change( |
|
fn=get_gguf_model_names, |
|
inputs=[curr_llm_model_repo], |
|
outputs=[curr_llm_model_path], |
|
) |
|
|
|
load_llm_model_btn.click( |
|
fn=load_llm_model, |
|
inputs=[curr_llm_model_repo, curr_llm_model_path], |
|
outputs=[llm_model, support_system_role, load_llm_model_log], |
|
).success( |
|
fn=lambda log: log + get_memory_usage(), |
|
inputs=[load_llm_model_log], |
|
outputs=[load_llm_model_log], |
|
).then( |
|
fn=get_system_prompt_component, |
|
inputs=[support_system_role], |
|
outputs=[system_prompt], |
|
) |
|
|
|
clear_llm_folder_btn.click( |
|
fn=clear_llm_folder, |
|
inputs=[curr_llm_model_path], |
|
outputs=None, |
|
).success( |
|
fn=lambda model_path: f'Models other than {model_path} removed', |
|
inputs=[curr_llm_model_path], |
|
outputs=None, |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Tab('Load embed model'): |
|
new_embed_model_repo = gr.Textbox( |
|
value='', |
|
label='Add repository', |
|
placeholder='Link to HF model repository', |
|
) |
|
new_embed_model_repo_btn = gr.Button('Add repository') |
|
curr_embed_model_repo = gr.Dropdown( |
|
choices=EMBED_MODEL_REPOS, |
|
value=None, |
|
label='HF model repository', |
|
) |
|
|
|
load_embed_model_btn = gr.Button('Loading and initializing model') |
|
load_embed_model_log = gr.Textbox( |
|
value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup', |
|
label='Model loading status', |
|
lines=7, |
|
) |
|
with gr.Group(): |
|
gr.Markdown('Free up disk space by deleting all models except the currently selected one') |
|
clear_embed_folder_btn = gr.Button('Clear folder') |
|
|
|
new_embed_model_repo_btn.click( |
|
fn=add_new_model_repo, |
|
inputs=[new_embed_model_repo, embed_model_repos], |
|
outputs=[curr_embed_model_repo, load_embed_model_log], |
|
).success( |
|
fn=lambda: '', |
|
inputs=None, |
|
outputs=new_embed_model_repo, |
|
) |
|
|
|
load_embed_model_btn.click( |
|
fn=load_embed_model, |
|
inputs=[curr_embed_model_repo], |
|
outputs=[embed_model, load_embed_model_log], |
|
).success( |
|
fn=lambda log: log + get_memory_usage(), |
|
inputs=[load_embed_model_log], |
|
outputs=[load_embed_model_log], |
|
) |
|
|
|
clear_embed_folder_btn.click( |
|
fn=clear_embed_folder, |
|
inputs=[curr_embed_model_repo], |
|
outputs=None, |
|
).success( |
|
fn=lambda model_repo: f'Models other than {model_repo} removed', |
|
inputs=[curr_embed_model_repo], |
|
outputs=None, |
|
) |
|
|
|
|
|
interface.launch(server_name='0.0.0.0', server_port=7860) |