# -*- coding: utf-8 -*- """app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1XRKQ-ICJVg5oXXPNinjrj1VGGr8F3VYE """ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments from datasets import load_dataset # Step 1: Load the pre-trained model and tokenizer model_name = "tiiuae/falcon-7b" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # Step 2: Load the legal dataset dataset = load_dataset("casehold/casehold", "all") # Step 3: Preprocess the dataset def preprocess_data(example): # Combine context and question into a single input return { "input_ids": tokenizer( example["context"] + " " + example["question"], truncation=True, padding="max_length", max_length=512, )["input_ids"], "labels": tokenizer( example["answer"], truncation=True, padding="max_length", max_length=512, )["input_ids"], } tokenized_dataset = dataset.map(preprocess_data, batched=True) # Step 4: Fine-tune the model training_args = TrainingArguments( output_dir="./legal_gpt", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=4, num_train_epochs=3, save_steps=1000, save_total_limit=2, fp16=True, # Mixed precision for faster training logging_dir="./logs", logging_steps=500, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["validation"], ) trainer.train() model.save_pretrained("./legal_gpt") tokenizer.save_pretrained("./legal_gpt") import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Load the fine-tuned model model_path = "./legal_gpt" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) def generate_response(prompt): inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) outputs = model.generate(inputs["input_ids"], max_length=200, num_return_sequences=1, do_sample=True, top_k=10) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response # Gradio Interface interface = gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=5, placeholder="Enter your legal query here..."), outputs="text", title="Legal Advice GPT", description="Ask your legal questions and receive advice based on fine-tuned GPT!" ) interface.launch()