Swaleed commited on
Commit
53abbe8
·
verified ·
1 Parent(s): 8695292

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1XRKQ-ICJVg5oXXPNinjrj1VGGr8F3VYE
8
+ """
9
+
10
+ pip install transformers datasets torch gradio
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
13
+ from datasets import load_dataset
14
+
15
+ # Step 1: Load the pre-trained model and tokenizer
16
+ model_name = "tiiuae/falcon-7b"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
19
+
20
+ # Step 2: Load the legal dataset
21
+ dataset = load_dataset("casehold/casehold", "all")
22
+
23
+ # Step 3: Preprocess the dataset
24
+ def preprocess_data(example):
25
+ # Combine context and question into a single input
26
+ return {
27
+ "input_ids": tokenizer(
28
+ example["context"] + " " + example["question"],
29
+ truncation=True,
30
+ padding="max_length",
31
+ max_length=512,
32
+ )["input_ids"],
33
+ "labels": tokenizer(
34
+ example["answer"],
35
+ truncation=True,
36
+ padding="max_length",
37
+ max_length=512,
38
+ )["input_ids"],
39
+ }
40
+
41
+ tokenized_dataset = dataset.map(preprocess_data, batched=True)
42
+
43
+ # Step 4: Fine-tune the model
44
+ training_args = TrainingArguments(
45
+ output_dir="./legal_gpt",
46
+ evaluation_strategy="epoch",
47
+ learning_rate=5e-5,
48
+ per_device_train_batch_size=4,
49
+ num_train_epochs=3,
50
+ save_steps=1000,
51
+ save_total_limit=2,
52
+ fp16=True, # Mixed precision for faster training
53
+ logging_dir="./logs",
54
+ logging_steps=500,
55
+ )
56
+
57
+ trainer = Trainer(
58
+ model=model,
59
+ args=training_args,
60
+ train_dataset=tokenized_dataset["train"],
61
+ eval_dataset=tokenized_dataset["validation"],
62
+ )
63
+
64
+ trainer.train()
65
+ model.save_pretrained("./legal_gpt")
66
+ tokenizer.save_pretrained("./legal_gpt")
67
+
68
+ import gradio as gr
69
+ from transformers import AutoTokenizer, AutoModelForCausalLM
70
+
71
+ # Load the fine-tuned model
72
+ model_path = "./legal_gpt"
73
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
74
+ model = AutoModelForCausalLM.from_pretrained(model_path)
75
+
76
+ def generate_response(prompt):
77
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
78
+ outputs = model.generate(inputs["input_ids"], max_length=200, num_return_sequences=1, do_sample=True, top_k=10)
79
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ return response
81
+
82
+ # Gradio Interface
83
+ interface = gr.Interface(
84
+ fn=generate_response,
85
+ inputs=gr.Textbox(lines=5, placeholder="Enter your legal query here..."),
86
+ outputs="text",
87
+ title="Legal Advice GPT",
88
+ description="Ask your legal questions and receive advice based on fine-tuned GPT!"
89
+ )
90
+
91
+ interface.launch()