import os import gradio as gr from gradio.components import Textbox, Button, Slider, Checkbox from AinaTheme import theme from urllib.error import HTTPError from rag import RAG from utils import setup MAX_NEW_TOKENS = 700 SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="False") == "True" import logging logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s') setup() print("Loading RAG model...") print("Show model parameters in UI: ", SHOW_MODEL_PARAMETERS_IN_UI) # Load the RAG model rag = RAG( vs_hf_repo_path=os.getenv("VS_REPO_NAME"), vectorstore_path=os.getenv("VECTORSTORE_PATH"), hf_token=os.getenv("HF_TOKEN"), embeddings_model=os.getenv("EMBEDDINGS"), model_name=os.getenv("MODEL"), rerank_model=os.getenv("RERANK_MODEL"), rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS")) ) def generate(prompt, model_parameters): try: output, context, source = rag.get_response(prompt, model_parameters) return output, context, source except HTTPError as err: if err.code == 400: gr.Warning( "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET." ) except: gr.Warning( "Inference endpoint is not available right now. Please try again later." ) return None, None, None def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): """ Function to handle the input and call the RAG model for inference. """ if input_.strip() == "": gr.Warning("Not possible to inference an empty input") return None model_parameters = { "NUM_CHUNKS": num_chunks, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "top_k": top_k, "top_p": top_p, "do_sample": do_sample, "temperature": temperature } print("Model parameters: ", model_parameters) output, context, source = generate(input_, model_parameters) sources_markup = "" for url in source: sources_markup += f'{url}
' return output, sources_markup, context # return output.strip(), sources_markup, context def change_interactive(text): if len(text) == 0: return gr.update(interactive=True), gr.update(interactive=False) return gr.update(interactive=True), gr.update(interactive=True) def clear(): return ( None, None, None, None, gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True) ) def gradio_app(): with gr.Blocks(theme=theme) as demo: # App Description # ===================================================================================================================================== with gr.Row(): with gr.Column(): gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""") with gr.Row(equal_height=False): # User Input # ===================================================================================================================================== with gr.Column(scale=2, variant="panel"): input_ = Textbox( lines=5, label="Input", placeholder="Qui va crear la guerra de les Galaxies ?", ) with gr.Row(variant="default"): clear_btn = Button("Clear",) submit_btn = Button("Submit", variant="primary", interactive=False) with gr.Row(variant="default"): num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True) # Add Examples manually gr.Examples( examples=[ ["Qui va crear la guerra de les Galaxies?"], ["Quin era el nom real de Voltaire?"], ["Què fan al BSC?"], # No existèix aquesta entrada a la VDB # https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic # ["Què és un Imperi Galàctic?"], # ["Què és l'Imperi Galàctic d'Isaac Asimov?"], # ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"] ], inputs=[input_], # only inputs ) # Output # ===================================================================================================================================== with gr.Column(scale=10, variant="panel"): output = Textbox( lines=10, max_lines=25, label="Output", interactive=False, show_copy_button=True ) with gr.Accordion("Sources and context:", open=False, visible=False): source_context = gr.Markdown( label="Sources", show_label=False, ) with gr.Accordion("See full context evaluation:", open=False): context_evaluation = gr.Markdown( label="Full context", show_label=False, # interactive=False, # autoscroll=False, # show_copy_button=True ) # Event Handlers # ===================================================================================================================================== input_.change( fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False, ) input_.change( fn=None, inputs=[input_], api_name=False, js="""(i, m) => { document.getElementById('inputlenght').textContent = i.length + ' ' document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; }""", ) clear_btn.click( fn=clear, inputs=[], outputs=[input_, output, source_context, context_evaluation, num_chunks], # outputs=[input_, output, source_context, context_evaluation] + parameters_compontents, queue=False, api_name=False ) submit_btn.click( fn=submit_input, # inputs=[input_] + parameters_compontents, inputs=[input_] + [num_chunks], outputs=[output, source_context, context_evaluation], api_name="get-results" ) # ===================================================================================================================================== # # Output # with gr.Row(): # with gr.Column(scale=0.5): # gr.Examples( # examples=[["""Qui va crear la guerra de les Galaxies ?"""],], # inputs=input_, # outputs=[output, source_context, context_evaluation], # fn=submit_input, # ) # input_, output, source_context, context_evaluation, num_chunks = clear() demo.launch(show_api=True) if __name__ == "__main__": gradio_app()