File size: 1,343 Bytes
3ff674d
 
 
 
 
 
 
 
 
 
 
 
054849f
3ff674d
9fbf01e
3ff674d
 
 
 
77563ba
3ff674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4694dc8
3ff674d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys
import os

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(project_root)

import gradio as gr
from cohere import Client
from groq import Groq
from predictor import Predictor

COHERE_API_KEY="8hr6huyTDgAnbbU3WU4mnXTovfa2dwIeV0kc5Uf5"
GROQ_API_KEY="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"


def main():
    embeding_client = Client(api_key=COHERE_API_KEY)
    gen_client = Groq(api_key=GROQ_API_KEY)

    predictor = Predictor(dataset_path="dataset.pkl",
                          embeding_client=embeding_client,
                          QA_boosted=True,
                          generative_client=gen_client,)
    predictor.setup()

    def make_prediction(query):
        answer = predictor.make_prediction(query)
        return answer.answer, answer.link, answer.content

    iface = gr.Interface(
        fn=make_prediction,
        inputs=gr.Textbox(label="query", value="Alain Delon connait-il Anne-Elisabeth Lemoine?"),
        outputs=[gr.Textbox(label="reponse"),
                 gr.Textbox(label="link"),
                 gr.Textbox(label="context")],
        title="Gossip answering",
        description="Derrière chaque question, il y a un gossip qui attend d’être découvert."
    )

    iface.launch()

if __name__ == '__main__':
    main()