|
from typing import List, Tuple, Optional
|
|
|
|
|
|
|
|
from llama_cpp import Llama
|
|
|
|
import gradio as gr
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from config import (
|
|
LLM_MODEL_REPOS,
|
|
START_LLM_MODEL_FILE,
|
|
EMBED_MODEL_REPOS,
|
|
SUBTITLES_LANGUAGES,
|
|
GENERATE_KWARGS,
|
|
CONTEXT_TEMPLATE,
|
|
)
|
|
|
|
from utils import (
|
|
load_llm_model,
|
|
load_embed_model,
|
|
load_documents_and_create_db,
|
|
user_message_to_chatbot,
|
|
update_user_message_with_context,
|
|
get_llm_response,
|
|
get_gguf_model_names,
|
|
add_new_model_repo,
|
|
clear_llm_folder,
|
|
clear_embed_folder,
|
|
get_memory_usage,
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox:
|
|
value = visible = db is not None
|
|
return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible)
|
|
|
|
|
|
def get_rag_settings(
|
|
rag_mode: bool,
|
|
context_template_value: str,
|
|
render: bool = True,
|
|
) -> Tuple[gr.component, ...]:
|
|
|
|
k = gr.Radio(
|
|
choices=[1, 2, 3, 4, 5, 'all'],
|
|
value=2,
|
|
label='Number of relevant documents for search',
|
|
visible=rag_mode,
|
|
render=render,
|
|
)
|
|
score_threshold = gr.Slider(
|
|
minimum=0,
|
|
maximum=1,
|
|
value=0.5,
|
|
step=0.05,
|
|
label='relevance_scores_threshold',
|
|
visible=rag_mode,
|
|
render=render,
|
|
)
|
|
context_template = gr.Textbox(
|
|
value=context_template_value,
|
|
label='Context Template',
|
|
lines=len(context_template_value.split('\n')),
|
|
visible=rag_mode,
|
|
render=render,
|
|
)
|
|
return k, score_threshold, context_template
|
|
|
|
|
|
def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component:
|
|
num_lines = len(text.split('\n'))
|
|
max_lines = 10
|
|
num_lines = max_lines if num_lines > max_lines else num_lines
|
|
return gr.Textbox(
|
|
text,
|
|
visible=rag_mode,
|
|
interactive=False,
|
|
label='User Message With Context',
|
|
lines=num_lines,
|
|
)
|
|
|
|
|
|
def get_system_prompt_component(interactive: bool) -> gr.Textbox:
|
|
value = '' if interactive else 'System prompt is not supported by this model'
|
|
return gr.Textbox(value=value, label='System prompt', interactive=interactive)
|
|
|
|
|
|
def get_generate_args(do_sample: bool) -> List[gr.component]:
|
|
generate_args = [
|
|
gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
|
|
gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
|
|
gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
|
|
gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
|
|
]
|
|
return generate_args
|
|
|
|
|
|
|
|
|
|
start_llm_model, start_support_system_role, load_log = load_llm_model(
|
|
model_repo=LLM_MODEL_REPOS[0],
|
|
model_file=START_LLM_MODEL_FILE,
|
|
)
|
|
|
|
if start_llm_model['llm_model'] is None:
|
|
raise Exception(f'LLM model not initialized, status message: {load_log}')
|
|
|
|
|
|
start_embed_model, load_log = load_embed_model(
|
|
model_repo=EMBED_MODEL_REPOS[0],
|
|
)
|
|
|
|
if start_embed_model['embed_model'] is None:
|
|
raise Exception(f'Embed model not initialized, status message: {load_log}')
|
|
|
|
|
|
|
|
|
|
css = '''
|
|
.gradio-container {
|
|
width: 70% !important;
|
|
margin: 0 auto !important;
|
|
}
|
|
'''
|
|
|
|
with gr.Blocks(css=css) as interface:
|
|
|
|
|
|
|
|
documents = gr.State([])
|
|
db = gr.State(None)
|
|
user_message_with_context = gr.State('')
|
|
support_system_role = gr.State(start_support_system_role)
|
|
llm_model_repos = gr.State(LLM_MODEL_REPOS)
|
|
embed_model_repos = gr.State(EMBED_MODEL_REPOS)
|
|
llm_model = gr.State(start_llm_model)
|
|
embed_model = gr.State(start_embed_model)
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='Chatbot'):
|
|
with gr.Row():
|
|
with gr.Column(scale=3):
|
|
chatbot = gr.Chatbot(
|
|
type='messages',
|
|
show_copy_button=True,
|
|
height=480,
|
|
)
|
|
user_message = gr.Textbox(label='User')
|
|
|
|
with gr.Row():
|
|
user_message_btn = gr.Button('Send')
|
|
stop_btn = gr.Button('Stop')
|
|
clear_btn = gr.Button('Clear')
|
|
|
|
|
|
|
|
with gr.Column(scale=1, min_width=80):
|
|
with gr.Group():
|
|
gr.Markdown('History size')
|
|
history_len = gr.Slider(
|
|
minimum=0,
|
|
maximum=5,
|
|
value=0,
|
|
step=1,
|
|
info='Number of previous messages taken into account in history',
|
|
label='history_len',
|
|
show_label=False,
|
|
)
|
|
|
|
with gr.Group():
|
|
gr.Markdown('Generation parameters')
|
|
do_sample = gr.Checkbox(
|
|
value=False,
|
|
label='do_sample',
|
|
info='Activate random sampling',
|
|
)
|
|
generate_args = get_generate_args(do_sample.value)
|
|
do_sample.change(
|
|
fn=get_generate_args,
|
|
inputs=do_sample,
|
|
outputs=generate_args,
|
|
show_progress=False,
|
|
)
|
|
|
|
rag_mode = get_rag_mode_component(db=db.value)
|
|
k, score_threshold, context_template = get_rag_settings(
|
|
rag_mode=rag_mode.value,
|
|
context_template_value=CONTEXT_TEMPLATE,
|
|
render=False,
|
|
)
|
|
rag_mode.change(
|
|
fn=get_rag_settings,
|
|
inputs=[rag_mode, context_template],
|
|
outputs=[k, score_threshold, context_template],
|
|
)
|
|
|
|
with gr.Row():
|
|
k.render()
|
|
score_threshold.render()
|
|
|
|
|
|
|
|
with gr.Accordion('Prompt', open=True):
|
|
system_prompt = get_system_prompt_component(interactive=support_system_role.value)
|
|
context_template.render()
|
|
user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value)
|
|
|
|
|
|
|
|
generate_event = gr.on(
|
|
triggers=[user_message.submit, user_message_btn.click],
|
|
fn=user_message_to_chatbot,
|
|
inputs=[user_message, chatbot],
|
|
outputs=[user_message, chatbot],
|
|
|
|
).then(
|
|
fn=update_user_message_with_context,
|
|
inputs=[chatbot, rag_mode, db, k, score_threshold, context_template],
|
|
outputs=[user_message_with_context],
|
|
).then(
|
|
fn=get_user_message_with_context,
|
|
inputs=[user_message_with_context, rag_mode],
|
|
outputs=[user_message_with_context],
|
|
).then(
|
|
fn=get_llm_response,
|
|
inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt,
|
|
support_system_role, history_len, do_sample, *generate_args],
|
|
outputs=[chatbot],
|
|
)
|
|
|
|
stop_btn.click(
|
|
fn=None,
|
|
inputs=None,
|
|
outputs=None,
|
|
cancels=generate_event,
|
|
queue=False,
|
|
)
|
|
|
|
clear_btn.click(
|
|
fn=lambda: (None, ''),
|
|
inputs=None,
|
|
outputs=[chatbot, user_message_with_context],
|
|
queue=False,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='Load documents'):
|
|
with gr.Row(variant='compact'):
|
|
upload_files = gr.File(file_count='multiple', label='Loading text files')
|
|
web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube')
|
|
|
|
with gr.Row(variant='compact'):
|
|
chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size')
|
|
chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap')
|
|
|
|
subtitles_lang = gr.Radio(
|
|
SUBTITLES_LANGUAGES,
|
|
value=SUBTITLES_LANGUAGES[0],
|
|
label='YouTube subtitle language',
|
|
)
|
|
|
|
load_documents_btn = gr.Button(value='Upload documents and initialize database')
|
|
load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False)
|
|
|
|
load_documents_btn.click(
|
|
fn=load_documents_and_create_db,
|
|
inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model],
|
|
outputs=[documents, db, load_docs_log],
|
|
).success(
|
|
fn=get_rag_mode_component,
|
|
inputs=[db],
|
|
outputs=[rag_mode],
|
|
)
|
|
|
|
gr.HTML("""<h3 style='text-align: center'>
|
|
<a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3>
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label='View documents'):
|
|
view_documents_btn = gr.Button(value='Show downloaded text chunks')
|
|
view_documents_textbox = gr.Textbox(
|
|
lines=1,
|
|
placeholder='To view chunks, load documents in the Load documents tab',
|
|
label='Uploaded chunks',
|
|
)
|
|
sep = '=' * 20
|
|
view_documents_btn.click(
|
|
lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]),
|
|
inputs=[documents],
|
|
outputs=[view_documents_textbox],
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Tab('Load LLM model'):
|
|
new_llm_model_repo = gr.Textbox(
|
|
value='',
|
|
label='Add repository',
|
|
placeholder='Link to repository of HF models in GGUF format',
|
|
)
|
|
new_llm_model_repo_btn = gr.Button('Add repository')
|
|
curr_llm_model_repo = gr.Dropdown(
|
|
choices=LLM_MODEL_REPOS,
|
|
value=None,
|
|
label='HF Model Repository',
|
|
)
|
|
curr_llm_model_path = gr.Dropdown(
|
|
choices=[],
|
|
value=None,
|
|
label='GGUF model file',
|
|
)
|
|
load_llm_model_btn = gr.Button('Loading and initializing model')
|
|
load_llm_model_log = gr.Textbox(
|
|
value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup',
|
|
label='Model loading status',
|
|
lines=6,
|
|
)
|
|
|
|
with gr.Group():
|
|
gr.Markdown('Free up disk space by deleting all models except the currently selected one')
|
|
clear_llm_folder_btn = gr.Button('Clear folder')
|
|
|
|
new_llm_model_repo_btn.click(
|
|
fn=add_new_model_repo,
|
|
inputs=[new_llm_model_repo, llm_model_repos],
|
|
outputs=[curr_llm_model_repo, load_llm_model_log],
|
|
).success(
|
|
fn=lambda: '',
|
|
inputs=None,
|
|
outputs=[new_llm_model_repo],
|
|
)
|
|
|
|
curr_llm_model_repo.change(
|
|
fn=get_gguf_model_names,
|
|
inputs=[curr_llm_model_repo],
|
|
outputs=[curr_llm_model_path],
|
|
)
|
|
|
|
load_llm_model_btn.click(
|
|
fn=load_llm_model,
|
|
inputs=[curr_llm_model_repo, curr_llm_model_path],
|
|
outputs=[llm_model, support_system_role, load_llm_model_log],
|
|
).success(
|
|
fn=lambda log: log + get_memory_usage(),
|
|
inputs=[load_llm_model_log],
|
|
outputs=[load_llm_model_log],
|
|
).then(
|
|
fn=get_system_prompt_component,
|
|
inputs=[support_system_role],
|
|
outputs=[system_prompt],
|
|
)
|
|
|
|
clear_llm_folder_btn.click(
|
|
fn=clear_llm_folder,
|
|
inputs=[curr_llm_model_path],
|
|
outputs=None,
|
|
).success(
|
|
fn=lambda model_path: f'Models other than {model_path} removed',
|
|
inputs=[curr_llm_model_path],
|
|
outputs=None,
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Tab('Load embed model'):
|
|
new_embed_model_repo = gr.Textbox(
|
|
value='',
|
|
label='Add repository',
|
|
placeholder='Link to HF model repository',
|
|
)
|
|
new_embed_model_repo_btn = gr.Button('Add repository')
|
|
curr_embed_model_repo = gr.Dropdown(
|
|
choices=EMBED_MODEL_REPOS,
|
|
value=None,
|
|
label='HF model repository',
|
|
)
|
|
|
|
load_embed_model_btn = gr.Button('Loading and initializing model')
|
|
load_embed_model_log = gr.Textbox(
|
|
value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup',
|
|
label='Model loading status',
|
|
lines=7,
|
|
)
|
|
with gr.Group():
|
|
gr.Markdown('Free up disk space by deleting all models except the currently selected one')
|
|
clear_embed_folder_btn = gr.Button('Clear folder')
|
|
|
|
new_embed_model_repo_btn.click(
|
|
fn=add_new_model_repo,
|
|
inputs=[new_embed_model_repo, embed_model_repos],
|
|
outputs=[curr_embed_model_repo, load_embed_model_log],
|
|
).success(
|
|
fn=lambda: '',
|
|
inputs=None,
|
|
outputs=new_embed_model_repo,
|
|
)
|
|
|
|
load_embed_model_btn.click(
|
|
fn=load_embed_model,
|
|
inputs=[curr_embed_model_repo],
|
|
outputs=[embed_model, load_embed_model_log],
|
|
).success(
|
|
fn=lambda log: log + get_memory_usage(),
|
|
inputs=[load_embed_model_log],
|
|
outputs=[load_embed_model_log],
|
|
)
|
|
|
|
clear_embed_folder_btn.click(
|
|
fn=clear_embed_folder,
|
|
inputs=[curr_embed_model_repo],
|
|
outputs=None,
|
|
).success(
|
|
fn=lambda model_repo: f'Models other than {model_repo} removed',
|
|
inputs=[curr_embed_model_repo],
|
|
outputs=None,
|
|
)
|
|
|
|
|
|
interface.launch(server_name='0.0.0.0', server_port=7860) |