minsuas commited on
Commit
a3055d2
ยท
verified ยท
1 Parent(s): 0ae7414

Upload module1.py

Browse files
Files changed (1) hide show
  1. src/FisrtModule/module1.py +87 -0
src/FisrtModule/module1.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """module1.py
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1AYXXKXRzUU4DWKWbJqvyjSwQ0dVQMS7Y
8
+ """
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ class MisconceptionModel:
16
+ def __init__(self, model_name, misconception_mapping_path, misconception_embs_paths):
17
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
18
+ self.model = SentenceTransformer(model_name)
19
+ self.misconception_mapping = pd.read_parquet(misconception_mapping_path)
20
+ self.misconception_names = self.misconception_mapping.set_index("MisconceptionId")["MisconceptionName"]
21
+ self.misconception_embs = [
22
+ np.load(path) for path in misconception_embs_paths
23
+ ]
24
+
25
+ def preprocess(self, df):
26
+ """๋ฐ์ดํ„ฐ ํ”„๋ฆฌํ”„๋กœ์„ธ์‹ฑ"""
27
+ df_new = df.copy()
28
+ for col in df.columns[df.dtypes == "object"]:
29
+ df_new[col] = df_new[col].str.strip()
30
+ for option in ["A", "B", "C", "D"]:
31
+ df_new[f"Answer{option}Text"] = df_new[f"Answer{option}Text"].str.replace("Only\n", "Only ")
32
+ return df_new
33
+
34
+ def wide_to_long(self, df):
35
+ """๋ฐ์ดํ„ฐ๋ฅผ wide ํ˜•์‹์—์„œ long ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜"""
36
+ rows = []
37
+ for _, row in df.iterrows():
38
+ correct_option = row["CorrectAnswer"]
39
+ correct_text = row[f"Answer{correct_option}Text"]
40
+ for option in ["A", "B", "C", "D"]:
41
+ if option == correct_option:
42
+ continue
43
+ misconception_id = row.get(f"Misconception{option}Id", np.nan)
44
+ row_new = row[:"QuestionText"]
45
+ row_new["CorrectAnswerText"] = correct_text
46
+ row_new["Answer"] = option
47
+ row_new["AnswerText"] = row[f"Answer{option}Text"]
48
+ if not pd.isna(misconception_id):
49
+ row_new["MisconceptionId"] = int(misconception_id)
50
+ rows.append(row_new)
51
+ df_long = pd.DataFrame(rows).reset_index(drop=True)
52
+ df_long.insert(0, "QuestionId_Answer", df_long["QuestionId"].astype(str) + "_" + df_long["Answer"])
53
+ return df_long
54
+
55
+ def predict(self, test_df):
56
+ """ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ์˜ˆ์ธก ์ˆ˜ํ–‰"""
57
+ test_df_long = self.wide_to_long(test_df)
58
+
59
+ prompt = (
60
+ "Subject: {SubjectName}\n"
61
+ "Construct: {ConstructName}\n"
62
+ "Question: {QuestionText}\n"
63
+ "Incorrect Answer: {AnswerText}"
64
+ )
65
+ test_df_long["anchor"] = [
66
+ prompt.format(
67
+ SubjectName=row["SubjectName"],
68
+ ConstructName=row["ConstructName"],
69
+ QuestionText=row["QuestionText"],
70
+ AnswerText=row["AnswerText"]
71
+ ) for _, row in test_df_long.iterrows()
72
+ ]
73
+
74
+ # ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ์ž„๋ฒ ๋”ฉ
75
+ embs_test_query = self.model.encode(test_df_long["anchor"], normalize_embeddings=True)
76
+
77
+ # ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ ๋ฐ ์ˆœ์œ„ ์‚ฐ์ถœ
78
+ rank_test = np.array([
79
+ np.argsort(np.argsort(-cosine_similarity(embs_test_query, embs_misconception)), axis=1, kind="stable")
80
+ for embs_misconception in self.misconception_embs
81
+ ])
82
+ rank_ave_test = np.mean(rank_test ** (1 / 4), axis=0)
83
+ argsort_test = np.argsort(rank_ave_test, axis=1, kind="stable")
84
+
85
+ test_df_long["PredictedMisconceptions"] = [argsort_test[i, :25].tolist() for i in range(len(argsort_test))]
86
+ return test_df_long
87
+