TALLER_1_3 / app.py
thejarll's picture
Update app.py
a32597b verified
import gradio as gr
import pickle
import os
import json
from transformers import AutoTokenizer, AutoModel, pipeline
import torch
import faiss
import numpy as np
from spaces import GPU # IMPORTANTE para ZeroGPU
# Token para modelos privados si se requiere
hf_token = os.getenv("HF_KEY")
# Cargar índice FAISS y los chunks
if os.path.exists("index.pkl"):
with open("index.pkl", "rb") as f:
index, chunks = pickle.load(f)
else:
raise FileNotFoundError("No se encontró el archivo 'index.pkl'.")
# Cargar diccionario de sinónimos
with open("sinonimos.json", "r", encoding="utf-8") as f:
diccionario_sinonimos = json.load(f)
# Función para expandir palabras clave con sinónimos
def expandir_con_sinonimos(palabras, diccionario):
resultado = set(palabras)
for palabra in palabras:
for clave, sinonimos in diccionario.items():
if palabra == clave or palabra in sinonimos:
resultado.update([clave] + sinonimos)
return list(resultado)
# Modelo de embeddings
model_id = "jinaai/jina-embeddings-v2-base-es"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
def generar_embedding(texto):
inputs = tokenizer(texto, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
last_hidden = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).expand(last_hidden.size()).float()
summed = torch.sum(last_hidden * mask, 1)
counted = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / counted
return mean_pooled.numpy()
# LLM para generar respuesta final
llm = pipeline(
"text-generation",
model="meta-llama/Llama-3.2-3B-Instruct",
token=hf_token,
trust_remote_code=True
)
@GPU
def responder(pregunta):
if not pregunta:
return "Por favor ingresa una pregunta."
pregunta_embedding = generar_embedding(pregunta)
distances, indices = index.search(pregunta_embedding.reshape(1, -1), k=20)
result_chunks = [chunks[i] for i in indices[0]]
# Expandir sinónimos (opcional)
palabras_clave = pregunta.lower().split()
palabras_expandidas = expandir_con_sinonimos(palabras_clave, diccionario_sinonimos)
# Aplicar filtro, pero dejar más chunks como respaldo
filtrados = [c for c in result_chunks if any(p in c.lower() for p in palabras_expandidas)]
contexto_final = "\n\n".join(filtrados[:15]) if filtrados else "\n\n".join(result_chunks[:15])
# (El prompt vendría después de esto...)
prompt = f"""
Eres un abogado colombiano experto en normativas legales.
Vas a responder la siguiente pregunta basándote exclusivamente en el TEXTO LEGAL proporcionado.
Analiza si la pregunta está relacionada con el Código de Tránsito, el Código de Policía o el Código Penal,
y responde desde el enfoque legal más pertinente. Si el texto incluye varios códigos,
elige únicamente el que sea relevante para la pregunta.
No inventes sanciones ni menciones artículos que no estén explícitamente en el texto.
No incluyas enlaces ni notas aclaratorias. Sé claro, breve, profesional y directo.
CONTEXTO LEGAL:
{contexto_final}
PREGUNTA:
{pregunta}
RESPUESTA:
"""
resultado = llm(
prompt,
max_new_tokens=350,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.2
)[0]["generated_text"]
if "RESPUESTA:" in resultado:
solo_respuesta = resultado.split("RESPUESTA:")[-1].strip()
else:
solo_respuesta = resultado.strip()
return solo_respuesta
# Interfaz de usuario con Gradio
demo = gr.Interface(
fn=responder,
inputs=gr.Textbox(label="Escribe tu pregunta"),
outputs=gr.Textbox(label="Respuesta generada"),
title="Asistente Legal Colombiano",
description="Consulta el Código de Tránsito, Código de Policía y Código Penal colombiano."
)
if __name__ == "__main__":
demo.launch()