Spaces:
Runtime error
Runtime error
import json | |
import torch | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def predict_NuExtract(model, tokenizer, text, template, batch_size=1, max_length=10_000, max_new_tokens=4_000): | |
template = json.dumps(json.loads(template), indent=4) | |
prompt = f"""<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>""" | |
with torch.no_grad(): | |
encoding = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=max_length).to(model.device) | |
pred_ids = model.generate(**encoding, max_new_tokens=max_new_tokens) | |
output = tokenizer.decode(pred_ids[0], skip_special_tokens=True) | |
return output.split("<|output|>")[1] if "<|output|>" in output else output | |
def generate_response(extracted_data): | |
try: | |
data = json.loads(extracted_data) | |
entities = data.get("Entities", {}) | |
response = (f"I checked the logs for the user. This user was accessing the app through our {entities.get('App version', 'Unknown')} app " | |
f"(Wind Creek Casino app). {entities.get('Issue', 'an issue occurred')} on {entities.get('Date', 'an unknown date')} " | |
f"because {entities.get('Reason', 'no specific reason provided')}. This is working as designed, " | |
f"{', '.join(entities.get('Action', ['no action required']))}.") | |
return response | |
except json.JSONDecodeError: | |
return "Error in processing extracted data. Please check the input format." | |
model_name = "numind/NuExtract-v1.5" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
def extract_information(text, template): | |
extracted_data = predict_NuExtract(model, tokenizer, text, template) | |
return generate_response(extracted_data) | |
demo = gr.Interface( | |
fn=extract_information, | |
inputs=[ | |
gr.Textbox(label="Enter Text", lines=5, placeholder="Enter text to extract information from..."), | |
gr.Textbox(label="Enter Template", lines=10, placeholder="Enter JSON extraction template...") | |
], | |
outputs=gr.Textbox(label="Generated Response"), | |
title="NuExtract Information Extractor", | |
description="Enter a text and a JSON template to extract structured information and generate a response using NuExtract.", | |
) | |
# Patch Phi-3.5-mini-instruct's `prepare_inputs_for_generation` | |
def patched_prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | |
if past_key_values is not None: | |
max_cache_length = past_key_values.get_seq_length() # Fix here | |
else: | |
max_cache_length = None | |
return self._default_prepare_inputs_for_generation(input_ids, past_key_values, **kwargs) | |
# Apply the patch dynamically | |
if hasattr(model, "prepare_inputs_for_generation"): | |
model._default_prepare_inputs_for_generation = model.prepare_inputs_for_generation | |
model.prepare_inputs_for_generation = patched_prepare_inputs_for_generation.__get__(model) | |
demo.launch(share=True) | |