customgpt / app.py
Swaleed's picture
Update app.py
c2e0879 verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1XRKQ-ICJVg5oXXPNinjrj1VGGr8F3VYE
"""
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
# Step 1: Load the pre-trained model and tokenizer
model_name = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
# Step 2: Load the legal dataset
dataset = load_dataset("casehold/casehold", "all")
# Step 3: Preprocess the dataset
def preprocess_data(example):
# Combine context and question into a single input
return {
"input_ids": tokenizer(
example["context"] + " " + example["question"],
truncation=True,
padding="max_length",
max_length=512,
)["input_ids"],
"labels": tokenizer(
example["answer"],
truncation=True,
padding="max_length",
max_length=512,
)["input_ids"],
}
tokenized_dataset = dataset.map(preprocess_data, batched=True)
# Step 4: Fine-tune the model
training_args = TrainingArguments(
output_dir="./legal_gpt",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=4,
num_train_epochs=3,
save_steps=1000,
save_total_limit=2,
fp16=True, # Mixed precision for faster training
logging_dir="./logs",
logging_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
)
trainer.train()
model.save_pretrained("./legal_gpt")
tokenizer.save_pretrained("./legal_gpt")
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the fine-tuned model
model_path = "./legal_gpt"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(inputs["input_ids"], max_length=200, num_return_sequences=1, do_sample=True, top_k=10)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Gradio Interface
interface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=5, placeholder="Enter your legal query here..."),
outputs="text",
title="Legal Advice GPT",
description="Ask your legal questions and receive advice based on fine-tuned GPT!"
)
interface.launch()