Ticio commited on
Commit
18a943c
verified
1 Parent(s): 775263f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from supabase import create_client
2
+ from supabase.lib.client_options import ClientOptions
3
+ from sentence_transformers import SentenceTransformer
4
+ import replicate
5
+ import vecs
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ key = os.getenv("key")
13
+ url = os.getenv("url")
14
+ opts = ClientOptions().replace(schema="vecs")
15
+ client = create_client(url, key, options=opts)
16
+
17
+ user = os.getenv("user")
18
+ password = os.getenv("password")
19
+ host = os.getenv("host")
20
+ port = os.getenv("port")
21
+ db_name = "postgres"
22
+ DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
23
+ vx = vecs.create_client(DB_CONNECTION)
24
+ model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs')
25
+
26
+ def find_prediction(prediction_list, id):
27
+ for item in prediction_list:
28
+ temp_dict = dict(item)
29
+ if temp_dict.get("id") == id:
30
+ return temp_dict
31
+
32
+ def find_id(prediction_list, inp):
33
+ id = ""
34
+ for item in prediction_list:
35
+ temp_dict = dict(item)
36
+ if temp_dict.get("input").get("prompt") == inp:
37
+ id = temp_dict.get("id")
38
+ break
39
+ else:
40
+ pass
41
+ return id
42
+
43
+ def live_inference(prompt, max_new_tokens = 1024, top_k = 50, top_p = 0.9, temperature = 0.6, presence_penalty=1.15):
44
+ output = replicate.run(
45
+ "meta/meta-llama-3-8b-instruct",
46
+ input={
47
+ "top_k": top_k,
48
+ "top_p": top_p,
49
+ "prompt": prompt,
50
+ "temperature": temperature,
51
+ "max_new_tokens": max_new_tokens,
52
+ "prompt_template": "{prompt}",
53
+ "frequency_penalty": presence_penalty,
54
+ "seed": 61001
55
+ })
56
+
57
+ id = find_id(replicate.predictions.list(), prompt)
58
+ generating = True
59
+ out = ""
60
+ while generating:
61
+ prediction_list = replicate.predictions.list()
62
+ prediction = find_prediction(prediction_list, id)
63
+ try:
64
+ if prediction.get("status") != "succeeded":
65
+ pass
66
+ else:
67
+ for item in prediction.get("output"):
68
+ out += item
69
+ generating = False
70
+ except:
71
+ pass
72
+ return out
73
+
74
+
75
+ def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = False, include_metadata=False, table = "2023"):
76
+ query_embeds = vx.get_or_create_collection(name= table, dimension=384)
77
+ ans = query_embeds.query(
78
+ data=query,
79
+ limit=limit,
80
+ filters=filters,
81
+ measure=measure,
82
+ include_value=include_value,
83
+ include_metadata=include_metadata,
84
+ )
85
+ return ans
86
+
87
+ def construct_result(ans):
88
+ ans.sort(key=sort_by_score, reverse=True)
89
+ results = ""
90
+ for i in range(0, len(ans)):
91
+ a, b = ans[i][2].get("sentencia"), ans[i][2].get("fragmento")
92
+ results += (f"En la sentencia {a}, se dijo {b}\n")
93
+ return results
94
+
95
+ def sort_by_score(item):
96
+ return item[1]
97
+
98
+ def referencias(results):
99
+ references = 'Sentencias encontradas: \n'
100
+ enlistadas = []
101
+ for item in results:
102
+ if item[2].get('sentencia') in enlistadas:
103
+ pass
104
+ else:
105
+ references += item[2].get('sentencia')+ ' '
106
+ enlistadas.append(item[2].get('sentencia'))
107
+ return references
108
+
109
+ def inference(prompt):
110
+ encoded_prompt = model.encode(prompt)
111
+ years = range(2020, 2025)
112
+ results = []
113
+ for year in years:
114
+ results.extend(query_db(encoded_prompt, include_metadata = True, table = str(year), include_value=True, limit = 3))
115
+ results.sort(key=sort_by_score, reverse=True)
116
+ context =f"""
117
+ <|begin_of_text|>
118
+ <|start_header_id|>system<|end_header_id|>
119
+ Eres Ticio, un asistente de investigaci贸n jur铆dica. Tu deber es organizar el contenido de las sentencias de la jurisprudencia de acuerdo
120
+ a las necesidades del usuario. Debes responder solo en espa帽ol. Debes responder solo en base a la informaci贸n del contexto a continuaci贸n.
121
+ Siempre debes mencionar la fuente en tu escrito, debe tener un estilo formal y juridico.
122
+ Contexto:
123
+ {construct_result(results)}
124
+ <|eot_id|>
125
+ <|start_header_id|>user<|end_header_id|>
126
+ {prompt}
127
+ <|eot_id|>
128
+ <|start_header_id|>assistant<|end_header_id|>
129
+ """
130
+ return live_inference(context, max_new_tokens=512) + '\n' + referencias(results)
131
+
132
+ theme = gr.themes.Base(
133
+ primary_hue="red",
134
+ secondary_hue="red",
135
+ neutral_hue="neutral",
136
+ ).set(
137
+ button_primary_background_fill='#910A0A',
138
+ button_primary_border_color='*primary_300',
139
+ button_primary_text_color='*primary_50'
140
+ )
141
+
142
+ with gr.Blocks(theme=theme) as demo:
143
+ output = gr.Textbox(label = "Ticio", lines = 15, show_label = True, show_copy_button= True)
144
+ name = gr.Textbox(label="Name", show_label = False, container = True, placeholder = "驴Que quieres buscar?")
145
+ greet_btn = gr.Button("Preguntar", variant = "primary")
146
+ greet_btn.click(fn=inference, inputs=name, outputs=output, api_name=False)
147
+
148
+ if __name__ == "__main__":
149
+ demo.queue(default_concurrency_limit=60)
150
+ demo.launch(show_api=False)