|
|
|
"""app.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1XRKQ-ICJVg5oXXPNinjrj1VGGr8F3VYE |
|
""" |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
|
from datasets import load_dataset |
|
|
|
|
|
model_name = "tiiuae/falcon-7b" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
dataset = load_dataset("casehold/casehold", "all") |
|
|
|
|
|
def preprocess_data(example): |
|
|
|
return { |
|
"input_ids": tokenizer( |
|
example["context"] + " " + example["question"], |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512, |
|
)["input_ids"], |
|
"labels": tokenizer( |
|
example["answer"], |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512, |
|
)["input_ids"], |
|
} |
|
|
|
tokenized_dataset = dataset.map(preprocess_data, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./legal_gpt", |
|
evaluation_strategy="epoch", |
|
learning_rate=5e-5, |
|
per_device_train_batch_size=4, |
|
num_train_epochs=3, |
|
save_steps=1000, |
|
save_total_limit=2, |
|
fp16=True, |
|
logging_dir="./logs", |
|
logging_steps=500, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset["train"], |
|
eval_dataset=tokenized_dataset["validation"], |
|
) |
|
|
|
trainer.train() |
|
model.save_pretrained("./legal_gpt") |
|
tokenizer.save_pretrained("./legal_gpt") |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_path = "./legal_gpt" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
def generate_response(prompt): |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
outputs = model.generate(inputs["input_ids"], max_length=200, num_return_sequences=1, do_sample=True, top_k=10) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_response, |
|
inputs=gr.Textbox(lines=5, placeholder="Enter your legal query here..."), |
|
outputs="text", |
|
title="Legal Advice GPT", |
|
description="Ask your legal questions and receive advice based on fine-tuned GPT!" |
|
) |
|
|
|
interface.launch() |