Reham721 commited on
Commit
6f86460
·
verified ·
1 Parent(s): eac6be3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ from arabert.preprocess import ArabertPreprocessor
3
+ import unicodedata
4
+ import arabic_reshaper
5
+ from bidi.algorithm import get_display
6
+ import torch
7
+ import random
8
+ import re
9
+ import gradio as gr
10
+
11
+ tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG")
12
+ tokenizer2 = AutoTokenizer.from_pretrained("google/mt5-base")
13
+
14
+ model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG")
15
+ model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs_QG")
16
+
17
+ prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator")
18
+ qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")
19
+
20
+ def generate_questions(model, tokenizer, input_sequence):
21
+ input_ids = tokenizer.encode(input_sequence, return_tensors='pt')
22
+ outputs = model.generate(
23
+ input_ids=input_ids,
24
+ max_length=200,
25
+ num_beams=3,
26
+ no_repeat_ngram_size=3,
27
+ early_stopping=True,
28
+ temperature=1,
29
+ num_return_sequences=3,
30
+ )
31
+ return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
32
+
33
+ def get_sorted_questions(questions, context):
34
+ dic = {}
35
+ context = prep.preprocess(context)
36
+ for question in questions:
37
+ try:
38
+ result = qa_pipe(question=question, context=context)
39
+ dic[question] = result["score"]
40
+ except:
41
+ dic[question] = 0
42
+ return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))
43
+
44
+ def is_arabic(text):
45
+ reshaped_text = arabic_reshaper.reshape(text)
46
+ bidi_text = get_display(reshaped_text)
47
+ for char in bidi_text:
48
+ if char.isalpha() and not unicodedata.name(char).startswith('ARABIC'):
49
+ return False
50
+ return True
51
+
52
+ def generate_distractors(question, answer, context, num_distractors=3, k=10):
53
+ input_sequence = f'{question} <sep> {answer} <sep> {context}'
54
+ input_ids = tokenizer2.encode(input_sequence, return_tensors='pt')
55
+ outputs = model2.generate(
56
+ input_ids,
57
+ do_sample=True,
58
+ max_length=50,
59
+ top_k=50,
60
+ top_p=0.95,
61
+ num_return_sequences=num_distractors,
62
+ no_repeat_ngram_size=2
63
+ )
64
+ distractors = []
65
+ for output in outputs:
66
+ decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
67
+ elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
68
+ elements = [e for e in elements if e and is_arabic(e)]
69
+ distractors.extend(elements)
70
+ unique_distractors = []
71
+ for d in distractors:
72
+ if d not in unique_distractors and d != answer:
73
+ unique_distractors.append(d)
74
+ while len(unique_distractors) < num_distractors:
75
+ outputs = model2.generate(
76
+ input_ids,
77
+ do_sample=True,
78
+ max_length=50,
79
+ top_k=50,
80
+ top_p=0.95,
81
+ num_return_sequences=num_distractors - len(unique_distractors),
82
+ no_repeat_ngram_size=2
83
+ )
84
+ for output in outputs:
85
+ decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
86
+ elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
87
+ elements = [e for e in elements if e and is_arabic(e)]
88
+ for e in elements:
89
+ if e not in unique_distractors and e != answer:
90
+ unique_distractors.append(e)
91
+ if len(unique_distractors) >= num_distractors:
92
+ break
93
+ if len(unique_distractors) > k:
94
+ unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k]
95
+ return random.sample(unique_distractors, num_distractors)
96
+
97
+ context = gr.Textbox(lines=5, placeholder="أدخل الفقرة هنا", label="النص")
98
+ answer = gr.Textbox(lines=3, placeholder="أدخل الإجابة هنا", label="الإجابة")
99
+ question_type = gr.Radio(choices=["سؤال مقالي", "سؤال اختيار من متعدد"], label="نوع السؤال")
100
+ question = gr.Textbox(type="text", label="السؤال الناتج")
101
+
102
+ def generate_question(context, answer, question_type):
103
+ article = answer + "<sep>" + context
104
+ output = generate_questions(model1, tokenizer1, article)
105
+ result = get_sorted_questions(output, context)
106
+ best_question = next(iter(result)) if result else "لم يتم توليد سؤال مناسب"
107
+ if question_type == "سؤال مقالي":
108
+ return best_question
109
+ else:
110
+ mcqs = generate_distractors(best_question, answer, context)
111
+ mcqs.append(answer)
112
+ random.shuffle(mcqs)
113
+ return best_question + "\n" + "\n".join("- " + opt for opt in mcqs)
114
+
115
+ iface = gr.Interface(
116
+ fn=generate_question,
117
+ inputs=[context, answer, question_type],
118
+ outputs=question
119
+ )
120
+
121
+ iface.launch(debug=True, share=False)