from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from arabert.preprocess import ArabertPreprocessor import unicodedata import arabic_reshaper from bidi.algorithm import get_display import torch import random import re import gradio as gr tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG") tokenizer2 = AutoTokenizer.from_pretrained("google/mt5-base") model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG") model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs_QG") prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator") qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa") def generate_questions(model, tokenizer, input_sequence): input_ids = tokenizer.encode(input_sequence, return_tensors='pt') outputs = model.generate( input_ids=input_ids, max_length=200, num_beams=3, no_repeat_ngram_size=3, early_stopping=True, temperature=1, num_return_sequences=3, ) return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] def get_sorted_questions(questions, context): dic = {} context = prep.preprocess(context) for question in questions: try: result = qa_pipe(question=question, context=context) dic[question] = result["score"] except: dic[question] = 0 return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True)) def is_arabic(text): reshaped_text = arabic_reshaper.reshape(text) bidi_text = get_display(reshaped_text) for char in bidi_text: if char.isalpha() and not unicodedata.name(char).startswith('ARABIC'): return False return True def generate_distractors(question, answer, context, num_distractors=3, k=10): input_sequence = f'{question} {answer} {context}' input_ids = tokenizer2.encode(input_sequence, return_tensors='pt') outputs = model2.generate( input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=num_distractors, no_repeat_ngram_size=2 ) distractors = [] for output in outputs: decoded_output = tokenizer2.decode(output, skip_special_tokens=True) elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e] elements = [e for e in elements if e and is_arabic(e)] distractors.extend(elements) unique_distractors = [] for d in distractors: if d not in unique_distractors and d != answer: unique_distractors.append(d) while len(unique_distractors) < num_distractors: outputs = model2.generate( input_ids, do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=num_distractors - len(unique_distractors), no_repeat_ngram_size=2 ) for output in outputs: decoded_output = tokenizer2.decode(output, skip_special_tokens=True) elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e] elements = [e for e in elements if e and is_arabic(e)] for e in elements: if e not in unique_distractors and e != answer: unique_distractors.append(e) if len(unique_distractors) >= num_distractors: break if len(unique_distractors) > k: unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k] return random.sample(unique_distractors, num_distractors) context = gr.Textbox(lines=5, placeholder="أدخل الفقرة هنا", label="النص") answer = gr.Textbox(lines=3, placeholder="أدخل الإجابة هنا", label="الإجابة") question_type = gr.Radio(choices=["سؤال مقالي", "سؤال اختيار من متعدد"], label="نوع السؤال") question = gr.Textbox(type="text", label="السؤال الناتج") def generate_question(context, answer, question_type): article = answer + "" + context output = generate_questions(model1, tokenizer1, article) result = get_sorted_questions(output, context) best_question = next(iter(result)) if result else "لم يتم توليد سؤال مناسب" if question_type == "سؤال مقالي": return best_question else: mcqs = generate_distractors(best_question, answer, context) mcqs.append(answer) random.shuffle(mcqs) return best_question + "\n" + "\n".join("- " + opt for opt in mcqs) iface = gr.Interface( fn=generate_question, inputs=[context, answer, question_type], outputs=question ) iface.launch(debug=True, share=False)