from transformers import T5Tokenizer,T5ForConditionalGeneration import torch import lightning as L import numpy as np import random import gradio as gr MODEL_NAME:str = "google/flan-t5-small" def load_tokenizer(tokenizer_path:str): tokenizer = T5Tokenizer.from_pretrained(tokenizer_path,local_files_only=True) return tokenizer def qa_preprocess_data(context:str, tokenizer:T5Tokenizer): input_prefix:str = "Generate relevant question and answer for this paragraph:\n " inputs = input_prefix + context model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt") return model_inputs def distractor_preprocess_data(context:str,question:str, answer:str,tokenizer:T5Tokenizer): input_prefix:str = "Generate 3 plausible but incorrect answer options (distractors) for the given question and correct answer, based on the provided context:" inputs = f"{input_prefix}\nCONTEXT:\n{context}\nQUESTION: {question}\nANSWER: {answer}" model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt") return model_inputs class DistractorTrained(L.LightningModule): def __init__(self): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) def forward(self,input_ids,attention_mask): return self.model.generate(input_ids=input_ids, attention_mask=attention_mask, num_beams=4,max_new_tokens=80, do_sample=True,temperature=1.2) class QATrained(L.LightningModule): def __init__(self): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) def forward(self,input_ids:torch.Tensor,attention_mask:torch.Tensor, num_beams:int=4,max_new_tokens:int=65, temperature:float=1.2): return self.model.generate( input_ids=input_ids,attention_mask=attention_mask, num_beams=num_beams,max_new_tokens=65, do_sample=True,temperature=temperature ) def load_qa_model(model_path:str): model = QATrained.load_from_checkpoint(model_path) return model def load_distractor_model(model_path:str): model = DistractorTrained.load_from_checkpoint(model_path) return model def predict_qa(model:QATrained,tokenizer:T5Tokenizer,model_inputs:torch.Tensor, device:str="cpu"): model.to(device) model.eval() with torch.inference_mode(): generated_ids = model(input_ids=model_inputs["input_ids"].to(device), attention_mask = model_inputs["attention_mask"].to(device)) generated_ids = generated_ids.cpu() decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids] return decoded_predictions def predict_distractor(model:DistractorTrained,tokenizer:T5Tokenizer, model_inputs:torch.Tensor,device:str="cpu"): model.to(device) model.eval() with torch.inference_mode(): generated_ids = model(input_ids=model_inputs["input_ids"].to(device), attention_mask = model_inputs["attention_mask"].to(device)) generated_ids = generated_ids.cpu() decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids] return decoded_predictions def main(user_input): device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer_path:str = "./t5_tokenizer" qa_model_path:str = "./qa-t5-small.ckpt" distractor_model_path:str = "./distractor_t5-small.ckpt" tokenizer = load_tokenizer(tokenizer_path) qa_model = load_qa_model(qa_model_path) distractor_model = load_distractor_model(distractor_model_path) qa_model_inputs = qa_preprocess_data(user_input,tokenizer) qa_decoded_predictions = predict_qa(qa_model,tokenizer,qa_model_inputs,device=device) qa_decoded_predictions = qa_decoded_predictions[0] indices = [] start = 0 while True: index = qa_decoded_predictions.find("[ANSWER] ",start) if index==-1: break indices.append(index) start = index + 1 question = qa_decoded_predictions[11:indices[0]].rstrip() if len(indices)==1: answer = qa_decoded_predictions[indices[0]+9:].rstrip() if len(indices)>1: answer = qa_decoded_predictions[indices[0]+9:indices[1]-1].rstrip() filtered_ans = answer.replace("?",".") distractor_model_inputs = distractor_preprocess_data(user_input,question,filtered_ans,tokenizer) distractor_decoded_predictions = predict_distractor(distractor_model,tokenizer,distractor_model_inputs,device=device) distractor_decoded_predictions = distractor_decoded_predictions[0] option_strings = ["[OPTION 1]","[OPTION 2]","[OPTION 3]"] option_indices:list[int] = [] for option in option_strings: ind:int = distractor_decoded_predictions.find(option) option_indices.append(ind) for option in option_strings: option1:str = distractor_decoded_predictions[11:option_indices[1]].replace(option,"").strip() option2:str = distractor_decoded_predictions[option_indices[1]+10:option_indices[-1]].replace(option,"").strip() option3:str = distractor_decoded_predictions[option_indices[1]+10:].replace(option,"").strip() option4:str = answer return {"question": question, "option1": option1, "option2": option2, "option3": option3, "option4": option4} def shuffle_options(question_data): options = [ question_data["option1"], question_data["option2"], question_data["option3"], question_data["option4"] ] correct_answer = question_data["option4"] random.shuffle(options) return options, correct_answer def process_input(context): question_data = main(context) options, correct_answer = shuffle_options(question_data) return question_data["question"], options, correct_answer def check_answer(choice, correct_answer): if choice == correct_answer: return f'
Correct!
' else: return f'Incorrect ! Try again.
' with gr.Blocks() as demo: gr.Markdown("# MCQ Generator") with gr.Row(): context_input = gr.Textbox(label="Context Paragraph", lines=5) generate_button = gr.Button("Generate Question") question_output = gr.Textbox(label="Question") options_radio = gr.Radio(label="Options", choices=[]) submit_button = gr.Button("Submit Answer") result_output = gr.HTML() correct_answer = gr.State() def update_interface(question, options, correct): return { question_output: question, options_radio: gr.Radio(choices=options, label="Options"), correct_answer: correct } generate_button.click( process_input, inputs=[context_input], outputs=[question_output, options_radio, correct_answer] ).then( update_interface, inputs=[question_output, options_radio, correct_answer], outputs=[question_output, options_radio, correct_answer] ) submit_button.click( check_answer, inputs=[options_radio, correct_answer], outputs=[result_output] ) if __name__=="__main__": demo.launch(debug=True)