daqc commited on
Commit
0a272a9
verified
1 Parent(s): 8ab72fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -75
app.py CHANGED
@@ -1,34 +1,20 @@
1
  import gradio as gr
2
 
3
- # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
4
- import json
5
  import torch
 
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
7
- # from google.colab import userdata
8
- import os
9
 
10
- model_id = "somosnlp/kuntur-peru-legal-es-gemma-2b-it-merged"
11
- bnb_config = BitsAndBytesConfig(
12
- load_in_4bit=True,
13
- bnb_4bit_quant_type="nf4",
14
- bnb_4bit_compute_dtype=torch.bfloat16
15
- )
16
- max_seq_length=512
17
 
18
- # if torch.cuda.get_device_capability()[0] >= 8:
19
- # # print("Flash Attention")
20
- # attn_implementation="flash_attention_2"
21
- # else:
22
- # attn_implementation=None
23
- attn_implementation=None
24
 
25
- tokenizer = AutoTokenizer.from_pretrained(model_id,
26
- max_length = max_seq_length)
27
- model = AutoModelForCausalLM.from_pretrained(model_id,
28
- # quantization_config=bnb_config,
29
- device_map = {"":0},
30
- attn_implementation = attn_implementation, # A100 o H100
31
- ).eval()
32
 
33
 
34
 
@@ -51,7 +37,7 @@ class ListOfTokensStoppingCriteria(StoppingCriteria):
51
  return False
52
 
53
  # Uso del criterio de parada personalizado
54
- stop_tokens = ["end_of_turn"] # Lista de tokens de parada
55
 
56
  # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
57
  stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
@@ -59,13 +45,17 @@ stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
59
  # A帽ade tu criterio de parada a una StoppingCriteriaList
60
  stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
61
 
62
- def generate_text(prompt, max_length=2048):
63
- # prompt="""What were the main contributions of Eratosthenes to the development of mathematics in ancient Greece?"""
64
- prompt=prompt.replace("\n", "").replace("驴","").replace("?","")
65
-
66
-
67
- #EXAMPLE
68
- input_text = f'''<start_of_turn>system
 
 
 
 
69
  You are a helpful AI assistant. You only answer in JSON format.
70
  Eres un agente experto en la constituci贸n pol铆tica del per煤 de 1993 que solo responde formato JSON:
71
  {{
@@ -79,54 +69,69 @@ def generate_text(prompt, max_length=2048):
79
  3. tema: Solo escoge los temas de la lista proporcionada, no inventes ni crees un nuevo tema, en caso de considerarse mas de 2 temas se separa con punto y coma, escoge solo los que se adecuen a la respuesta, no consideres todos los temas al mismo tiempo: Educaci贸n, Conflictos sociales, Prevenci贸n de la corrupci贸n, Servicios p煤blicos, Violencia contra la ni帽ez, Desigualdad y violencia hacia las mujeres, Seguridad ciudadana, Discapacidad o Salud.
80
  UNICAMENTE DEBES RESPONDER EN FORMATO JSON, SOLO EN JSON, JSON, JSON
81
  <end_of_turn>
82
- <start_of_turn>user
83
- {prompt}?<end_of_turn>\n<start_of_turn>model\n'''
84
-
85
- inputs = tokenizer.encode(input_text,
86
- return_tensors="pt",
87
- add_special_tokens=False).to("cuda:0")
88
- max_new_tokens=max_length
89
- generation_config = GenerationConfig(
90
- max_new_tokens=max_new_tokens,
91
- temperature=0.15,
92
- top_p=0.75, #0.9,
93
- top_k=40, # 45
94
- num_beams=2, #me
95
- repetition_penalty=1., #1.1
96
- do_sample=True,
97
- )
98
- outputs = model.generate(generation_config=generation_config,
99
- input_ids=inputs,
100
- stopping_criteria=stopping_criteria_list,)
101
- return tokenizer.decode(outputs[0], skip_special_tokens=False) #True
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
 
105
  def mostrar_respuesta(pregunta):
106
- respuesta_default = "No se pudo generar una respuesta adecuada."
107
- json_obj = {
108
- "respuesta": respuesta_default,
109
- "fuente": respuesta_default,
110
- "tema": respuesta_default
111
- }
112
 
113
- if pregunta:
114
  try:
115
- res = generate_text(pregunta, max_length=512)
116
- inicio_json = res.find('{')
117
- fin_json = res.rfind('}') + 1
118
- json_str = res[inicio_json:fin_json]
119
- json_obj = json.loads(json_str)
120
-
121
- # Verificar si el JSON contiene todas las claves necesarias
122
- if all(key in json_obj for key in ["respuesta", "fuente", "tema"]):
123
- return json_obj["respuesta"], json_obj["fuente"], json_obj["tema"]
124
- else:
125
- return res, respuesta_default, respuesta_default
126
- except Exception as e:
127
- print("Error al procesar la respuesta:", e)
128
-
129
- return json_obj["respuesta"], json_obj["fuente"], json_obj["tema"]
 
 
 
 
 
 
 
130
 
131
  # Ejemplos de preguntas
132
  ejemplos = [
@@ -143,7 +148,7 @@ iface = gr.Interface(
143
  gr.Textbox(label="Fuente", lines=1),
144
  gr.Textbox(label="Tema", lines=1)
145
  ],
146
- title="HolaaaaaaaaaaaaaaaaaConsulta Juridica basada en la Constitucion Politica del Peru",
147
  description="Introduce tu pregunta sobre la Constituci贸n Politica o una situaci贸n donde creas que tus derechos hayan sido vulnerados.",
148
  examples=ejemplos,
149
  )
 
1
  import gradio as gr
2
 
 
 
3
  import torch
4
+ from peft import PeftModel, PeftConfig
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, GenerationConfig
6
+
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
 
 
8
 
9
+ import os
 
 
 
 
 
 
10
 
11
+ peft_model_id = "daqc/kuntur-peru-legal-es-gemma-2b-it"
12
+ config = PeftConfig.from_pretrained(peft_model_id)
13
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map={"":0})
14
+ tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
 
 
15
 
16
+ model = PeftModel.from_pretrained(model, peft_model_id)
17
+ model.eval()
 
 
 
 
 
18
 
19
 
20
 
 
37
  return False
38
 
39
  # Uso del criterio de parada personalizado
40
+ stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada
41
 
42
  # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
43
  stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
 
45
  # A帽ade tu criterio de parada a una StoppingCriteriaList
46
  stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
47
 
48
+ def generate(
49
+ instruction,
50
+ max_new_tokens=256,
51
+ temperature=0.1,
52
+ top_p=0.75,
53
+ top_k=40,
54
+ num_beams=2,
55
+ **kwargs,
56
+ ):
57
+ instruction = instruction.replace("驴","").replace("?","")
58
+ system = f"""<start_of_turn>system
59
  You are a helpful AI assistant. You only answer in JSON format.
60
  Eres un agente experto en la constituci贸n pol铆tica del per煤 de 1993 que solo responde formato JSON:
61
  {{
 
69
  3. tema: Solo escoge los temas de la lista proporcionada, no inventes ni crees un nuevo tema, en caso de considerarse mas de 2 temas se separa con punto y coma, escoge solo los que se adecuen a la respuesta, no consideres todos los temas al mismo tiempo: Educaci贸n, Conflictos sociales, Prevenci贸n de la corrupci贸n, Servicios p煤blicos, Violencia contra la ni帽ez, Desigualdad y violencia hacia las mujeres, Seguridad ciudadana, Discapacidad o Salud.
70
  UNICAMENTE DEBES RESPONDER EN FORMATO JSON, SOLO EN JSON, JSON, JSON
71
  <end_of_turn>
72
+ """
73
+ prompt = f"""{system} <start_of_turn>user
74
+ {instruction}<end_of_turn> <start_of_turn>model\n
75
+ """
76
+ print(prompt)
77
+ inputs = tokenizer(prompt, return_tensors="pt")
78
+ input_ids = inputs["input_ids"].to("cuda")
79
+ attention_mask = inputs["attention_mask"].to("cuda")
80
+ generation_config = GenerationConfig(
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ top_k=top_k,
84
+ num_beams=num_beams,
85
+ **kwargs,
86
+ )
87
+ with torch.no_grad():
88
+ generation_output = model.generate(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ generation_config=generation_config,
92
+ return_dict_in_generate=True,
93
+ #output_scores=True,
94
+ max_new_tokens=max_new_tokens,
95
+ early_stopping=True
96
+ )
97
+ s = generation_output.sequences[0]
98
+ output = tokenizer.decode(s, skip_special_tokens=True)
99
+ return output.split("model")[1]
100
+
101
+
102
 
103
 
104
 
105
  def mostrar_respuesta(pregunta):
106
+ texto_json = generate(pregunta)
107
+ respuesta = ""
108
+ fuente = ""
109
+ tema = ""
 
 
110
 
111
+ if texto_json.startswith('{'):
112
  try:
113
+ # Busca las posiciones de inicio y fin de cada campo
114
+ inicio_respuesta = texto_json.find('"respuesta":') + len('"respuesta":')
115
+ fin_respuesta = texto_json.find('"fuente":')
116
+ inicio_fuente = texto_json.find('"fuente":') + len('"fuente":')
117
+ fin_fuente = texto_json.find('"tema":')
118
+ inicio_tema = texto_json.find('"tema":') + len('"tema":')
119
+
120
+ # Extrae los valores de cada campo
121
+ respuesta = texto_json[inicio_respuesta:fin_respuesta].strip().strip('"')
122
+ fuente = texto_json[inicio_fuente:fin_fuente].strip().strip('"')
123
+ # Verifica si la clave "tema" existe en el JSON antes de extraer su valor
124
+ if '"tema":' in texto_json:
125
+ tema = texto_json[inicio_tema:].strip().strip('"')
126
+ except ValueError:
127
+ pass
128
+ else:
129
+ respuesta = texto_json.strip().strip('"')
130
+
131
+ return respuesta, fuente, tema
132
+
133
+
134
+
135
 
136
  # Ejemplos de preguntas
137
  ejemplos = [
 
148
  gr.Textbox(label="Fuente", lines=1),
149
  gr.Textbox(label="Tema", lines=1)
150
  ],
151
+ title="Consulta Juridica basada en la Constitucion Politica del Peru",
152
  description="Introduce tu pregunta sobre la Constituci贸n Politica o una situaci贸n donde creas que tus derechos hayan sido vulnerados.",
153
  examples=ejemplos,
154
  )