pujanpaudel commited on
Commit
5026e83
·
verified ·
1 Parent(s): e0f8c1f

main app uloaded

Browse files
Files changed (1) hide show
  1. inference.py +213 -0
inference.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer,T5ForConditionalGeneration
2
+ import torch
3
+ import lightning as L
4
+ import numpy as np
5
+ import random
6
+
7
+ import gradio as gr
8
+
9
+ MODEL_NAME:str = "google/flan-t5-small"
10
+
11
+ def load_tokenizer(tokenizer_path:str):
12
+ tokenizer = T5Tokenizer.from_pretrained(tokenizer_path,local_files_only=True)
13
+ return tokenizer
14
+
15
+ def qa_preprocess_data(context:str, tokenizer:T5Tokenizer):
16
+ input_prefix:str = "Generate relevant question and answer for this paragraph:\n "
17
+ inputs = input_prefix + context
18
+ model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt")
19
+ return model_inputs
20
+
21
+ def distractor_preprocess_data(context:str,question:str,
22
+ answer:str,tokenizer:T5Tokenizer):
23
+
24
+ input_prefix:str = "Generate 3 plausible but incorrect answer options (distractors) for the given question and correct answer, based on the provided context:"
25
+ inputs = f"{input_prefix}\nCONTEXT:\n{context}\nQUESTION: {question}\nANSWER: {answer}"
26
+ model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt")
27
+ return model_inputs
28
+
29
+
30
+ class DistractorTrained(L.LightningModule):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
34
+
35
+
36
+ def forward(self,input_ids,attention_mask):
37
+ return self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
38
+ num_beams=4,max_new_tokens=80,
39
+ do_sample=True,temperature=1.2)
40
+
41
+ class QATrained(L.LightningModule):
42
+ def __init__(self):
43
+ super().__init__()
44
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
45
+
46
+ def forward(self,input_ids:torch.Tensor,attention_mask:torch.Tensor,
47
+ num_beams:int=4,max_new_tokens:int=65,
48
+ temperature:float=1.2):
49
+
50
+ return self.model.generate(
51
+ input_ids=input_ids,attention_mask=attention_mask,
52
+ num_beams=num_beams,max_new_tokens=65,
53
+ do_sample=True,temperature=temperature
54
+ )
55
+
56
+
57
+ def load_qa_model(model_path:str):
58
+ model = QATrained.load_from_checkpoint(model_path)
59
+ return model
60
+
61
+ def load_distractor_model(model_path:str):
62
+ model = DistractorTrained.load_from_checkpoint(model_path)
63
+ return model
64
+
65
+ def predict_qa(model:QATrained,tokenizer:T5Tokenizer,model_inputs:torch.Tensor,
66
+ device:str="cpu"):
67
+ model.to(device)
68
+ model.eval()
69
+ with torch.inference_mode():
70
+ generated_ids = model(input_ids=model_inputs["input_ids"].to(device),
71
+ attention_mask = model_inputs["attention_mask"].to(device))
72
+
73
+ generated_ids = generated_ids.cpu()
74
+ decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids]
75
+
76
+ return decoded_predictions
77
+
78
+ def predict_distractor(model:DistractorTrained,tokenizer:T5Tokenizer,
79
+ model_inputs:torch.Tensor,device:str="cpu"):
80
+ model.to(device)
81
+ model.eval()
82
+ with torch.inference_mode():
83
+ generated_ids = model(input_ids=model_inputs["input_ids"].to(device),
84
+ attention_mask = model_inputs["attention_mask"].to(device))
85
+
86
+ generated_ids = generated_ids.cpu()
87
+
88
+ decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids]
89
+
90
+ return decoded_predictions
91
+
92
+
93
+ def main(user_input):
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ tokenizer_path:str = "./t5_tokenizer"
96
+ qa_model_path:str = "./qa_trained_model/qa-t5-small.ckpt"
97
+ distractor_model_path:str = "./distractor_trained_model/distractor_t5-small.ckpt"
98
+ tokenizer = load_tokenizer(tokenizer_path)
99
+ qa_model = load_qa_model(qa_model_path)
100
+ distractor_model = load_distractor_model(distractor_model_path)
101
+ qa_model_inputs = qa_preprocess_data(user_input,tokenizer)
102
+ qa_decoded_predictions = predict_qa(qa_model,tokenizer,qa_model_inputs,device=device)
103
+ qa_decoded_predictions = qa_decoded_predictions[0]
104
+ indices = []
105
+ start = 0
106
+
107
+ while True:
108
+ index = qa_decoded_predictions.find("[ANSWER] ",start)
109
+
110
+ if index==-1:
111
+ break
112
+ indices.append(index)
113
+ start = index + 1
114
+
115
+ question = qa_decoded_predictions[11:indices[0]].rstrip()
116
+
117
+ if len(indices)==1:
118
+ answer = qa_decoded_predictions[indices[0]+9:].rstrip()
119
+
120
+ if len(indices)>1:
121
+ answer = qa_decoded_predictions[indices[0]+9:indices[1]-1].rstrip()
122
+
123
+ filtered_ans = answer.replace("?",".")
124
+
125
+ distractor_model_inputs = distractor_preprocess_data(user_input,question,filtered_ans,tokenizer)
126
+ distractor_decoded_predictions = predict_distractor(distractor_model,tokenizer,distractor_model_inputs,device=device)
127
+
128
+ distractor_decoded_predictions = distractor_decoded_predictions[0]
129
+
130
+ option_strings = ["[OPTION 1]","[OPTION 2]","[OPTION 3]"]
131
+
132
+ option_indices:list[int] = []
133
+
134
+ for option in option_strings:
135
+ ind:int = distractor_decoded_predictions.find(option)
136
+ option_indices.append(ind)
137
+
138
+ for option in option_strings:
139
+ option1:str = distractor_decoded_predictions[11:option_indices[1]].replace(option,"").strip()
140
+ option2:str = distractor_decoded_predictions[option_indices[1]+10:option_indices[-1]].replace(option,"").strip()
141
+ option3:str = distractor_decoded_predictions[option_indices[1]+10:].replace(option,"").strip()
142
+
143
+ option4:str = answer
144
+
145
+ return {"question": question,
146
+ "option1": option1,
147
+ "option2": option2,
148
+ "option3": option3,
149
+ "option4": option4}
150
+
151
+ def shuffle_options(question_data):
152
+ options = [
153
+ question_data["option1"],
154
+ question_data["option2"],
155
+ question_data["option3"],
156
+ question_data["option4"]
157
+ ]
158
+ correct_answer = question_data["option4"]
159
+ random.shuffle(options)
160
+ return options, correct_answer
161
+
162
+ def process_input(context):
163
+ question_data = main(context)
164
+ options, correct_answer = shuffle_options(question_data)
165
+ return question_data["question"], options, correct_answer
166
+
167
+ def check_answer(choice, correct_answer):
168
+ if choice == correct_answer:
169
+ return f'<p style="color: #28a745;">Correct!</p>'
170
+ else:
171
+ return f'<p style="color: #dc3545;">Incorrect ! Try again.</p>'
172
+
173
+ with gr.Blocks() as demo:
174
+ gr.Markdown("# MCQ Generator")
175
+
176
+ with gr.Row():
177
+ context_input = gr.Textbox(label="Context Paragraph", lines=5)
178
+ generate_button = gr.Button("Generate Question")
179
+
180
+ question_output = gr.Textbox(label="Question")
181
+ options_radio = gr.Radio(label="Options", choices=[])
182
+ submit_button = gr.Button("Submit Answer")
183
+ result_output = gr.HTML()
184
+ correct_answer = gr.State()
185
+
186
+ def update_interface(question, options, correct):
187
+ return {
188
+ question_output: question,
189
+ options_radio: gr.Radio(choices=options, label="Options"),
190
+ correct_answer: correct
191
+ }
192
+
193
+ generate_button.click(
194
+ process_input,
195
+ inputs=[context_input],
196
+ outputs=[question_output, options_radio, correct_answer]
197
+ ).then(
198
+ update_interface,
199
+ inputs=[question_output, options_radio, correct_answer],
200
+ outputs=[question_output, options_radio, correct_answer]
201
+ )
202
+
203
+ submit_button.click(
204
+ check_answer,
205
+ inputs=[options_radio, correct_answer],
206
+ outputs=[result_output]
207
+ )
208
+
209
+ if __name__=="__main__":
210
+ demo.launch(debug=True)
211
+
212
+
213
+