Spaces:
Sleeping
Sleeping
Commit
Β·
a6b0abd
1
Parent(s):
b489aea
grouping input news
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- application.py +8 -12
- application_2.py +20 -17
- examples/example_image_input.jpg +0 -0
- example_image_real_1.jpg.webp β examples/example_image_real_1.jpg.webp +0 -0
- example_image_real_2.jpg.webp β examples/example_image_real_2.jpg.webp +0 -0
- example_image_real_3.jpg β examples/example_image_real_3.jpg +0 -0
- example_image_real_3.jpg.webp β examples/example_image_real_3.jpg.webp +0 -0
- example_text_LLM_modification.txt β examples/example_text_LLM_modification.txt +0 -0
- example_text_LLM_topic.txt β examples/example_text_LLM_topic.txt +0 -0
- example_text_real.txt β examples/example_text_real.txt +0 -0
- example_text_real_2.txt β examples/example_text_real_2.txt +0 -0
- src/application/content_detection.py +66 -41
- src/application/image/model_detection.py +1 -1
- src/application/text/helper.py +29 -0
- src/application/text/search_detection.py +34 -26
- src/application/url_reader.py +8 -1
- src/images/CNN_model_classifier.py +0 -63
- src/images/Diffusion/Final_Report.pdf +0 -0
- src/images/Diffusion/Pipfile +0 -29
- src/images/Diffusion/Pipfile.lock +0 -0
- src/images/Diffusion/README.md +0 -72
- src/images/Diffusion/combine_laion_script.ipynb +0 -117
- src/images/Diffusion/data_split.py +0 -80
- src/images/Diffusion/dataloader.py +0 -228
- src/images/Diffusion/diffusion_data_loader.py +0 -233
- src/images/Diffusion/diffusion_model_classifier.py +0 -242
- src/images/Diffusion/evaluation.ipynb +0 -187
- src/images/Diffusion/model.py +0 -307
- src/images/Diffusion/sample_laion_script.ipynb +0 -73
- src/images/Diffusion/scrape.py +0 -149
- src/images/Diffusion/utils_sampling.py +0 -94
- src/images/Diffusion/visualizations.ipynb +0 -196
- src/images/README.md +0 -64
- src/images/Search_Image/Bing_search.py +0 -93
- src/images/Search_Image/image_difference.py +0 -0
- src/images/Search_Image/image_model_share.py +0 -142
- src/images/Search_Image/search.py +0 -56
- src/images/Search_Image/search_2.py +0 -150
- src/images/Search_Image/search_yandex.py +0 -177
- src/images/diffusion_data_loader.py +0 -229
- src/images/diffusion_model_classifier.py +0 -293
- src/images/diffusion_utils_sampling.py +0 -94
- src/images/image_demo.py +0 -73
- src/main.py +0 -51
- src/texts/MAGE/.gradio/flagged/dataset1.csv +0 -2
- src/texts/MAGE/LICENSE +0 -201
- src/texts/MAGE/README.md +0 -258
- src/texts/MAGE/app.py +0 -74
- src/texts/MAGE/deployment/__init__.py +0 -1
- src/texts/MAGE/deployment/prepare_testbeds.py +0 -348
application.py
CHANGED
@@ -124,13 +124,13 @@ with gr.Blocks() as demo:
|
|
124 |
#url_input.change(load_image, inputs=url_input, outputs=image_view)
|
125 |
|
126 |
try:
|
127 |
-
with open('example_text_real.txt','r', encoding='utf-8') as file:
|
128 |
text_real_1 = file.read()
|
129 |
-
with open('example_text_real_2.txt','r', encoding='utf-8') as file:
|
130 |
text_real_2 = file.read()
|
131 |
-
with open('example_text_LLM_topic.txt','r', encoding='utf-8') as file:
|
132 |
text_llm_topic = file.read()
|
133 |
-
with open('example_text_LLM_modification.txt','r', encoding='utf-8') as file:
|
134 |
text_llm_modification = file.read()
|
135 |
except FileNotFoundError:
|
136 |
print("File not found.")
|
@@ -140,9 +140,9 @@ with gr.Blocks() as demo:
|
|
140 |
title_1 = "Southampton news: Leeds target striker Cameron Archer"
|
141 |
title_2 = "Southampton news: Leeds target striker Cameron Archer"
|
142 |
|
143 |
-
image_1 = "example_image_real_1.jpg.webp"
|
144 |
-
image_2 = "example_image_real_2.jpg.webp"
|
145 |
-
image_3 = "example_image_real_3.jpg"
|
146 |
|
147 |
gr.Examples(
|
148 |
examples=[
|
@@ -159,8 +159,4 @@ with gr.Blocks() as demo:
|
|
159 |
],
|
160 |
)
|
161 |
|
162 |
-
demo.launch(share=False)
|
163 |
-
|
164 |
-
|
165 |
-
# https://www.bbc.com/travel/article/20250127-one-of-the-last-traders-on-the-silk-road
|
166 |
-
# https://bbc.com/future/article/20250110-how-often-you-should-wash-your-towels-according-to-science
|
|
|
124 |
#url_input.change(load_image, inputs=url_input, outputs=image_view)
|
125 |
|
126 |
try:
|
127 |
+
with open('examples/example_text_real.txt','r', encoding='utf-8') as file:
|
128 |
text_real_1 = file.read()
|
129 |
+
with open('examples/example_text_real_2.txt','r', encoding='utf-8') as file:
|
130 |
text_real_2 = file.read()
|
131 |
+
with open('examples/example_text_LLM_topic.txt','r', encoding='utf-8') as file:
|
132 |
text_llm_topic = file.read()
|
133 |
+
with open('examples/example_text_LLM_modification.txt','r', encoding='utf-8') as file:
|
134 |
text_llm_modification = file.read()
|
135 |
except FileNotFoundError:
|
136 |
print("File not found.")
|
|
|
140 |
title_1 = "Southampton news: Leeds target striker Cameron Archer"
|
141 |
title_2 = "Southampton news: Leeds target striker Cameron Archer"
|
142 |
|
143 |
+
image_1 = "examples/example_image_real_1.jpg.webp"
|
144 |
+
image_2 = "examples/example_image_real_2.jpg.webp"
|
145 |
+
image_3 = "examples/example_image_real_3.jpg"
|
146 |
|
147 |
gr.Examples(
|
148 |
examples=[
|
|
|
159 |
],
|
160 |
)
|
161 |
|
162 |
+
demo.launch(share=False)
|
|
|
|
|
|
|
|
application_2.py
CHANGED
@@ -91,7 +91,7 @@ with gr.Blocks() as demo:
|
|
91 |
with gr.Column(scale=2):
|
92 |
with gr.Accordion("News Analysis"):
|
93 |
detection_button = gr.Button("Verify news")
|
94 |
-
detailed_analysis = gr.HTML()
|
95 |
|
96 |
# Connect events
|
97 |
load_button.click(
|
@@ -116,36 +116,39 @@ with gr.Blocks() as demo:
|
|
116 |
#url_input.change(load_image, inputs=url_input, outputs=image_view)
|
117 |
|
118 |
try:
|
119 |
-
with open('
|
120 |
-
|
121 |
-
with open('
|
122 |
-
|
123 |
-
with open('
|
124 |
-
|
|
|
|
|
125 |
except FileNotFoundError:
|
126 |
print("File not found.")
|
127 |
except Exception as e:
|
128 |
print(f"An error occurred: {e}")
|
129 |
|
130 |
-
title_1 = "
|
131 |
-
title_2 = "
|
132 |
|
133 |
-
image_1 = "
|
134 |
-
image_2 = "
|
|
|
135 |
|
136 |
gr.Examples(
|
137 |
examples=[
|
138 |
-
[title_1, image_1,
|
139 |
-
[
|
140 |
-
[title_1,
|
141 |
],
|
142 |
inputs=[news_title, news_image, news_content],
|
143 |
label="Examples",
|
144 |
example_labels=[
|
145 |
"2 real news",
|
146 |
-
"
|
147 |
-
"1 real news
|
148 |
],
|
149 |
)
|
150 |
|
151 |
-
demo.launch(share=
|
|
|
91 |
with gr.Column(scale=2):
|
92 |
with gr.Accordion("News Analysis"):
|
93 |
detection_button = gr.Button("Verify news")
|
94 |
+
detailed_analysis = gr.HTML("<br>"*40)
|
95 |
|
96 |
# Connect events
|
97 |
load_button.click(
|
|
|
116 |
#url_input.change(load_image, inputs=url_input, outputs=image_view)
|
117 |
|
118 |
try:
|
119 |
+
with open('examples/example_text_real.txt','r', encoding='utf-8') as file:
|
120 |
+
text_real_1 = file.read()
|
121 |
+
with open('examples/example_text_real_2.txt','r', encoding='utf-8') as file:
|
122 |
+
text_real_2 = file.read()
|
123 |
+
with open('examples/example_text_LLM_topic.txt','r', encoding='utf-8') as file:
|
124 |
+
text_llm_topic = file.read()
|
125 |
+
with open('examples/example_text_LLM_modification.txt','r', encoding='utf-8') as file:
|
126 |
+
text_llm_modification = file.read()
|
127 |
except FileNotFoundError:
|
128 |
print("File not found.")
|
129 |
except Exception as e:
|
130 |
print(f"An error occurred: {e}")
|
131 |
|
132 |
+
title_1 = "Southampton news: Leeds target striker Cameron Archer"
|
133 |
+
title_2 = "Southampton news: Leeds target striker Cameron Archer"
|
134 |
|
135 |
+
image_1 = "examples/example_image_real_1.jpg.webp"
|
136 |
+
image_2 = "examples/example_image_real_2.jpg.webp"
|
137 |
+
image_3 = "examples/example_image_real_3.jpg"
|
138 |
|
139 |
gr.Examples(
|
140 |
examples=[
|
141 |
+
[title_1, image_1, text_real_1 + '\n\n' + text_real_2],
|
142 |
+
[title_1, image_2, text_real_1 + '\n\n' + text_llm_modification],
|
143 |
+
[title_1, image_3, text_real_1 + '\n\n' + text_llm_topic],
|
144 |
],
|
145 |
inputs=[news_title, news_image, news_content],
|
146 |
label="Examples",
|
147 |
example_labels=[
|
148 |
"2 real news",
|
149 |
+
"1 real news + 1 LLM modification-based news",
|
150 |
+
"1 real news + 1 LLM topic-based news",
|
151 |
],
|
152 |
)
|
153 |
|
154 |
+
demo.launch(share=True)
|
examples/example_image_input.jpg
ADDED
![]() |
example_image_real_1.jpg.webp β examples/example_image_real_1.jpg.webp
RENAMED
File without changes
|
example_image_real_2.jpg.webp β examples/example_image_real_2.jpg.webp
RENAMED
File without changes
|
example_image_real_3.jpg β examples/example_image_real_3.jpg
RENAMED
File without changes
|
example_image_real_3.jpg.webp β examples/example_image_real_3.jpg.webp
RENAMED
File without changes
|
example_text_LLM_modification.txt β examples/example_text_LLM_modification.txt
RENAMED
File without changes
|
example_text_LLM_topic.txt β examples/example_text_LLM_topic.txt
RENAMED
File without changes
|
example_text_real.txt β examples/example_text_real.txt
RENAMED
File without changes
|
example_text_real_2.txt β examples/example_text_real_2.txt
RENAMED
File without changes
|
src/application/content_detection.py
CHANGED
@@ -23,7 +23,7 @@ class NewsVerification():
|
|
23 |
self.news_prediction_label = ""
|
24 |
self.news_prediction_score = -1
|
25 |
|
26 |
-
self.found_img_url:list[str] = []
|
27 |
self.aligned_sentences:list[dict] = []
|
28 |
self.is_paraphrased:list[bool] = []
|
29 |
self.analyzed_table:list[list] = []
|
@@ -50,42 +50,61 @@ class NewsVerification():
|
|
50 |
print("\tFrom search engine:")
|
51 |
# Classify by search engine
|
52 |
input_sentences = split_into_sentences(self.news_text)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if paraphrase is False:
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
"is_paraphrase_sentence": False,
|
67 |
-
"url": "",
|
68 |
-
}
|
69 |
else:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def detect_image_origin(self):
|
91 |
print("CHECK IMAGE:")
|
@@ -95,7 +114,8 @@ class NewsVerification():
|
|
95 |
self.image_referent_url = None
|
96 |
return
|
97 |
|
98 |
-
|
|
|
99 |
matched_url, similarity = detect_image_from_news_image(self.news_image, self.found_img_url)
|
100 |
if matched_url is not None:
|
101 |
print(f"matching image: {matched_url}\nsimilarity: {similarity}\n")
|
@@ -114,6 +134,7 @@ class NewsVerification():
|
|
114 |
|
115 |
detected_label, score = detect_image_by_ai_model(self.news_image)
|
116 |
if detected_label:
|
|
|
117 |
self.image_prediction_label = detected_label
|
118 |
self.image_prediction_score = score
|
119 |
self.image_referent_url = None
|
@@ -346,13 +367,15 @@ class NewsVerification():
|
|
346 |
|
347 |
# short_url = self.shorten_url(self.text_referent_url[index], max_length)
|
348 |
# source_text_url = f"""<a href="{self.text_referent_url[index]}">{short_url}</a>"""
|
349 |
-
|
350 |
-
self.
|
351 |
-
|
|
|
|
|
|
|
352 |
|
353 |
def format_image_row(self, max_length=30):
|
354 |
-
input_image = f"""<img src="example_image_input.jpg" width="200" height="150">"""
|
355 |
-
print(f"self.news_image = {self.news_image}")
|
356 |
|
357 |
if self.image_referent_url is not None or self.image_referent_url != "":
|
358 |
source_image = f"""<img src="{self.image_referent_url}" width="200" height="150">"""
|
@@ -360,6 +383,8 @@ class NewsVerification():
|
|
360 |
source_image_url = f"""<a href="{self.image_referent_url}">{short_url}</a>"""
|
361 |
else:
|
362 |
source_image = "Image not found"
|
|
|
|
|
363 |
return f"""<tr><td>input image</td><td>{source_image}</td><td>{self.image_prediction_label}<br>({self.image_prediction_score:.2f}%)</td><td>{source_image_url}</td></tr>"""
|
364 |
|
365 |
def shorten_url(self, url, max_length=30):
|
|
|
23 |
self.news_prediction_label = ""
|
24 |
self.news_prediction_score = -1
|
25 |
|
26 |
+
self.found_img_url:list[str] = ["https://ichef.bbci.co.uk/ace/standard/819/cpsprodpb/8acc/live/86282470-defb-11ef-ba00-65100a906e68.jpg"]
|
27 |
self.aligned_sentences:list[dict] = []
|
28 |
self.is_paraphrased:list[bool] = []
|
29 |
self.analyzed_table:list[list] = []
|
|
|
50 |
print("\tFrom search engine:")
|
51 |
# Classify by search engine
|
52 |
input_sentences = split_into_sentences(self.news_text)
|
53 |
+
current_index = 0
|
54 |
+
previous_paraphrase = None
|
55 |
+
ai_sentence = {
|
56 |
+
"input_sentence": "",
|
57 |
+
"matched_sentence": "",
|
58 |
+
"label": "",
|
59 |
+
"similarity": None,
|
60 |
+
"paraphrase": False,
|
61 |
+
"url": "",
|
62 |
+
}
|
63 |
+
for index, sentence in enumerate(input_sentences):
|
64 |
+
if current_index >= index:
|
65 |
+
continue
|
66 |
+
print(f"-------index = {index}-------")
|
67 |
+
paraphrase, text_url, searched_sentences, img_urls, current_index = detect_text_by_relative_search(input_sentences, index)
|
68 |
if paraphrase is False:
|
69 |
+
# add sentence to ai_sentence
|
70 |
+
ai_sentence["input_sentence"] += sentence
|
71 |
+
if index == len(input_sentences) - 1:
|
72 |
+
# add ai_sentences to align_sentences
|
73 |
+
text_prediction_label, text_prediction_score = detect_text_by_ai_model(ai_sentence["input_sentence"])
|
74 |
+
ai_sentence["label"] = text_prediction_label
|
75 |
+
ai_sentence["similarity"] = text_prediction_score
|
76 |
+
self.aligned_sentences.append(ai_sentence)
|
|
|
|
|
|
|
77 |
else:
|
78 |
+
if previous_paraphrase is False or previous_paraphrase is None:
|
79 |
+
# add ai_sentences to align_sentences
|
80 |
+
if ai_sentence["input_sentence"] != "":
|
81 |
+
text_prediction_label, text_prediction_score = detect_text_by_ai_model(ai_sentence["input_sentence"])
|
82 |
+
ai_sentence["label"] = text_prediction_label
|
83 |
+
ai_sentence["similarity"] = text_prediction_score
|
84 |
+
self.aligned_sentences.append(ai_sentence)
|
85 |
+
|
86 |
+
# reset
|
87 |
+
ai_sentence = {
|
88 |
+
"input_sentence": "",
|
89 |
+
"matched_sentence": "",
|
90 |
+
"label": "",
|
91 |
+
"similarity": None,
|
92 |
+
"paraphrase": False,
|
93 |
+
"url": "",
|
94 |
+
}
|
95 |
+
|
96 |
+
# add searched_sentences to align_sentences
|
97 |
+
if searched_sentences["input_sentence"] != "":
|
98 |
+
self.found_img_url.extend(img_urls)
|
99 |
+
if check_human(searched_sentences):
|
100 |
+
searched_sentences["label"] = "HUMAN"
|
101 |
+
else:
|
102 |
+
searched_sentences["label"] = "MACHINE"
|
103 |
+
|
104 |
+
self.aligned_sentences.append(searched_sentences)
|
105 |
+
|
106 |
+
previous_paraphrase = paraphrase
|
107 |
+
#self.found_img_url = list(set(self.found_img_url))
|
108 |
|
109 |
def detect_image_origin(self):
|
110 |
print("CHECK IMAGE:")
|
|
|
114 |
self.image_referent_url = None
|
115 |
return
|
116 |
|
117 |
+
for image in self.found_img_url:
|
118 |
+
print(f"\tfound_img_url: {image}")
|
119 |
matched_url, similarity = detect_image_from_news_image(self.news_image, self.found_img_url)
|
120 |
if matched_url is not None:
|
121 |
print(f"matching image: {matched_url}\nsimilarity: {similarity}\n")
|
|
|
134 |
|
135 |
detected_label, score = detect_image_by_ai_model(self.news_image)
|
136 |
if detected_label:
|
137 |
+
print(f"detected_label: {detected_label} ({score})")
|
138 |
self.image_prediction_label = detected_label
|
139 |
self.image_prediction_score = score
|
140 |
self.image_referent_url = None
|
|
|
367 |
|
368 |
# short_url = self.shorten_url(self.text_referent_url[index], max_length)
|
369 |
# source_text_url = f"""<a href="{self.text_referent_url[index]}">{short_url}</a>"""
|
370 |
+
#label = self.aligned_sentences[index]["label"]
|
371 |
+
print(self.aligned_sentences)
|
372 |
+
print(index)
|
373 |
+
label = self.aligned_sentences[index]["label"]
|
374 |
+
score = self.aligned_sentences[index]["similarity"]
|
375 |
+
return f"""<tr><td>{input_sentence}</td><td>{source_sentence}</td><td>{label}<br>({score*100:.2f}%)</td><td>{source_text_url}</td></tr>"""
|
376 |
|
377 |
def format_image_row(self, max_length=30):
|
378 |
+
# input_image = f"""<img src="example_image_input.jpg" width="200" height="150">"""
|
|
|
379 |
|
380 |
if self.image_referent_url is not None or self.image_referent_url != "":
|
381 |
source_image = f"""<img src="{self.image_referent_url}" width="200" height="150">"""
|
|
|
383 |
source_image_url = f"""<a href="{self.image_referent_url}">{short_url}</a>"""
|
384 |
else:
|
385 |
source_image = "Image not found"
|
386 |
+
source_image_url = ""
|
387 |
+
|
388 |
return f"""<tr><td>input image</td><td>{source_image}</td><td>{self.image_prediction_label}<br>({self.image_prediction_score:.2f}%)</td><td>{source_image_url}</td></tr>"""
|
389 |
|
390 |
def shorten_url(self, url, max_length=30):
|
src/application/image/model_detection.py
CHANGED
@@ -130,7 +130,7 @@ def image_generation_detection(image_path):
|
|
130 |
image_prediction_label = "MACHINE"
|
131 |
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
|
132 |
result += f" with confidence = {round(image_confidence * 100, 2)}%"
|
133 |
-
image_confidence = round(image_confidence * 100, 2)
|
134 |
return image_prediction_label, image_confidence
|
135 |
|
136 |
|
|
|
130 |
image_prediction_label = "MACHINE"
|
131 |
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
|
132 |
result += f" with confidence = {round(image_confidence * 100, 2)}%"
|
133 |
+
# image_confidence = round(image_confidence * 100, 2)
|
134 |
return image_prediction_label, image_confidence
|
135 |
|
136 |
|
src/application/text/helper.py
CHANGED
@@ -144,6 +144,35 @@ def extract_important_phrases(paragraph: str, keywords: list[str], phrase_length
|
|
144 |
|
145 |
return important_phrases
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
"""# Example usage
|
148 |
keywords = get_keywords(paragraph)
|
149 |
important_phrases = extract_important_phrases(paragraph, keywords)
|
|
|
144 |
|
145 |
return important_phrases
|
146 |
|
147 |
+
def connect_consecutive_indexes(nums):
|
148 |
+
"""
|
149 |
+
Connects consecutive integers in a list.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
nums: A list of integers.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
A list of lists, where each inner list represents a consecutive range.
|
156 |
+
"""
|
157 |
+
|
158 |
+
if not nums: # Handle empty input
|
159 |
+
return []
|
160 |
+
|
161 |
+
result = []
|
162 |
+
start = nums[0]
|
163 |
+
end = nums[0]
|
164 |
+
|
165 |
+
for i in range(1, len(nums)):
|
166 |
+
if nums[i] == end + 1:
|
167 |
+
end = nums[i]
|
168 |
+
else:
|
169 |
+
result.append([start, end])
|
170 |
+
start = nums[i]
|
171 |
+
end = nums[i]
|
172 |
+
|
173 |
+
result.append([start, end]) # Add the last range
|
174 |
+
return result
|
175 |
+
|
176 |
"""# Example usage
|
177 |
keywords = get_keywords(paragraph)
|
178 |
important_phrases = extract_important_phrases(paragraph, keywords)
|
src/application/text/search_detection.py
CHANGED
@@ -33,10 +33,9 @@ MIN_RATIO_PARAPHRASE_NUM = 0.7
|
|
33 |
MAX_CHAR_SIZE = 30000
|
34 |
|
35 |
|
36 |
-
def detect_text_by_relative_search(input_text, is_support_opposite = False):
|
37 |
-
|
38 |
checked_urls = set()
|
39 |
-
searched_phrases = generate_search_phrases(input_text)
|
40 |
|
41 |
for candidate in searched_phrases:
|
42 |
search_results = search_by_google(candidate)
|
@@ -59,15 +58,36 @@ def detect_text_by_relative_search(input_text, is_support_opposite = False):
|
|
59 |
continue
|
60 |
|
61 |
page_text = content.title + "\n" + content.text
|
|
|
62 |
if len(page_text) > MAX_CHAR_SIZE:
|
63 |
print(f"\t\t\tβββ More than {MAX_CHAR_SIZE} characters")
|
64 |
continue
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
return False, None, [], []
|
71 |
|
72 |
def longest_common_subsequence(arr1, arr2):
|
73 |
"""
|
@@ -151,7 +171,7 @@ def check_sentence(input_sentence, source_sentence, min_same_sentence_len,
|
|
151 |
return False
|
152 |
|
153 |
|
154 |
-
def check_paraphrase(input_text, page_text, url
|
155 |
"""
|
156 |
Checks if the input text is paraphrased in the content at the given URL.
|
157 |
|
@@ -183,7 +203,6 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
183 |
return is_paraphrase_text, []
|
184 |
#page_text = remove_punctuation(page_text)
|
185 |
page_sentences = split_into_sentences(page_text)
|
186 |
-
|
187 |
if not input_sentences or not page_sentences:
|
188 |
return is_paraphrase_text, []
|
189 |
|
@@ -193,7 +212,7 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
193 |
additional_sentences.append(sentence.replace(", external", ""))
|
194 |
page_sentences.extend(additional_sentences)
|
195 |
|
196 |
-
min_matching_sentences = math.ceil(len(input_sentences) * MIN_RATIO_PARAPHRASE_NUM)
|
197 |
|
198 |
# Encode sentences into embeddings
|
199 |
embeddings1 = PARAPHASE_MODEL.encode(input_sentences, convert_to_tensor=True, device=DEVICE)
|
@@ -206,18 +225,18 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
206 |
alignment = {}
|
207 |
paraphrased_sentence_count = 0
|
208 |
for i, sentence1 in enumerate(input_sentences):
|
209 |
-
print(f"allign: {i}")
|
210 |
max_sim_index = np.argmax(similarity_matrix[i])
|
211 |
max_similarity = similarity_matrix[i][max_sim_index]
|
212 |
|
213 |
is_paraphrase_sentence = max_similarity > PARAPHRASE_THRESHOLD
|
214 |
|
215 |
-
if
|
216 |
alignment = {
|
217 |
"input_sentence": sentence1,
|
218 |
"matched_sentence": "",
|
219 |
"similarity": max_similarity,
|
220 |
-
"
|
|
|
221 |
"url": "",
|
222 |
}
|
223 |
else:
|
@@ -225,7 +244,8 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
225 |
"input_sentence": sentence1,
|
226 |
"matched_sentence": page_sentences[max_sim_index],
|
227 |
"similarity": max_similarity,
|
228 |
-
"
|
|
|
229 |
"url": url,
|
230 |
}
|
231 |
|
@@ -234,9 +254,6 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
234 |
sentence1, page_sentences[max_sim_index], MIN_SAME_SENTENCE_LEN, MIN_PHRASE_SENTENCE_LEN
|
235 |
):
|
236 |
is_paraphrase_text = True
|
237 |
-
if verbose:
|
238 |
-
print(f"Paraphrase found for individual sentence: {sentence1}")
|
239 |
-
print(f"Matched sentence: {page_sentences[max_sim_index]}")
|
240 |
|
241 |
#alignment.append(item)
|
242 |
paraphrased_sentence_count += 1 if is_paraphrase_sentence else 0
|
@@ -245,15 +262,6 @@ def check_paraphrase(input_text, page_text, url, verbose=False):
|
|
245 |
|
246 |
is_paraphrase_text = paraphrased_sentence_count > 0 #min_matching_sentences
|
247 |
|
248 |
-
if verbose:
|
249 |
-
print (f"\t\tparaphrased_sentence_count: {paraphrased_sentence_count}, min_matching_sentences: {min_matching_sentences}, total_sentence_count: {len(input_sentences)}")
|
250 |
-
print(f"Minimum matching sentences required: {min_matching_sentences}")
|
251 |
-
print(f"Total input sentences: {len(input_sentences)}")
|
252 |
-
print(f"Number of matching sentences: {paraphrased_sentence_count}")
|
253 |
-
print(f"Is paraphrase: {is_paraphrase_text}")
|
254 |
-
for item in alignment:
|
255 |
-
print(item)
|
256 |
-
|
257 |
return is_paraphrase_text, alignment
|
258 |
|
259 |
def similarity_ratio(a, b):
|
|
|
33 |
MAX_CHAR_SIZE = 30000
|
34 |
|
35 |
|
36 |
+
def detect_text_by_relative_search(input_text, index, is_support_opposite = False):
|
|
|
37 |
checked_urls = set()
|
38 |
+
searched_phrases = generate_search_phrases(input_text[index])
|
39 |
|
40 |
for candidate in searched_phrases:
|
41 |
search_results = search_by_google(candidate)
|
|
|
58 |
continue
|
59 |
|
60 |
page_text = content.title + "\n" + content.text
|
61 |
+
print(f"page_text: {page_text}")
|
62 |
if len(page_text) > MAX_CHAR_SIZE:
|
63 |
print(f"\t\t\tβββ More than {MAX_CHAR_SIZE} characters")
|
64 |
continue
|
65 |
|
66 |
+
paraphrase, aligned_first_sentences = check_paraphrase(input_text[index], page_text, url)
|
67 |
+
|
68 |
+
if paraphrase is False:
|
69 |
+
return paraphrase, url, aligned_first_sentences, content.images, index
|
70 |
+
|
71 |
+
sub_paraphrase = True
|
72 |
+
while sub_paraphrase == True:
|
73 |
+
index += 1
|
74 |
+
print(f"----search {index}----")
|
75 |
+
if index >= len(input_text):
|
76 |
+
break
|
77 |
+
sub_paraphrase, sub_sentences = check_paraphrase(input_text[index], page_text, url)
|
78 |
+
print(f"sub_paraphrase: {sub_paraphrase}")
|
79 |
+
print(f"sub_sentences: {sub_sentences}")
|
80 |
+
if sub_paraphrase == True:
|
81 |
+
aligned_first_sentences["input_sentence"] += sub_sentences["input_sentence"]
|
82 |
+
aligned_first_sentences["matched_sentence"] += sub_sentences["matched_sentence"]
|
83 |
+
aligned_first_sentences["similarity"] += sub_sentences["similarity"]
|
84 |
+
aligned_first_sentences["similarity"] /= 2
|
85 |
+
|
86 |
+
print(f"paraphrase: {paraphrase}")
|
87 |
+
print(f"aligned_first_sentences: {aligned_first_sentences}")
|
88 |
+
return paraphrase, url, aligned_first_sentences, content.images, index
|
89 |
|
90 |
+
return False, None, [], [], index
|
91 |
|
92 |
def longest_common_subsequence(arr1, arr2):
|
93 |
"""
|
|
|
171 |
return False
|
172 |
|
173 |
|
174 |
+
def check_paraphrase(input_text, page_text, url):
|
175 |
"""
|
176 |
Checks if the input text is paraphrased in the content at the given URL.
|
177 |
|
|
|
203 |
return is_paraphrase_text, []
|
204 |
#page_text = remove_punctuation(page_text)
|
205 |
page_sentences = split_into_sentences(page_text)
|
|
|
206 |
if not input_sentences or not page_sentences:
|
207 |
return is_paraphrase_text, []
|
208 |
|
|
|
212 |
additional_sentences.append(sentence.replace(", external", ""))
|
213 |
page_sentences.extend(additional_sentences)
|
214 |
|
215 |
+
# min_matching_sentences = math.ceil(len(input_sentences) * MIN_RATIO_PARAPHRASE_NUM)
|
216 |
|
217 |
# Encode sentences into embeddings
|
218 |
embeddings1 = PARAPHASE_MODEL.encode(input_sentences, convert_to_tensor=True, device=DEVICE)
|
|
|
225 |
alignment = {}
|
226 |
paraphrased_sentence_count = 0
|
227 |
for i, sentence1 in enumerate(input_sentences):
|
|
|
228 |
max_sim_index = np.argmax(similarity_matrix[i])
|
229 |
max_similarity = similarity_matrix[i][max_sim_index]
|
230 |
|
231 |
is_paraphrase_sentence = max_similarity > PARAPHRASE_THRESHOLD
|
232 |
|
233 |
+
if is_paraphrase_sentence is False:
|
234 |
alignment = {
|
235 |
"input_sentence": sentence1,
|
236 |
"matched_sentence": "",
|
237 |
"similarity": max_similarity,
|
238 |
+
"label": "",
|
239 |
+
"paraphrase": is_paraphrase_sentence,
|
240 |
"url": "",
|
241 |
}
|
242 |
else:
|
|
|
244 |
"input_sentence": sentence1,
|
245 |
"matched_sentence": page_sentences[max_sim_index],
|
246 |
"similarity": max_similarity,
|
247 |
+
"label": "",
|
248 |
+
"paraphrase": is_paraphrase_sentence,
|
249 |
"url": url,
|
250 |
}
|
251 |
|
|
|
254 |
sentence1, page_sentences[max_sim_index], MIN_SAME_SENTENCE_LEN, MIN_PHRASE_SENTENCE_LEN
|
255 |
):
|
256 |
is_paraphrase_text = True
|
|
|
|
|
|
|
257 |
|
258 |
#alignment.append(item)
|
259 |
paraphrased_sentence_count += 1 if is_paraphrase_sentence else 0
|
|
|
262 |
|
263 |
is_paraphrase_text = paraphrased_sentence_count > 0 #min_matching_sentences
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
return is_paraphrase_text, alignment
|
266 |
|
267 |
def similarity_ratio(a, b):
|
src/application/url_reader.py
CHANGED
@@ -109,4 +109,11 @@ class URLReader():
|
|
109 |
|
110 |
except requests.exceptions.RequestException as e:
|
111 |
print(f"\t\tβββ Error getting URL size: {e}")
|
112 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
except requests.exceptions.RequestException as e:
|
111 |
print(f"\t\tβββ Error getting URL size: {e}")
|
112 |
+
return None
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
url = "https://www.bbc.com/sport/football/articles/c2d3rdy3673o"
|
117 |
+
reader = URLReader(url)
|
118 |
+
print(f"Title: {reader.title}")
|
119 |
+
print(f"Text: {reader.text}")
|
src/images/CNN_model_classifier.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch.nn
|
4 |
-
import torchvision.transforms as transforms
|
5 |
-
from PIL import Image
|
6 |
-
|
7 |
-
from .CNN.networks.resnet import resnet50
|
8 |
-
|
9 |
-
|
10 |
-
def predict_cnn(image, model_path, crop=None):
|
11 |
-
model = resnet50(num_classes=1)
|
12 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
13 |
-
model.load_state_dict(state_dict["model"])
|
14 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
-
model.to(device)
|
16 |
-
model.eval()
|
17 |
-
|
18 |
-
# Transform
|
19 |
-
if crop is not None:
|
20 |
-
trans_init = [transforms.CenterCrop(crop)]
|
21 |
-
print("Cropping to [%i]" % crop)
|
22 |
-
trans = transforms.Compose(
|
23 |
-
trans_init
|
24 |
-
+ [
|
25 |
-
transforms.ToTensor(),
|
26 |
-
transforms.Normalize(
|
27 |
-
mean=[0.485, 0.456, 0.406],
|
28 |
-
std=[0.229, 0.224, 0.225],
|
29 |
-
),
|
30 |
-
],
|
31 |
-
)
|
32 |
-
|
33 |
-
image = trans(image.convert("RGB"))
|
34 |
-
|
35 |
-
with torch.no_grad():
|
36 |
-
in_tens = image.unsqueeze(0)
|
37 |
-
prob = model(in_tens).sigmoid().item()
|
38 |
-
|
39 |
-
return prob
|
40 |
-
|
41 |
-
|
42 |
-
if __name__ == "__main__":
|
43 |
-
parser = argparse.ArgumentParser(
|
44 |
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
45 |
-
)
|
46 |
-
parser.add_argument("-f", "--file", default="examples_realfakedir")
|
47 |
-
parser.add_argument(
|
48 |
-
"-m",
|
49 |
-
"--model_path",
|
50 |
-
type=str,
|
51 |
-
default="weights/blur_jpg_prob0.5.pth",
|
52 |
-
)
|
53 |
-
parser.add_argument(
|
54 |
-
"-c",
|
55 |
-
"--crop",
|
56 |
-
type=int,
|
57 |
-
default=None,
|
58 |
-
help="by default, do not crop. specify crop size",
|
59 |
-
)
|
60 |
-
|
61 |
-
opt = parser.parse_args()
|
62 |
-
prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop)
|
63 |
-
print(f"probability of being synthetic: {prob * 100:.2f}%")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/Final_Report.pdf
DELETED
Binary file (359 kB)
|
|
src/images/Diffusion/Pipfile
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
[[source]]
|
2 |
-
url = "https://pypi.org/simple"
|
3 |
-
verify_ssl = true
|
4 |
-
name = "pypi"
|
5 |
-
|
6 |
-
[[source]]
|
7 |
-
url = "https://download.pytorch.org/whl/cu121"
|
8 |
-
verify_ssl = true
|
9 |
-
name = "downloadpytorch"
|
10 |
-
|
11 |
-
[packages]
|
12 |
-
pandas = "*"
|
13 |
-
numpy = "*"
|
14 |
-
polars = "*"
|
15 |
-
requests = "*"
|
16 |
-
img2dataset = "*"
|
17 |
-
torch = {version = "==2.1.0", index = "downloadpytorch"}
|
18 |
-
torchvision = {version = "==0.16.0", index = "downloadpytorch"}
|
19 |
-
lightning = "*"
|
20 |
-
webdataset = "*"
|
21 |
-
matplotlib = "*"
|
22 |
-
invisible-watermark = "*"
|
23 |
-
torchdata = "*"
|
24 |
-
timm = "*"
|
25 |
-
|
26 |
-
[dev-packages]
|
27 |
-
|
28 |
-
[requires]
|
29 |
-
python_version = "3.11"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/Pipfile.lock
DELETED
The diff for this file is too large to render.
See raw diff
|
|
src/images/Diffusion/README.md
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
# AI-generated image detection
|
2 |
-
|
3 |
-
This is a group project developed by a team of two individuals.
|
4 |
-
|
5 |
-
## Managing Python packages
|
6 |
-
|
7 |
-
Use of `pipenv` is recommended. The required packages are in `Pipfile`, and can be installed using `pipenv install`.
|
8 |
-
|
9 |
-
## Scraping script for Reddit
|
10 |
-
|
11 |
-
`python scrape.py --subreddit midjourney --flair Showcase`
|
12 |
-
|
13 |
-
This command will scrape the midjourney subreddit, and filter posts that contain the "Showcase" flair. The default number of images to scrape is 30000. The output will contain a parquet file containing metadata, and a csv file containing the urls.
|
14 |
-
|
15 |
-
`img2dataset --url_list=urls/midjourney.csv --output_folder=data/midjourney --thread_count=64 --resize_mode=no --output_format=webdataset`
|
16 |
-
|
17 |
-
This command will download the images in the webdataset format.
|
18 |
-
|
19 |
-
|
20 |
-
## Laion script for real images
|
21 |
-
|
22 |
-
`wget -l1 -r --no-parent https://the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/
|
23 |
-
mv the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/ .`
|
24 |
-
|
25 |
-
This command will download a 50GB url metadata dataset in 32 parquet files.
|
26 |
-
|
27 |
-
`sample_laion_script.ipynb`
|
28 |
-
|
29 |
-
This script consolidates the parquet files, excludes NSFW images, and selects a subset of 224,917 images.
|
30 |
-
|
31 |
-
`combine_laion_script`
|
32 |
-
|
33 |
-
This script combines the outputs from earlier into 1 parquet file.
|
34 |
-
|
35 |
-
`img2dataset --url_list urls/laion.parquet --input_format "parquet" --url_col "URL" --caption_col "TEXT" --skip_reencode True --output_format webdataset --output_folder data/laion400m_data --processes_count 16 --thread_count 128 --resize_mode no --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True`
|
36 |
-
|
37 |
-
This command will download the images in the webdataset format.
|
38 |
-
|
39 |
-
|
40 |
-
## Data splitting, preprocessing and loading
|
41 |
-
|
42 |
-
`data_split.py` splits the data according to 80/10/10. The number of samples:
|
43 |
-
|
44 |
-
```
|
45 |
-
./data/laion400m_data: (115346, 14418, 14419)
|
46 |
-
./data/genai-images/StableDiffusion: (22060, 2757, 2758)
|
47 |
-
./data/genai-images/midjourney: (21096, 2637, 2637)
|
48 |
-
./data/genai-images/dalle2: (13582, 1697, 1699)
|
49 |
-
./data/genai-images/dalle3: (12027, 1503, 1504)
|
50 |
-
```
|
51 |
-
|
52 |
-
Each sample contains image, target label(1 for GenAI images), and domain label(denoting which generator the image is from). The meaning of the domain label is:
|
53 |
-
|
54 |
-
```
|
55 |
-
DOMAIN_LABELS = {
|
56 |
-
0: "laion",
|
57 |
-
1: "StableDiffusion",
|
58 |
-
2: "dalle2",
|
59 |
-
3: "dalle3",
|
60 |
-
4: "midjourney"
|
61 |
-
}
|
62 |
-
```
|
63 |
-
|
64 |
-
The `load_dataloader()` function in `dataloader.py` returns a `torchdata.dataloader2.DataLoader2` given a list of domains for GenAI images(subset of `[1, 2, 3, 4]`, LAION will always be included). When building the training dataset, data augmentation and class balanced sampling are applied. It is very memory intensive(>20G) and takes some time to fill its buffer before producing batches. Use the dataloader in this way:
|
65 |
-
|
66 |
-
```
|
67 |
-
for epoch in range(10):
|
68 |
-
dl.seed(epoch)
|
69 |
-
for d in dl:
|
70 |
-
model(d)
|
71 |
-
dl.shutdown()
|
72 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/combine_laion_script.ipynb
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"pip install pyspark"
|
10 |
-
]
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"cell_type": "code",
|
14 |
-
"execution_count": null,
|
15 |
-
"metadata": {},
|
16 |
-
"outputs": [],
|
17 |
-
"source": [
|
18 |
-
"import os\n",
|
19 |
-
"current_directory = os.getcwd()\n",
|
20 |
-
"print(current_directory)"
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": null,
|
26 |
-
"metadata": {},
|
27 |
-
"outputs": [],
|
28 |
-
"source": [
|
29 |
-
"os.chdir(current_directory)\n"
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "code",
|
34 |
-
"execution_count": null,
|
35 |
-
"metadata": {},
|
36 |
-
"outputs": [],
|
37 |
-
"source": [
|
38 |
-
"import pandas as pd\n",
|
39 |
-
"from pyspark.sql import SparkSession\n",
|
40 |
-
"from pyspark.sql.functions import col\n",
|
41 |
-
"\n",
|
42 |
-
"spark = SparkSession.builder.appName(\"CombineParquetFiles\").config(\"spark.executor.memory\", \"8g\").config(\"spark.executor.cores\", \"4\").config(\"spark.executor.instances\", \"3\").config(\"spark.dynamicAllocation.enabled\", \"true\").config(\"spark.task.maxFailures\", 10).getOrCreate()"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "code",
|
47 |
-
"execution_count": null,
|
48 |
-
"metadata": {},
|
49 |
-
"outputs": [],
|
50 |
-
"source": [
|
51 |
-
"parquet_directory_path = '/Users/fionachow/Documents/NYU/CDS/Fall 2023/CSCI - GA 2271 - Computer Vision/Project/laion_sampled'\n",
|
52 |
-
"\n",
|
53 |
-
"output_parquet_file = '/Users/fionachow/Documents/NYU/CDS/Fall 2023/CSCI - GA 2271 - Computer Vision/Project/laion_combined'\n",
|
54 |
-
"\n",
|
55 |
-
"df = spark.read.parquet(parquet_directory_path)\n",
|
56 |
-
"\n",
|
57 |
-
"df_coalesced = df.coalesce(1)\n",
|
58 |
-
"\n",
|
59 |
-
"df_coalesced.write.mode('overwrite').parquet(output_parquet_file)\n",
|
60 |
-
"\n",
|
61 |
-
"row_count = df.count()"
|
62 |
-
]
|
63 |
-
},
|
64 |
-
{
|
65 |
-
"cell_type": "code",
|
66 |
-
"execution_count": null,
|
67 |
-
"metadata": {},
|
68 |
-
"outputs": [],
|
69 |
-
"source": [
|
70 |
-
"print(row_count)"
|
71 |
-
]
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"cell_type": "code",
|
75 |
-
"execution_count": null,
|
76 |
-
"metadata": {},
|
77 |
-
"outputs": [],
|
78 |
-
"source": [
|
79 |
-
"parquet_directory_path = '/Users/fionachow/Documents/NYU/CDS/Fall 2023/CSCI - GA 2271 - Computer Vision/Project/laion_combined/part-00000-0190eea7-02ac-4ea0-86fd-0722308c0c58-c000.snappy.parquet'\n",
|
80 |
-
"\n",
|
81 |
-
"df = spark.read.parquet(parquet_directory_path)\n",
|
82 |
-
"\n",
|
83 |
-
"df.show()"
|
84 |
-
]
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"cell_type": "code",
|
88 |
-
"execution_count": null,
|
89 |
-
"metadata": {},
|
90 |
-
"outputs": [],
|
91 |
-
"source": [
|
92 |
-
"print(df.count())"
|
93 |
-
]
|
94 |
-
}
|
95 |
-
],
|
96 |
-
"metadata": {
|
97 |
-
"kernelspec": {
|
98 |
-
"display_name": "bloom",
|
99 |
-
"language": "python",
|
100 |
-
"name": "python3"
|
101 |
-
},
|
102 |
-
"language_info": {
|
103 |
-
"codemirror_mode": {
|
104 |
-
"name": "ipython",
|
105 |
-
"version": 3
|
106 |
-
},
|
107 |
-
"file_extension": ".py",
|
108 |
-
"mimetype": "text/x-python",
|
109 |
-
"name": "python",
|
110 |
-
"nbconvert_exporter": "python",
|
111 |
-
"pygments_lexer": "ipython3",
|
112 |
-
"version": "3.9.16"
|
113 |
-
}
|
114 |
-
},
|
115 |
-
"nbformat": 4,
|
116 |
-
"nbformat_minor": 2
|
117 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/data_split.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import json
|
3 |
-
|
4 |
-
import webdataset as wds
|
5 |
-
|
6 |
-
|
7 |
-
def split_dataset(path, n_train, n_val, n_test, label, domain_label):
|
8 |
-
max_file_size = 1000
|
9 |
-
input_files = glob.glob(path + "/*.tar")
|
10 |
-
src = wds.WebDataset(input_files)
|
11 |
-
|
12 |
-
train_path_prefix = path + "/train"
|
13 |
-
val_path_prefix = path + "/val"
|
14 |
-
test_path_prefix = path + "/test"
|
15 |
-
|
16 |
-
def write_split(dataset, prefix, start, end):
|
17 |
-
n_split = end - start
|
18 |
-
output_files = [
|
19 |
-
f"{prefix}_{i}.tar" for i in range(n_split // max_file_size + 1)
|
20 |
-
]
|
21 |
-
for i, output_file in enumerate(output_files):
|
22 |
-
print(f"Writing {output_file}")
|
23 |
-
with wds.TarWriter(output_file) as dst:
|
24 |
-
for sample in dataset.slice(
|
25 |
-
start + i * max_file_size,
|
26 |
-
min(start + (i + 1) * max_file_size, end),
|
27 |
-
):
|
28 |
-
new_sample = {
|
29 |
-
"__key__": sample["__key__"],
|
30 |
-
"jpg": sample["jpg"],
|
31 |
-
"label.cls": label,
|
32 |
-
"domain_label.cls": domain_label,
|
33 |
-
}
|
34 |
-
dst.write(new_sample)
|
35 |
-
|
36 |
-
write_split(src, train_path_prefix, 0, n_train)
|
37 |
-
write_split(src, val_path_prefix, n_train, n_train + n_val)
|
38 |
-
write_split(
|
39 |
-
src,
|
40 |
-
test_path_prefix,
|
41 |
-
n_train + n_val,
|
42 |
-
n_train + n_val + n_test,
|
43 |
-
)
|
44 |
-
|
45 |
-
|
46 |
-
def calculate_sizes(path):
|
47 |
-
stat_files = glob.glob(path + "/*_stats.json")
|
48 |
-
total = 0
|
49 |
-
for f in stat_files:
|
50 |
-
with open(f) as stats:
|
51 |
-
total += json.load(stats)["successes"]
|
52 |
-
n_train = int(total * 0.8)
|
53 |
-
n_val = int(total * 0.1)
|
54 |
-
n_test = total - n_train - n_val
|
55 |
-
|
56 |
-
return n_train, n_val, n_test
|
57 |
-
|
58 |
-
|
59 |
-
if __name__ == "__main__":
|
60 |
-
|
61 |
-
paths = [
|
62 |
-
"./data/laion400m_data",
|
63 |
-
"./data/genai-images/StableDiffusion",
|
64 |
-
"./data/genai-images/midjourney",
|
65 |
-
"./data/genai-images/dalle2",
|
66 |
-
"./data/genai-images/dalle3",
|
67 |
-
]
|
68 |
-
|
69 |
-
sizes = []
|
70 |
-
for p in paths:
|
71 |
-
res = calculate_sizes(p)
|
72 |
-
sizes.append(res)
|
73 |
-
|
74 |
-
domain_labels = [0, 1, 4, 2, 3]
|
75 |
-
|
76 |
-
for i, p in enumerate(paths):
|
77 |
-
print(f"{p}: {sizes[i]}")
|
78 |
-
label = 0 if i == 0 else 1
|
79 |
-
print(label, domain_labels[i])
|
80 |
-
split_dataset(p, *calculate_sizes(p), label, domain_labels[i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/dataloader.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import collections
|
3 |
-
import random
|
4 |
-
from typing import Iterator
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
import torchdata.datapipes as dp
|
9 |
-
from imwatermark import WatermarkEncoder
|
10 |
-
from PIL import (
|
11 |
-
Image,
|
12 |
-
ImageFile,
|
13 |
-
)
|
14 |
-
from torch.utils.data import DataLoader
|
15 |
-
from torchdata.datapipes.iter import (
|
16 |
-
Concater,
|
17 |
-
FileLister,
|
18 |
-
FileOpener,
|
19 |
-
SampleMultiplexer,
|
20 |
-
)
|
21 |
-
from torchvision.transforms import v2
|
22 |
-
from tqdm import tqdm
|
23 |
-
|
24 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
25 |
-
Image.MAX_IMAGE_PIXELS = 1000000000
|
26 |
-
|
27 |
-
encoder = WatermarkEncoder()
|
28 |
-
encoder.set_watermark("bytes", b"test")
|
29 |
-
|
30 |
-
|
31 |
-
DOMAIN_LABELS = {
|
32 |
-
0: "laion",
|
33 |
-
1: "StableDiffusion",
|
34 |
-
2: "dalle2",
|
35 |
-
3: "dalle3",
|
36 |
-
4: "midjourney",
|
37 |
-
}
|
38 |
-
|
39 |
-
N_SAMPLES = {
|
40 |
-
0: (115346, 14418, 14419),
|
41 |
-
1: (22060, 2757, 2758),
|
42 |
-
4: (21096, 2637, 2637),
|
43 |
-
2: (13582, 1697, 1699),
|
44 |
-
3: (12027, 1503, 1504),
|
45 |
-
}
|
46 |
-
|
47 |
-
|
48 |
-
@dp.functional_datapipe("collect_from_workers")
|
49 |
-
class WorkerResultCollector(dp.iter.IterDataPipe):
|
50 |
-
def __init__(self, source: dp.iter.IterDataPipe):
|
51 |
-
self.source = source
|
52 |
-
|
53 |
-
def __iter__(self) -> Iterator:
|
54 |
-
yield from self.source
|
55 |
-
|
56 |
-
def is_replicable(self) -> bool:
|
57 |
-
"""Method to force data back to main process"""
|
58 |
-
return False
|
59 |
-
|
60 |
-
|
61 |
-
def crop_bottom(image, cutoff=16):
|
62 |
-
return image[:, :-cutoff, :]
|
63 |
-
|
64 |
-
|
65 |
-
def random_gaussian_blur(image, p=0.01):
|
66 |
-
if random.random() < p:
|
67 |
-
return v2.functional.gaussian_blur(image, kernel_size=5)
|
68 |
-
return image
|
69 |
-
|
70 |
-
|
71 |
-
def random_invisible_watermark(image, p=0.2):
|
72 |
-
image_np = np.array(image)
|
73 |
-
image_np = np.transpose(image_np, (1, 2, 0))
|
74 |
-
|
75 |
-
if image_np.ndim == 2: # Grayscale image
|
76 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
|
77 |
-
elif image_np.shape[2] == 4: # RGBA image
|
78 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2BGR)
|
79 |
-
|
80 |
-
# print(image_np.shape)
|
81 |
-
if image_np.shape[0] < 256 or image_np.shape[1] < 256:
|
82 |
-
image_np = cv2.resize(
|
83 |
-
image_np,
|
84 |
-
(256, 256),
|
85 |
-
interpolation=cv2.INTER_AREA,
|
86 |
-
)
|
87 |
-
if random.random() < p:
|
88 |
-
return encoder.encode(image_np, method="dwtDct")
|
89 |
-
return image_np
|
90 |
-
|
91 |
-
|
92 |
-
def build_transform(split: str):
|
93 |
-
train_transform = v2.Compose(
|
94 |
-
[
|
95 |
-
v2.Lambda(crop_bottom),
|
96 |
-
v2.RandomCrop((256, 256), pad_if_needed=True),
|
97 |
-
v2.Lambda(random_gaussian_blur),
|
98 |
-
v2.RandomGrayscale(p=0.05),
|
99 |
-
v2.Lambda(random_invisible_watermark),
|
100 |
-
v2.ToImage(),
|
101 |
-
],
|
102 |
-
)
|
103 |
-
|
104 |
-
eval_transform = v2.Compose(
|
105 |
-
[
|
106 |
-
v2.CenterCrop((256, 256)),
|
107 |
-
],
|
108 |
-
)
|
109 |
-
transform = train_transform if split == "train" else eval_transform
|
110 |
-
|
111 |
-
return transform
|
112 |
-
|
113 |
-
|
114 |
-
def dp_to_tuple_train(input_dict):
|
115 |
-
transform = build_transform("train")
|
116 |
-
return (
|
117 |
-
transform(input_dict[".jpg"]),
|
118 |
-
input_dict[".label.cls"],
|
119 |
-
input_dict[".domain_label.cls"],
|
120 |
-
)
|
121 |
-
|
122 |
-
|
123 |
-
def dp_to_tuple_eval(input_dict):
|
124 |
-
transform = build_transform("eval")
|
125 |
-
return (
|
126 |
-
transform(input_dict[".jpg"]),
|
127 |
-
input_dict[".label.cls"],
|
128 |
-
input_dict[".domain_label.cls"],
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
-
def load_dataset(domains: list[int], split: str):
|
133 |
-
|
134 |
-
laion_lister = FileLister("./data/laion400m_data", f"{split}*.tar")
|
135 |
-
genai_lister = {
|
136 |
-
d: FileLister(
|
137 |
-
f"./data/genai-images/{DOMAIN_LABELS[d]}",
|
138 |
-
f"{split}*.tar",
|
139 |
-
)
|
140 |
-
for d in domains
|
141 |
-
if DOMAIN_LABELS[d] != "laion"
|
142 |
-
}
|
143 |
-
weight_genai = 1 / len(genai_lister)
|
144 |
-
|
145 |
-
def open_lister(lister):
|
146 |
-
opener = FileOpener(lister, mode="b")
|
147 |
-
return opener.load_from_tar().routed_decode().webdataset()
|
148 |
-
|
149 |
-
buffer_size1 = 100 if split == "train" else 10
|
150 |
-
buffer_size2 = 100 if split == "train" else 10
|
151 |
-
|
152 |
-
if split != "train":
|
153 |
-
all_lister = [laion_lister] + list(genai_lister.values())
|
154 |
-
dp = open_lister(Concater(*all_lister)).sharding_filter()
|
155 |
-
else:
|
156 |
-
laion_dp = (
|
157 |
-
open_lister(laion_lister.shuffle())
|
158 |
-
.cycle()
|
159 |
-
.sharding_filter()
|
160 |
-
.shuffle(buffer_size=buffer_size1)
|
161 |
-
)
|
162 |
-
genai_dp = {
|
163 |
-
open_lister(genai_lister[d].shuffle())
|
164 |
-
.cycle()
|
165 |
-
.sharding_filter()
|
166 |
-
.shuffle(buffer_size=buffer_size1): weight_genai
|
167 |
-
for d in domains
|
168 |
-
if DOMAIN_LABELS[d] != "laion"
|
169 |
-
}
|
170 |
-
dp = SampleMultiplexer({laion_dp: 1, **genai_dp}).shuffle(
|
171 |
-
buffer_size=buffer_size2,
|
172 |
-
)
|
173 |
-
|
174 |
-
if split == "train":
|
175 |
-
dp = dp.map(dp_to_tuple_train)
|
176 |
-
else:
|
177 |
-
dp = dp.map(dp_to_tuple_eval)
|
178 |
-
|
179 |
-
return dp
|
180 |
-
|
181 |
-
|
182 |
-
def load_dataloader(
|
183 |
-
domains: list[int],
|
184 |
-
split: str,
|
185 |
-
batch_size: int = 32,
|
186 |
-
num_workers: int = 4,
|
187 |
-
):
|
188 |
-
dp = load_dataset(domains, split)
|
189 |
-
# if split == "train":
|
190 |
-
# dp = UnderSamplerIterDataPipe(dp, {0: 0.5, 1: 0.5}, seed=42)
|
191 |
-
dp = dp.batch(batch_size).collate()
|
192 |
-
dl = DataLoader(
|
193 |
-
dp,
|
194 |
-
batch_size=None,
|
195 |
-
num_workers=num_workers,
|
196 |
-
pin_memory=True,
|
197 |
-
)
|
198 |
-
|
199 |
-
return dl
|
200 |
-
|
201 |
-
|
202 |
-
if __name__ == "__main__":
|
203 |
-
parser = argparse.ArgumentParser()
|
204 |
-
|
205 |
-
args = parser.parse_args()
|
206 |
-
|
207 |
-
# testing code
|
208 |
-
dl = load_dataloader([0, 1], "train", num_workers=8)
|
209 |
-
y_dist = collections.Counter()
|
210 |
-
d_dist = collections.Counter()
|
211 |
-
|
212 |
-
for i, (img, y, d) in tqdm(enumerate(dl)):
|
213 |
-
if i % 100 == 0:
|
214 |
-
print(y, d)
|
215 |
-
if i == 400:
|
216 |
-
break
|
217 |
-
y_dist.update(y.numpy())
|
218 |
-
d_dist.update(d.numpy())
|
219 |
-
|
220 |
-
print("class label")
|
221 |
-
for label in sorted(y_dist):
|
222 |
-
frequency = y_dist[label] / sum(y_dist.values())
|
223 |
-
print(f"β’ {label}: {frequency:.2%} ({y_dist[label]})")
|
224 |
-
|
225 |
-
print("domain label")
|
226 |
-
for label in sorted(d_dist):
|
227 |
-
frequency = d_dist[label] / sum(d_dist.values())
|
228 |
-
print(f"β’ {label}: {frequency:.2%} ({d_dist[label]})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/diffusion_data_loader.py
DELETED
@@ -1,233 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import collections
|
3 |
-
import glob
|
4 |
-
import os
|
5 |
-
import random
|
6 |
-
from typing import Iterator
|
7 |
-
|
8 |
-
import cv2
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
import torch.nn.functional as F
|
13 |
-
import torchdata as td
|
14 |
-
import torchdata.datapipes as dp
|
15 |
-
from imwatermark import WatermarkEncoder
|
16 |
-
from PIL import (
|
17 |
-
Image,
|
18 |
-
ImageFile,
|
19 |
-
)
|
20 |
-
from torch.utils.data import (
|
21 |
-
DataLoader,
|
22 |
-
RandomSampler,
|
23 |
-
)
|
24 |
-
from torchdata.dataloader2 import (
|
25 |
-
DataLoader2,
|
26 |
-
MultiProcessingReadingService,
|
27 |
-
)
|
28 |
-
from torchdata.datapipes.iter import (
|
29 |
-
Concater,
|
30 |
-
FileLister,
|
31 |
-
FileOpener,
|
32 |
-
SampleMultiplexer,
|
33 |
-
)
|
34 |
-
from torchvision.transforms import v2
|
35 |
-
from tqdm import tqdm
|
36 |
-
from utils_sampling import UnderSamplerIterDataPipe
|
37 |
-
|
38 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
39 |
-
Image.MAX_IMAGE_PIXELS = 1000000000
|
40 |
-
|
41 |
-
encoder = WatermarkEncoder()
|
42 |
-
encoder.set_watermark("bytes", b"test")
|
43 |
-
|
44 |
-
DOMAIN_LABELS = {
|
45 |
-
0: "laion",
|
46 |
-
1: "StableDiffusion",
|
47 |
-
2: "dalle2",
|
48 |
-
3: "dalle3",
|
49 |
-
4: "midjourney",
|
50 |
-
}
|
51 |
-
|
52 |
-
N_SAMPLES = {
|
53 |
-
0: (115346, 14418, 14419),
|
54 |
-
1: (22060, 2757, 2758),
|
55 |
-
4: (21096, 2637, 2637),
|
56 |
-
2: (13582, 1697, 1699),
|
57 |
-
3: (12027, 1503, 1504),
|
58 |
-
}
|
59 |
-
|
60 |
-
|
61 |
-
@dp.functional_datapipe("collect_from_workers")
|
62 |
-
class WorkerResultCollector(dp.iter.IterDataPipe):
|
63 |
-
def __init__(self, source: dp.iter.IterDataPipe):
|
64 |
-
self.source = source
|
65 |
-
|
66 |
-
def __iter__(self) -> Iterator:
|
67 |
-
yield from self.source
|
68 |
-
|
69 |
-
def is_replicable(self) -> bool:
|
70 |
-
"""Method to force data back to main process"""
|
71 |
-
return False
|
72 |
-
|
73 |
-
|
74 |
-
def crop_bottom(image, cutoff=16):
|
75 |
-
return image[:, :-cutoff, :]
|
76 |
-
|
77 |
-
|
78 |
-
def random_gaussian_blur(image, p=0.01):
|
79 |
-
if random.random() < p:
|
80 |
-
return v2.functional.gaussian_blur(image, kernel_size=5)
|
81 |
-
return image
|
82 |
-
|
83 |
-
|
84 |
-
def random_invisible_watermark(image, p=0.2):
|
85 |
-
image_np = np.array(image)
|
86 |
-
image_np = np.transpose(image_np, (1, 2, 0))
|
87 |
-
|
88 |
-
if image_np.ndim == 2: # Grayscale image
|
89 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
|
90 |
-
elif image_np.shape[2] == 4: # RGBA image
|
91 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2BGR)
|
92 |
-
|
93 |
-
# print(image_np.shape)
|
94 |
-
if image_np.shape[0] < 256 or image_np.shape[1] < 256:
|
95 |
-
image_np = cv2.resize(
|
96 |
-
image_np, (256, 256), interpolation=cv2.INTER_AREA
|
97 |
-
)
|
98 |
-
if random.random() < p:
|
99 |
-
return encoder.encode(image_np, method="dwtDct")
|
100 |
-
return image_np
|
101 |
-
|
102 |
-
|
103 |
-
def build_transform(split: str):
|
104 |
-
train_transform = v2.Compose(
|
105 |
-
[
|
106 |
-
v2.Lambda(crop_bottom),
|
107 |
-
v2.RandomCrop((256, 256), pad_if_needed=True),
|
108 |
-
v2.Lambda(random_gaussian_blur),
|
109 |
-
v2.RandomGrayscale(p=0.05),
|
110 |
-
v2.Lambda(random_invisible_watermark),
|
111 |
-
v2.ToImage(),
|
112 |
-
]
|
113 |
-
)
|
114 |
-
|
115 |
-
eval_transform = v2.Compose(
|
116 |
-
[
|
117 |
-
v2.CenterCrop((256, 256)),
|
118 |
-
]
|
119 |
-
)
|
120 |
-
transform = train_transform if split == "train" else eval_transform
|
121 |
-
|
122 |
-
return transform
|
123 |
-
|
124 |
-
|
125 |
-
def dp_to_tuple_train(input_dict):
|
126 |
-
transform = build_transform("train")
|
127 |
-
return (
|
128 |
-
transform(input_dict[".jpg"]),
|
129 |
-
input_dict[".label.cls"],
|
130 |
-
input_dict[".domain_label.cls"],
|
131 |
-
)
|
132 |
-
|
133 |
-
|
134 |
-
def dp_to_tuple_eval(input_dict):
|
135 |
-
transform = build_transform("eval")
|
136 |
-
return (
|
137 |
-
transform(input_dict[".jpg"]),
|
138 |
-
input_dict[".label.cls"],
|
139 |
-
input_dict[".domain_label.cls"],
|
140 |
-
)
|
141 |
-
|
142 |
-
|
143 |
-
def load_dataset(domains: list[int], split: str):
|
144 |
-
laion_lister = FileLister("./data/laion400m_data", f"{split}*.tar")
|
145 |
-
genai_lister = {
|
146 |
-
d: FileLister(
|
147 |
-
f"./data/genai-images/{DOMAIN_LABELS[d]}", f"{split}*.tar"
|
148 |
-
)
|
149 |
-
for d in domains
|
150 |
-
if DOMAIN_LABELS[d] != "laion"
|
151 |
-
}
|
152 |
-
weight_genai = 1 / len(genai_lister)
|
153 |
-
|
154 |
-
def open_lister(lister):
|
155 |
-
opener = FileOpener(lister, mode="b")
|
156 |
-
return opener.load_from_tar().routed_decode().webdataset()
|
157 |
-
|
158 |
-
buffer_size1 = 100 if split == "train" else 10
|
159 |
-
buffer_size2 = 100 if split == "train" else 10
|
160 |
-
|
161 |
-
if split != "train":
|
162 |
-
all_lister = [laion_lister] + list(genai_lister.values())
|
163 |
-
dp = open_lister(Concater(*all_lister)).sharding_filter()
|
164 |
-
else:
|
165 |
-
laion_dp = (
|
166 |
-
open_lister(laion_lister.shuffle())
|
167 |
-
.cycle()
|
168 |
-
.sharding_filter()
|
169 |
-
.shuffle(buffer_size=buffer_size1)
|
170 |
-
)
|
171 |
-
genai_dp = {
|
172 |
-
open_lister(genai_lister[d].shuffle())
|
173 |
-
.cycle()
|
174 |
-
.sharding_filter()
|
175 |
-
.shuffle(
|
176 |
-
buffer_size=buffer_size1,
|
177 |
-
): weight_genai
|
178 |
-
for d in domains
|
179 |
-
if DOMAIN_LABELS[d] != "laion"
|
180 |
-
}
|
181 |
-
dp = SampleMultiplexer({laion_dp: 1, **genai_dp}).shuffle(
|
182 |
-
buffer_size=buffer_size2
|
183 |
-
)
|
184 |
-
|
185 |
-
if split == "train":
|
186 |
-
dp = dp.map(dp_to_tuple_train)
|
187 |
-
else:
|
188 |
-
dp = dp.map(dp_to_tuple_eval)
|
189 |
-
|
190 |
-
return dp
|
191 |
-
|
192 |
-
|
193 |
-
def load_dataloader(
|
194 |
-
domains: list[int], split: str, batch_size: int = 32, num_workers: int = 4
|
195 |
-
):
|
196 |
-
dp = load_dataset(domains, split)
|
197 |
-
# if split == "train":
|
198 |
-
# dp = UnderSamplerIterDataPipe(dp, {0: 0.5, 1: 0.5}, seed=42)
|
199 |
-
dp = dp.batch(batch_size).collate()
|
200 |
-
dl = DataLoader(
|
201 |
-
dp, batch_size=None, num_workers=num_workers, pin_memory=True
|
202 |
-
)
|
203 |
-
|
204 |
-
return dl
|
205 |
-
|
206 |
-
|
207 |
-
if __name__ == "__main__":
|
208 |
-
parser = argparse.ArgumentParser()
|
209 |
-
|
210 |
-
args = parser.parse_args()
|
211 |
-
|
212 |
-
# testing code
|
213 |
-
dl = load_dataloader([0, 1], "train", num_workers=8)
|
214 |
-
y_dist = collections.Counter()
|
215 |
-
d_dist = collections.Counter()
|
216 |
-
|
217 |
-
for i, (img, y, d) in tqdm(enumerate(dl)):
|
218 |
-
if i % 100 == 0:
|
219 |
-
print(y, d)
|
220 |
-
if i == 400:
|
221 |
-
break
|
222 |
-
y_dist.update(y.numpy())
|
223 |
-
d_dist.update(d.numpy())
|
224 |
-
|
225 |
-
print("class label")
|
226 |
-
for label in sorted(y_dist):
|
227 |
-
frequency = y_dist[label] / sum(y_dist.values())
|
228 |
-
print(f"β’ {label}: {frequency:.2%} ({y_dist[label]})")
|
229 |
-
|
230 |
-
print("domain label")
|
231 |
-
for label in sorted(d_dist):
|
232 |
-
frequency = d_dist[label] / sum(d_dist.values())
|
233 |
-
print(f"β’ {label}: {frequency:.2%} ({d_dist[label]})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/diffusion_model_classifier.py
DELETED
@@ -1,242 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import pytorch_lightning as pl
|
7 |
-
import timm
|
8 |
-
import torch
|
9 |
-
import torchvision.transforms as transforms
|
10 |
-
from data_split import *
|
11 |
-
from dataloader import *
|
12 |
-
from PIL import Image
|
13 |
-
from pytorch_lightning.callbacks import (
|
14 |
-
EarlyStopping,
|
15 |
-
ModelCheckpoint,
|
16 |
-
)
|
17 |
-
from sklearn.metrics import roc_auc_score
|
18 |
-
from torchmetrics import (
|
19 |
-
Accuracy,
|
20 |
-
Recall,
|
21 |
-
)
|
22 |
-
from utils_sampling import *
|
23 |
-
|
24 |
-
logging.basicConfig(
|
25 |
-
filename="training.log", filemode="w", level=logging.INFO, force=True
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
class ImageClassifier(pl.LightningModule):
|
30 |
-
def __init__(self, lmd=0):
|
31 |
-
super().__init__()
|
32 |
-
self.model = timm.create_model(
|
33 |
-
"resnet50", pretrained=True, num_classes=1
|
34 |
-
)
|
35 |
-
self.accuracy = Accuracy(task="binary", threshold=0.5)
|
36 |
-
self.recall = Recall(task="binary", threshold=0.5)
|
37 |
-
self.validation_outputs = []
|
38 |
-
self.lmd = lmd
|
39 |
-
|
40 |
-
def forward(self, x):
|
41 |
-
return self.model(x)
|
42 |
-
|
43 |
-
def training_step(self, batch):
|
44 |
-
images, labels, _ = batch
|
45 |
-
outputs = self.forward(images).squeeze()
|
46 |
-
|
47 |
-
print(f"Shape of outputs (training): {outputs.shape}")
|
48 |
-
print(f"Shape of labels (training): {labels.shape}")
|
49 |
-
|
50 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
51 |
-
logging.info(f"Training Step - ERM loss: {loss.item()}")
|
52 |
-
loss += self.lmd * (outputs**2).mean() # SD loss penalty
|
53 |
-
logging.info(f"Training Step - SD loss: {loss.item()}")
|
54 |
-
return loss
|
55 |
-
|
56 |
-
def validation_step(self, batch):
|
57 |
-
images, labels, _ = batch
|
58 |
-
outputs = self.forward(images).squeeze()
|
59 |
-
|
60 |
-
if outputs.shape == torch.Size([]):
|
61 |
-
return
|
62 |
-
|
63 |
-
print(f"Shape of outputs (validation): {outputs.shape}")
|
64 |
-
print(f"Shape of labels (validation): {labels.shape}")
|
65 |
-
|
66 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
67 |
-
preds = torch.sigmoid(outputs)
|
68 |
-
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
69 |
-
self.log(
|
70 |
-
"val_acc",
|
71 |
-
self.accuracy(preds, labels.int()),
|
72 |
-
prog_bar=True,
|
73 |
-
sync_dist=True,
|
74 |
-
)
|
75 |
-
self.log(
|
76 |
-
"val_recall",
|
77 |
-
self.recall(preds, labels.int()),
|
78 |
-
prog_bar=True,
|
79 |
-
sync_dist=True,
|
80 |
-
)
|
81 |
-
output = {"val_loss": loss, "preds": preds, "labels": labels}
|
82 |
-
self.validation_outputs.append(output)
|
83 |
-
logging.info(f"Validation Step - Batch loss: {loss.item()}")
|
84 |
-
return output
|
85 |
-
|
86 |
-
def predict_step(self, batch):
|
87 |
-
images, label, domain = batch
|
88 |
-
outputs = self.forward(images).squeeze()
|
89 |
-
preds = torch.sigmoid(outputs)
|
90 |
-
return preds, label, domain
|
91 |
-
|
92 |
-
def on_validation_epoch_end(self):
|
93 |
-
if not self.validation_outputs:
|
94 |
-
logging.warning("No outputs in validation step to process")
|
95 |
-
return
|
96 |
-
preds = torch.cat([x["preds"] for x in self.validation_outputs])
|
97 |
-
labels = torch.cat([x["labels"] for x in self.validation_outputs])
|
98 |
-
if labels.unique().size(0) == 1:
|
99 |
-
logging.warning("Only one class in validation step")
|
100 |
-
return
|
101 |
-
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
|
102 |
-
self.log("val_auc", auc_score, prog_bar=True, sync_dist=True)
|
103 |
-
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
|
104 |
-
self.validation_outputs = []
|
105 |
-
|
106 |
-
def configure_optimizers(self):
|
107 |
-
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
|
108 |
-
return optimizer
|
109 |
-
|
110 |
-
|
111 |
-
checkpoint_callback = ModelCheckpoint(
|
112 |
-
monitor="val_loss",
|
113 |
-
dirpath="./model_checkpoints/",
|
114 |
-
filename="image-classifier-{step}-{val_loss:.2f}",
|
115 |
-
save_top_k=3,
|
116 |
-
mode="min",
|
117 |
-
every_n_train_steps=1001,
|
118 |
-
enable_version_counter=True,
|
119 |
-
)
|
120 |
-
|
121 |
-
early_stop_callback = EarlyStopping(
|
122 |
-
monitor="val_loss",
|
123 |
-
patience=4,
|
124 |
-
mode="min",
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
def load_image(image_path, transform=None):
|
129 |
-
image = Image.open(image_path).convert("RGB")
|
130 |
-
|
131 |
-
if transform:
|
132 |
-
image = transform(image)
|
133 |
-
|
134 |
-
return image
|
135 |
-
|
136 |
-
|
137 |
-
def predict_single_image(image_path, model, transform=None):
|
138 |
-
image = load_image(image_path, transform)
|
139 |
-
|
140 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
141 |
-
|
142 |
-
model.to(device)
|
143 |
-
|
144 |
-
image = image.to(device)
|
145 |
-
|
146 |
-
model.eval()
|
147 |
-
|
148 |
-
with torch.no_grad():
|
149 |
-
image = image.unsqueeze(0)
|
150 |
-
output = model(image).squeeze()
|
151 |
-
print(output)
|
152 |
-
prediction = torch.sigmoid(output).item()
|
153 |
-
|
154 |
-
return prediction
|
155 |
-
|
156 |
-
|
157 |
-
parser = argparse.ArgumentParser()
|
158 |
-
parser.add_argument(
|
159 |
-
"--ckpt_path", help="checkpoint to continue from", required=False
|
160 |
-
)
|
161 |
-
parser.add_argument(
|
162 |
-
"--predict", help="predict on test set", action="store_true"
|
163 |
-
)
|
164 |
-
parser.add_argument("--reset", help="reset training", action="store_true")
|
165 |
-
parser.add_argument(
|
166 |
-
"--predict_image",
|
167 |
-
help="predict the class of a single image",
|
168 |
-
action="store_true",
|
169 |
-
)
|
170 |
-
parser.add_argument(
|
171 |
-
"--image_path",
|
172 |
-
help="path to the image to predict",
|
173 |
-
type=str,
|
174 |
-
required=False,
|
175 |
-
)
|
176 |
-
args = parser.parse_args()
|
177 |
-
|
178 |
-
train_domains = [0, 1, 4]
|
179 |
-
val_domains = [0, 1, 4]
|
180 |
-
lmd_value = 0
|
181 |
-
|
182 |
-
if args.predict:
|
183 |
-
test_dl = load_dataloader(
|
184 |
-
[0, 1, 2, 3, 4], "test", batch_size=128, num_workers=1
|
185 |
-
)
|
186 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
187 |
-
trainer = pl.Trainer()
|
188 |
-
predictions = trainer.predict(model, dataloaders=test_dl)
|
189 |
-
preds, labels, domains = zip(*predictions)
|
190 |
-
preds = torch.cat(preds).cpu().numpy()
|
191 |
-
labels = torch.cat(labels).cpu().numpy()
|
192 |
-
domains = torch.cat(domains).cpu().numpy()
|
193 |
-
print(preds.shape, labels.shape, domains.shape)
|
194 |
-
df = pd.DataFrame({"preds": preds, "labels": labels, "domains": domains})
|
195 |
-
filename = "preds-" + args.ckpt_path.split("/")[-1]
|
196 |
-
df.to_csv(f"outputs/{filename}.csv", index=False)
|
197 |
-
elif args.predict_image:
|
198 |
-
image_path = args.image_path
|
199 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
200 |
-
|
201 |
-
# Define the transformations for the image
|
202 |
-
transform = transforms.Compose(
|
203 |
-
[
|
204 |
-
transforms.Resize((224, 224)), # Image size expected by ResNet50
|
205 |
-
transforms.ToTensor(),
|
206 |
-
transforms.Normalize(
|
207 |
-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
208 |
-
),
|
209 |
-
]
|
210 |
-
)
|
211 |
-
|
212 |
-
prediction = predict_single_image(image_path, model, transform)
|
213 |
-
print("prediction", prediction)
|
214 |
-
|
215 |
-
# Output the prediction
|
216 |
-
print(
|
217 |
-
f"Prediction for {image_path}: {'Human' if prediction <= 0.001 else 'Generated'}"
|
218 |
-
)
|
219 |
-
else:
|
220 |
-
train_dl = load_dataloader(
|
221 |
-
train_domains, "train", batch_size=128, num_workers=4
|
222 |
-
)
|
223 |
-
logging.info("Training dataloader loaded")
|
224 |
-
val_dl = load_dataloader(val_domains, "val", batch_size=128, num_workers=4)
|
225 |
-
logging.info("Validation dataloader loaded")
|
226 |
-
|
227 |
-
if args.reset:
|
228 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
229 |
-
else:
|
230 |
-
model = ImageClassifier(lmd=lmd_value)
|
231 |
-
trainer = pl.Trainer(
|
232 |
-
callbacks=[checkpoint_callback, early_stop_callback],
|
233 |
-
max_steps=20000,
|
234 |
-
val_check_interval=1000,
|
235 |
-
check_val_every_n_epoch=None,
|
236 |
-
)
|
237 |
-
trainer.fit(
|
238 |
-
model=model,
|
239 |
-
train_dataloaders=train_dl,
|
240 |
-
val_dataloaders=val_dl,
|
241 |
-
ckpt_path=args.ckpt_path if not args.reset else None,
|
242 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/evaluation.ipynb
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"import pandas as pd\n",
|
10 |
-
"import numpy as np\n",
|
11 |
-
"import polars as pl\n",
|
12 |
-
"import matplotlib.pyplot as plt\n",
|
13 |
-
"import seaborn as sns\n",
|
14 |
-
"from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, RocCurveDisplay\n",
|
15 |
-
"\n",
|
16 |
-
"sns.set()"
|
17 |
-
]
|
18 |
-
},
|
19 |
-
{
|
20 |
-
"cell_type": "code",
|
21 |
-
"execution_count": null,
|
22 |
-
"metadata": {},
|
23 |
-
"outputs": [],
|
24 |
-
"source": [
|
25 |
-
"def pfbeta(labels, predictions, beta=1):\n",
|
26 |
-
" y_true_count = 0\n",
|
27 |
-
" ctp = 0\n",
|
28 |
-
" cfp = 0\n",
|
29 |
-
"\n",
|
30 |
-
" for idx in range(len(labels)):\n",
|
31 |
-
" prediction = min(max(predictions[idx], 0), 1)\n",
|
32 |
-
" if (labels[idx]):\n",
|
33 |
-
" y_true_count += 1\n",
|
34 |
-
" ctp += prediction\n",
|
35 |
-
" else:\n",
|
36 |
-
" cfp += prediction\n",
|
37 |
-
"\n",
|
38 |
-
" beta_squared = beta * beta\n",
|
39 |
-
" c_precision = ctp / (ctp + cfp)\n",
|
40 |
-
" c_recall = ctp / y_true_count\n",
|
41 |
-
" if (c_precision > 0 and c_recall > 0):\n",
|
42 |
-
" result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)\n",
|
43 |
-
" return result\n",
|
44 |
-
" else:\n",
|
45 |
-
" return 0"
|
46 |
-
]
|
47 |
-
},
|
48 |
-
{
|
49 |
-
"cell_type": "code",
|
50 |
-
"execution_count": null,
|
51 |
-
"metadata": {},
|
52 |
-
"outputs": [],
|
53 |
-
"source": [
|
54 |
-
"def get_part_metrics(df: pl.DataFrame, threshold=0.3) -> dict:\n",
|
55 |
-
" df = df.with_columns((df[\"preds\"] > threshold).alias(\"preds_bin\"))\n",
|
56 |
-
" metrics = {}\n",
|
57 |
-
" # binary metrics using the threshold\n",
|
58 |
-
" metrics[\"accuracy\"] = accuracy_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
59 |
-
" metrics[\"precision\"] = precision_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
60 |
-
" metrics[\"recall\"] = recall_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
61 |
-
" metrics[\"f1\"] = f1_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
62 |
-
" # probabilistic F1 (doesn't depend on the threshold)\n",
|
63 |
-
" metrics[\"pf1\"] = pfbeta(df[\"labels\"].to_numpy(), df[\"preds\"].to_numpy())\n",
|
64 |
-
" # ROC AUC\n",
|
65 |
-
" metrics[\"roc_auc\"] = roc_auc_score(df[\"labels\"].to_numpy(), df[\"preds\"].to_numpy())\n",
|
66 |
-
" return metrics\n",
|
67 |
-
"\n",
|
68 |
-
"\n",
|
69 |
-
"def get_all_metrics(df: pl.DataFrame, threshold=0.3) -> pd.DataFrame:\n",
|
70 |
-
" groups = [list(range(5)), [0, 1], [0, 4], [0, 2], [0, 3]]\n",
|
71 |
-
" group_names = [\"all\", \"StableDiffusion\", \"Midjourney\", \"Dalle2\", \"Dalle3\"]\n",
|
72 |
-
" all_metrics = []\n",
|
73 |
-
" for i, g in enumerate(groups):\n",
|
74 |
-
" subset = df.filter(pl.col(\"domains\").is_in(g))\n",
|
75 |
-
" metrics = get_part_metrics(subset, threshold=threshold)\n",
|
76 |
-
" metrics[\"group\"] = group_names[i]\n",
|
77 |
-
" all_metrics.append(metrics)\n",
|
78 |
-
" \n",
|
79 |
-
" return pd.DataFrame(all_metrics)"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"cell_type": "code",
|
84 |
-
"execution_count": null,
|
85 |
-
"metadata": {},
|
86 |
-
"outputs": [],
|
87 |
-
"source": [
|
88 |
-
"df1 = pl.read_csv(\"outputs/preds-image-classifier-1.csv\")\n",
|
89 |
-
"metrics_df1 = get_all_metrics(df1, threshold=0.5)"
|
90 |
-
]
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"cell_type": "code",
|
94 |
-
"execution_count": null,
|
95 |
-
"metadata": {},
|
96 |
-
"outputs": [],
|
97 |
-
"source": [
|
98 |
-
"metrics_df1"
|
99 |
-
]
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "code",
|
103 |
-
"execution_count": null,
|
104 |
-
"metadata": {},
|
105 |
-
"outputs": [],
|
106 |
-
"source": [
|
107 |
-
"df14 = pl.read_csv(\"outputs/preds-image-classifier-14.csv\")\n",
|
108 |
-
"metrics_df14 = get_all_metrics(df14, threshold=0.5)"
|
109 |
-
]
|
110 |
-
},
|
111 |
-
{
|
112 |
-
"cell_type": "code",
|
113 |
-
"execution_count": null,
|
114 |
-
"metadata": {},
|
115 |
-
"outputs": [],
|
116 |
-
"source": [
|
117 |
-
"metrics_df14"
|
118 |
-
]
|
119 |
-
},
|
120 |
-
{
|
121 |
-
"cell_type": "code",
|
122 |
-
"execution_count": null,
|
123 |
-
"metadata": {},
|
124 |
-
"outputs": [],
|
125 |
-
"source": [
|
126 |
-
"df142 = pl.read_csv(\"outputs/preds-image-classifier-142.csv\")\n",
|
127 |
-
"metrics_df142 = get_all_metrics(df142, threshold=0.5)"
|
128 |
-
]
|
129 |
-
},
|
130 |
-
{
|
131 |
-
"cell_type": "code",
|
132 |
-
"execution_count": null,
|
133 |
-
"metadata": {},
|
134 |
-
"outputs": [],
|
135 |
-
"source": [
|
136 |
-
"metrics_df142"
|
137 |
-
]
|
138 |
-
},
|
139 |
-
{
|
140 |
-
"cell_type": "code",
|
141 |
-
"execution_count": null,
|
142 |
-
"metadata": {},
|
143 |
-
"outputs": [],
|
144 |
-
"source": [
|
145 |
-
"df1423 = pl.read_csv(\"outputs/preds-image-classifier-1423.csv\")\n",
|
146 |
-
"metrics_df1423 = get_all_metrics(df1423, threshold=0.5)"
|
147 |
-
]
|
148 |
-
},
|
149 |
-
{
|
150 |
-
"cell_type": "code",
|
151 |
-
"execution_count": null,
|
152 |
-
"metadata": {},
|
153 |
-
"outputs": [],
|
154 |
-
"source": [
|
155 |
-
"metrics_df1423"
|
156 |
-
]
|
157 |
-
},
|
158 |
-
{
|
159 |
-
"cell_type": "code",
|
160 |
-
"execution_count": null,
|
161 |
-
"metadata": {},
|
162 |
-
"outputs": [],
|
163 |
-
"source": []
|
164 |
-
}
|
165 |
-
],
|
166 |
-
"metadata": {
|
167 |
-
"kernelspec": {
|
168 |
-
"display_name": "GenAI-image-detection-Z_9oJJe7",
|
169 |
-
"language": "python",
|
170 |
-
"name": "python3"
|
171 |
-
},
|
172 |
-
"language_info": {
|
173 |
-
"codemirror_mode": {
|
174 |
-
"name": "ipython",
|
175 |
-
"version": 3
|
176 |
-
},
|
177 |
-
"file_extension": ".py",
|
178 |
-
"mimetype": "text/x-python",
|
179 |
-
"name": "python",
|
180 |
-
"nbconvert_exporter": "python",
|
181 |
-
"pygments_lexer": "ipython3",
|
182 |
-
"version": "3.11.6"
|
183 |
-
}
|
184 |
-
},
|
185 |
-
"nbformat": 4,
|
186 |
-
"nbformat_minor": 2
|
187 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/model.py
DELETED
@@ -1,307 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import pytorch_lightning as pl
|
7 |
-
import timm
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torchvision.transforms as transforms
|
11 |
-
from dataloader import load_dataloader
|
12 |
-
from PIL import Image
|
13 |
-
from pytorch_lightning.callbacks import (
|
14 |
-
EarlyStopping,
|
15 |
-
ModelCheckpoint,
|
16 |
-
)
|
17 |
-
from sklearn.metrics import roc_auc_score
|
18 |
-
from torchmetrics import (
|
19 |
-
Accuracy,
|
20 |
-
Recall,
|
21 |
-
)
|
22 |
-
|
23 |
-
logging.basicConfig(
|
24 |
-
filename="training.log",
|
25 |
-
filemode="w",
|
26 |
-
level=logging.INFO,
|
27 |
-
force=True,
|
28 |
-
)
|
29 |
-
|
30 |
-
|
31 |
-
class ImageClassifier(pl.LightningModule):
|
32 |
-
def __init__(self, lmd=0):
|
33 |
-
super().__init__()
|
34 |
-
self.model = timm.create_model(
|
35 |
-
"resnet50",
|
36 |
-
pretrained=True,
|
37 |
-
num_classes=1,
|
38 |
-
)
|
39 |
-
self.accuracy = Accuracy(task="binary", threshold=0.5)
|
40 |
-
self.recall = Recall(task="binary", threshold=0.5)
|
41 |
-
self.validation_outputs = []
|
42 |
-
self.lmd = lmd
|
43 |
-
|
44 |
-
def forward(self, x):
|
45 |
-
return self.model(x)
|
46 |
-
|
47 |
-
def training_step(self, batch):
|
48 |
-
images, labels, _ = batch
|
49 |
-
outputs = self.forward(images).squeeze()
|
50 |
-
|
51 |
-
print(f"Shape of outputs (training): {outputs.shape}")
|
52 |
-
print(f"Shape of labels (training): {labels.shape}")
|
53 |
-
|
54 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
55 |
-
logging.info(f"Training Step - ERM loss: {loss.item()}")
|
56 |
-
loss += self.lmd * (outputs**2).mean() # SD loss penalty
|
57 |
-
logging.info(f"Training Step - SD loss: {loss.item()}")
|
58 |
-
return loss
|
59 |
-
|
60 |
-
def validation_step(self, batch):
|
61 |
-
images, labels, _ = batch
|
62 |
-
outputs = self.forward(images).squeeze()
|
63 |
-
|
64 |
-
if outputs.shape == torch.Size([]):
|
65 |
-
return
|
66 |
-
|
67 |
-
print(f"Shape of outputs (validation): {outputs.shape}")
|
68 |
-
print(f"Shape of labels (validation): {labels.shape}")
|
69 |
-
|
70 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
71 |
-
preds = torch.sigmoid(outputs)
|
72 |
-
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
73 |
-
self.log(
|
74 |
-
"val_acc",
|
75 |
-
self.accuracy(preds, labels.int()),
|
76 |
-
prog_bar=True,
|
77 |
-
sync_dist=True,
|
78 |
-
)
|
79 |
-
self.log(
|
80 |
-
"val_recall",
|
81 |
-
self.recall(preds, labels.int()),
|
82 |
-
prog_bar=True,
|
83 |
-
sync_dist=True,
|
84 |
-
)
|
85 |
-
output = {"val_loss": loss, "preds": preds, "labels": labels}
|
86 |
-
self.validation_outputs.append(output)
|
87 |
-
logging.info(f"Validation Step - Batch loss: {loss.item()}")
|
88 |
-
return output
|
89 |
-
|
90 |
-
def predict_step(self, batch):
|
91 |
-
images, label, domain = batch
|
92 |
-
outputs = self.forward(images).squeeze()
|
93 |
-
preds = torch.sigmoid(outputs)
|
94 |
-
return preds, label, domain
|
95 |
-
|
96 |
-
def on_validation_epoch_end(self):
|
97 |
-
if not self.validation_outputs:
|
98 |
-
logging.warning("No outputs in validation step to process")
|
99 |
-
return
|
100 |
-
preds = torch.cat([x["preds"] for x in self.validation_outputs])
|
101 |
-
labels = torch.cat([x["labels"] for x in self.validation_outputs])
|
102 |
-
if labels.unique().size(0) == 1:
|
103 |
-
logging.warning("Only one class in validation step")
|
104 |
-
return
|
105 |
-
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
|
106 |
-
self.log("val_auc", auc_score, prog_bar=True, sync_dist=True)
|
107 |
-
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
|
108 |
-
self.validation_outputs = []
|
109 |
-
|
110 |
-
def configure_optimizers(self):
|
111 |
-
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
|
112 |
-
return optimizer
|
113 |
-
|
114 |
-
|
115 |
-
checkpoint_callback = ModelCheckpoint(
|
116 |
-
monitor="val_loss",
|
117 |
-
dirpath="./model_checkpoints/",
|
118 |
-
filename="image-classifier-{step}-{val_loss:.2f}",
|
119 |
-
save_top_k=3,
|
120 |
-
mode="min",
|
121 |
-
every_n_train_steps=1001,
|
122 |
-
enable_version_counter=True,
|
123 |
-
)
|
124 |
-
|
125 |
-
early_stop_callback = EarlyStopping(
|
126 |
-
monitor="val_loss",
|
127 |
-
patience=4,
|
128 |
-
mode="min",
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
-
def load_image(image_path, transform=None):
|
133 |
-
image = Image.open(image_path).convert("RGB")
|
134 |
-
|
135 |
-
if transform:
|
136 |
-
image = transform(image)
|
137 |
-
|
138 |
-
return image
|
139 |
-
|
140 |
-
|
141 |
-
def predict_single_image(image_path, model, transform=None):
|
142 |
-
image = load_image(image_path, transform)
|
143 |
-
|
144 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
145 |
-
|
146 |
-
model.to(device)
|
147 |
-
|
148 |
-
image = image.to(device)
|
149 |
-
|
150 |
-
model.eval()
|
151 |
-
|
152 |
-
with torch.no_grad():
|
153 |
-
image = image.unsqueeze(0)
|
154 |
-
output = model(image).squeeze()
|
155 |
-
print(output)
|
156 |
-
prediction = torch.sigmoid(output).item()
|
157 |
-
|
158 |
-
return prediction
|
159 |
-
|
160 |
-
|
161 |
-
parser = argparse.ArgumentParser()
|
162 |
-
parser.add_argument(
|
163 |
-
"--ckpt_path",
|
164 |
-
help="checkpoint to continue from",
|
165 |
-
required=False,
|
166 |
-
)
|
167 |
-
parser.add_argument(
|
168 |
-
"--predict",
|
169 |
-
help="predict on test set",
|
170 |
-
action="store_true",
|
171 |
-
)
|
172 |
-
parser.add_argument("--reset", help="reset training", action="store_true")
|
173 |
-
parser.add_argument(
|
174 |
-
"--predict_image",
|
175 |
-
help="predict the class of a single image",
|
176 |
-
action="store_true",
|
177 |
-
)
|
178 |
-
parser.add_argument(
|
179 |
-
"--image_path",
|
180 |
-
help="path to the image to predict",
|
181 |
-
type=str,
|
182 |
-
required=False,
|
183 |
-
)
|
184 |
-
parser.add_argument(
|
185 |
-
"--dir",
|
186 |
-
help="path to the images to predict",
|
187 |
-
type=str,
|
188 |
-
required=False,
|
189 |
-
)
|
190 |
-
parser.add_argument(
|
191 |
-
"--output_file",
|
192 |
-
help="path to output file",
|
193 |
-
type=str,
|
194 |
-
required=False,
|
195 |
-
)
|
196 |
-
args = parser.parse_args()
|
197 |
-
|
198 |
-
train_domains = [0, 1, 4]
|
199 |
-
val_domains = [0, 1, 4]
|
200 |
-
lmd_value = 0
|
201 |
-
|
202 |
-
if args.predict:
|
203 |
-
test_dl = load_dataloader(
|
204 |
-
[0, 1, 2, 3, 4],
|
205 |
-
"test",
|
206 |
-
batch_size=128,
|
207 |
-
num_workers=1,
|
208 |
-
)
|
209 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
210 |
-
trainer = pl.Trainer()
|
211 |
-
predictions = trainer.predict(model, dataloaders=test_dl)
|
212 |
-
preds, labels, domains = zip(*predictions)
|
213 |
-
preds = torch.cat(preds).cpu().numpy()
|
214 |
-
labels = torch.cat(labels).cpu().numpy()
|
215 |
-
domains = torch.cat(domains).cpu().numpy()
|
216 |
-
print(preds.shape, labels.shape, domains.shape)
|
217 |
-
df = pd.DataFrame({"preds": preds, "labels": labels, "domains": domains})
|
218 |
-
filename = "preds-" + args.ckpt_path.split("/")[-1]
|
219 |
-
df.to_csv(f"outputs/{filename}.csv", index=False)
|
220 |
-
elif args.predict_image:
|
221 |
-
image_path = args.image_path
|
222 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
223 |
-
|
224 |
-
# Define the transformations for the image
|
225 |
-
# transform = transforms.Compose(
|
226 |
-
# [
|
227 |
-
# transforms.Resize((224, 224)), # Image size expected by ResNet50
|
228 |
-
# transforms.ToTensor(),
|
229 |
-
# transforms.Normalize(
|
230 |
-
# mean=[0.485, 0.456, 0.406],
|
231 |
-
# std=[0.229, 0.224, 0.225],
|
232 |
-
# ),
|
233 |
-
# ],
|
234 |
-
# )
|
235 |
-
|
236 |
-
transform = transforms.Compose(
|
237 |
-
[
|
238 |
-
transforms.CenterCrop((256, 256)),
|
239 |
-
transforms.ToTensor(),
|
240 |
-
],
|
241 |
-
)
|
242 |
-
|
243 |
-
prediction = predict_single_image(image_path, model, transform)
|
244 |
-
print("prediction", prediction)
|
245 |
-
|
246 |
-
# Output the prediction
|
247 |
-
print(
|
248 |
-
f"Prediction for {image_path}: "
|
249 |
-
f"{'Human' if prediction <= 0.001 else 'Generated'}",
|
250 |
-
)
|
251 |
-
elif args.dir is not None:
|
252 |
-
predictions = []
|
253 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
254 |
-
# Define the transformations for the image
|
255 |
-
# transform = transforms.Compose(
|
256 |
-
# [
|
257 |
-
# transforms.Resize((224, 224)), # Image size expected by ResNet50
|
258 |
-
# transforms.ToTensor(),
|
259 |
-
# transforms.Normalize(
|
260 |
-
# mean=[0.485, 0.456, 0.406],
|
261 |
-
# std=[0.229, 0.224, 0.225],
|
262 |
-
# ),
|
263 |
-
# ],
|
264 |
-
# )
|
265 |
-
transform = transforms.Compose(
|
266 |
-
[
|
267 |
-
transforms.CenterCrop((256, 256)),
|
268 |
-
transforms.ToTensor(),
|
269 |
-
],
|
270 |
-
)
|
271 |
-
for root, dirs, files in os.walk(os.path.abspath(args.dir)):
|
272 |
-
for f_name in files:
|
273 |
-
f = os.path.join(root, f_name)
|
274 |
-
print(f"Predicting: {f}")
|
275 |
-
p = predict_single_image(f, model, transform)
|
276 |
-
predictions.append([f, f.split("/")[-2], p, p > 0.5])
|
277 |
-
print(f"--predicted: {p}")
|
278 |
-
|
279 |
-
df = pd.DataFrame(predictions, columns=["path", "folder", "pred", "class"])
|
280 |
-
df.to_csv(args.output_file, index=False)
|
281 |
-
else:
|
282 |
-
train_dl = load_dataloader(
|
283 |
-
train_domains,
|
284 |
-
"train",
|
285 |
-
batch_size=128,
|
286 |
-
num_workers=4,
|
287 |
-
)
|
288 |
-
logging.info("Training dataloader loaded")
|
289 |
-
val_dl = load_dataloader(val_domains, "val", batch_size=128, num_workers=4)
|
290 |
-
logging.info("Validation dataloader loaded")
|
291 |
-
|
292 |
-
if args.reset:
|
293 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
294 |
-
else:
|
295 |
-
model = ImageClassifier(lmd=lmd_value)
|
296 |
-
trainer = pl.Trainer(
|
297 |
-
callbacks=[checkpoint_callback, early_stop_callback],
|
298 |
-
max_steps=20000,
|
299 |
-
val_check_interval=1000,
|
300 |
-
check_val_every_n_epoch=None,
|
301 |
-
)
|
302 |
-
trainer.fit(
|
303 |
-
model=model,
|
304 |
-
train_dataloaders=train_dl,
|
305 |
-
val_dataloaders=val_dl,
|
306 |
-
ckpt_path=args.ckpt_path if not args.reset else None,
|
307 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/sample_laion_script.ipynb
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"import dask.dataframe as dd\n",
|
10 |
-
"from dask.diagnostics import ProgressBar\n",
|
11 |
-
"import os\n",
|
12 |
-
"\n",
|
13 |
-
"directory_path = '/Users/fionachow/Documents/NYU/CDS/Fall 2023/CSCI - GA 2271 - Computer Vision/Project/'\n",
|
14 |
-
"\n",
|
15 |
-
"file_prefix = 'part'\n",
|
16 |
-
"\n",
|
17 |
-
"def list_files_with_prefix(directory, prefix):\n",
|
18 |
-
" file_paths = []\n",
|
19 |
-
"\n",
|
20 |
-
" for root, _, files in os.walk(directory):\n",
|
21 |
-
" for file in files:\n",
|
22 |
-
" if file.startswith(prefix):\n",
|
23 |
-
" absolute_path = os.path.join(root, file)\n",
|
24 |
-
" file_paths.append(absolute_path)\n",
|
25 |
-
"\n",
|
26 |
-
" return file_paths\n",
|
27 |
-
"\n",
|
28 |
-
"laion_file_paths = list_files_with_prefix(directory_path, file_prefix)\n",
|
29 |
-
"\n",
|
30 |
-
"dataframes = [dd.read_parquet(file) for file in laion_file_paths]\n",
|
31 |
-
"combined_dataframe = dd.multi.concat(dataframes)\n",
|
32 |
-
"\n",
|
33 |
-
"with ProgressBar():\n",
|
34 |
-
" row_count = combined_dataframe.shape[0].compute()\n",
|
35 |
-
" print(row_count)\n",
|
36 |
-
"\n",
|
37 |
-
"filtered_df = combined_dataframe[combined_dataframe['NSFW'] == \"UNLIKELY\"]\n",
|
38 |
-
"\n",
|
39 |
-
"num_samples = 225_000\n",
|
40 |
-
"selected_rows = filtered_df.sample(frac=num_samples / filtered_df.shape[0].compute())\n",
|
41 |
-
"\n",
|
42 |
-
"with ProgressBar():\n",
|
43 |
-
" sampled_rows = selected_rows.compute()\n",
|
44 |
-
"\n",
|
45 |
-
"print(len(sampled_rows))\n",
|
46 |
-
"\n",
|
47 |
-
"with ProgressBar():\n",
|
48 |
-
" selected_rows.to_parquet('laion_sampled', write_index=False)\n"
|
49 |
-
]
|
50 |
-
}
|
51 |
-
],
|
52 |
-
"metadata": {
|
53 |
-
"kernelspec": {
|
54 |
-
"display_name": "bloom",
|
55 |
-
"language": "python",
|
56 |
-
"name": "python3"
|
57 |
-
},
|
58 |
-
"language_info": {
|
59 |
-
"codemirror_mode": {
|
60 |
-
"name": "ipython",
|
61 |
-
"version": 3
|
62 |
-
},
|
63 |
-
"file_extension": ".py",
|
64 |
-
"mimetype": "text/x-python",
|
65 |
-
"name": "python",
|
66 |
-
"nbconvert_exporter": "python",
|
67 |
-
"pygments_lexer": "ipython3",
|
68 |
-
"version": "3.9.16"
|
69 |
-
}
|
70 |
-
},
|
71 |
-
"nbformat": 4,
|
72 |
-
"nbformat_minor": 2
|
73 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/scrape.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import time
|
3 |
-
|
4 |
-
import polars as pl
|
5 |
-
import requests
|
6 |
-
|
7 |
-
|
8 |
-
def call_api(param):
|
9 |
-
url = "https://api.pullpush.io/reddit/search/submission/"
|
10 |
-
response = requests.get(url, params=param)
|
11 |
-
json_data = response.json()["data"]
|
12 |
-
create_utc = []
|
13 |
-
media_id = []
|
14 |
-
media_type_ls = []
|
15 |
-
post_ids = []
|
16 |
-
post_titles = []
|
17 |
-
cur_utc = 0
|
18 |
-
for submission in json_data:
|
19 |
-
cur_flair = submission["link_flair_text"]
|
20 |
-
cur_utc = submission["created_utc"]
|
21 |
-
media_ls = (
|
22 |
-
submission["media_metadata"]
|
23 |
-
if "media_metadata" in submission.keys()
|
24 |
-
else None
|
25 |
-
)
|
26 |
-
if param["flair"] is not None and cur_flair != param["flair"]:
|
27 |
-
continue
|
28 |
-
if media_ls is None:
|
29 |
-
continue
|
30 |
-
for id in media_ls.keys():
|
31 |
-
if media_ls[id]["status"] != "valid":
|
32 |
-
continue
|
33 |
-
try:
|
34 |
-
media_type = media_ls[id]["m"]
|
35 |
-
except: # noqa
|
36 |
-
# video will error out
|
37 |
-
continue
|
38 |
-
if media_type == "image/png":
|
39 |
-
media_type_ls.append("png")
|
40 |
-
elif media_type == "image/jpg":
|
41 |
-
media_type_ls.append("jpg")
|
42 |
-
else:
|
43 |
-
continue
|
44 |
-
create_utc.append(int(cur_utc))
|
45 |
-
post_ids.append(submission["id"])
|
46 |
-
post_titles.append(submission["title"])
|
47 |
-
media_id.append(id)
|
48 |
-
|
49 |
-
df = pl.DataFrame(
|
50 |
-
{
|
51 |
-
"create_utc": create_utc,
|
52 |
-
"media_id": media_id,
|
53 |
-
"media_type": media_type_ls,
|
54 |
-
"post_id": post_ids,
|
55 |
-
"post_title": post_titles,
|
56 |
-
},
|
57 |
-
schema={
|
58 |
-
"create_utc": pl.Int64,
|
59 |
-
"media_id": pl.Utf8,
|
60 |
-
"media_type": pl.Utf8,
|
61 |
-
"post_id": pl.Utf8,
|
62 |
-
"post_title": pl.Utf8,
|
63 |
-
},
|
64 |
-
)
|
65 |
-
return df, int(cur_utc)
|
66 |
-
|
67 |
-
|
68 |
-
def scraping_loop(
|
69 |
-
subreddit,
|
70 |
-
flair,
|
71 |
-
max_num=30000,
|
72 |
-
output_name=None,
|
73 |
-
before=None,
|
74 |
-
):
|
75 |
-
collected_all = []
|
76 |
-
collected_len = 0
|
77 |
-
last_timestamp = int(time.time()) if before is None else before
|
78 |
-
param = {
|
79 |
-
"subreddit": subreddit,
|
80 |
-
"flair": flair,
|
81 |
-
"before": last_timestamp,
|
82 |
-
}
|
83 |
-
while collected_len < max_num:
|
84 |
-
collected_df, last_timestamp = call_api(param)
|
85 |
-
if collected_df.shape[0] == 0:
|
86 |
-
print("No more data, saving current data and exiting...")
|
87 |
-
break
|
88 |
-
collected_all.append(collected_df)
|
89 |
-
collected_len += collected_df.shape[0]
|
90 |
-
print(
|
91 |
-
f"collected_len: {collected_len}, "
|
92 |
-
f"last_timestamp: {last_timestamp}",
|
93 |
-
)
|
94 |
-
param["before"] = last_timestamp
|
95 |
-
|
96 |
-
df = pl.concat(collected_all)
|
97 |
-
df = (
|
98 |
-
df.with_columns(
|
99 |
-
pl.col("media_id")
|
100 |
-
.str.replace(r"^", "https://i.redd.it/")
|
101 |
-
.alias("url1"),
|
102 |
-
pl.col("create_utc")
|
103 |
-
.cast(pl.Int64)
|
104 |
-
.cast(pl.Utf8)
|
105 |
-
.str.to_datetime("%s")
|
106 |
-
.alias("time"),
|
107 |
-
)
|
108 |
-
.with_columns(
|
109 |
-
pl.col("media_type").str.replace(r"^", ".").alias("url2"),
|
110 |
-
)
|
111 |
-
.with_columns(
|
112 |
-
pl.concat_str(
|
113 |
-
[pl.col("url1"), pl.col("url2")],
|
114 |
-
separator="",
|
115 |
-
).alias("url"),
|
116 |
-
)
|
117 |
-
.select("time", "url", "post_id", "post_title")
|
118 |
-
)
|
119 |
-
if output_name is None:
|
120 |
-
output_name = subreddit
|
121 |
-
df.write_parquet(f"urls/{output_name}.parquet")
|
122 |
-
df.select("url").write_csv(f"urls/{output_name}.csv", has_header=False)
|
123 |
-
|
124 |
-
|
125 |
-
if __name__ == "__main__":
|
126 |
-
parser = argparse.ArgumentParser()
|
127 |
-
parser.add_argument("--subreddit", help="subreddit name")
|
128 |
-
parser.add_argument("--flair", help="flair filter", default=None, type=str)
|
129 |
-
parser.add_argument(
|
130 |
-
"--max_num",
|
131 |
-
help="max number of posts to scrape",
|
132 |
-
default=30000,
|
133 |
-
type=int,
|
134 |
-
)
|
135 |
-
parser.add_argument(
|
136 |
-
"--output_name",
|
137 |
-
help="custom output name",
|
138 |
-
default=None,
|
139 |
-
)
|
140 |
-
parser.add_argument(
|
141 |
-
"--before",
|
142 |
-
help="before timestamp",
|
143 |
-
default=None,
|
144 |
-
type=int,
|
145 |
-
)
|
146 |
-
|
147 |
-
args = parser.parse_args()
|
148 |
-
|
149 |
-
scraping_loop(**args.__dict__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/utils_sampling.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import random
|
3 |
-
from typing import Callable
|
4 |
-
|
5 |
-
from torchdata.datapipes.iter import IterDataPipe
|
6 |
-
|
7 |
-
|
8 |
-
def get_second_entry(sample):
|
9 |
-
return sample[1]
|
10 |
-
|
11 |
-
|
12 |
-
class UnderSamplerIterDataPipe(IterDataPipe):
|
13 |
-
"""Dataset wrapper for under-sampling.
|
14 |
-
|
15 |
-
Copied from: https://github.com/MaxHalford/pytorch-resample/blob/master/pytorch_resample/under.py # noqa
|
16 |
-
Modified to work with multiple labels.
|
17 |
-
|
18 |
-
MIT License
|
19 |
-
|
20 |
-
Copyright (c) 2020 Max Halford
|
21 |
-
|
22 |
-
This method is based on rejection sampling.
|
23 |
-
|
24 |
-
Parameters:
|
25 |
-
dataset
|
26 |
-
desired_dist: The desired class distribution.
|
27 |
-
The keys are the classes whilst the
|
28 |
-
values are the desired class percentages.
|
29 |
-
The values are normalised so that sum up
|
30 |
-
to 1.
|
31 |
-
label_getter: A function that takes a sample and returns its label.
|
32 |
-
seed: Random seed for reproducibility.
|
33 |
-
|
34 |
-
Attributes:
|
35 |
-
actual_dist: The counts of the observed sample labels.
|
36 |
-
rng: A random number generator instance.
|
37 |
-
|
38 |
-
References:
|
39 |
-
- https://www.wikiwand.com/en/Rejection_sampling
|
40 |
-
|
41 |
-
"""
|
42 |
-
|
43 |
-
def __init__(
|
44 |
-
self,
|
45 |
-
dataset: IterDataPipe,
|
46 |
-
desired_dist: dict,
|
47 |
-
label_getter: Callable = get_second_entry,
|
48 |
-
seed: int = None,
|
49 |
-
):
|
50 |
-
|
51 |
-
self.dataset = dataset
|
52 |
-
self.desired_dist = {
|
53 |
-
c: p / sum(desired_dist.values()) for c, p in desired_dist.items()
|
54 |
-
}
|
55 |
-
self.label_getter = label_getter
|
56 |
-
self.seed = seed
|
57 |
-
|
58 |
-
self.actual_dist = collections.Counter()
|
59 |
-
self.rng = random.Random(seed)
|
60 |
-
self._pivot = None
|
61 |
-
|
62 |
-
def __iter__(self):
|
63 |
-
|
64 |
-
for dp in self.dataset:
|
65 |
-
y = self.label_getter(dp)
|
66 |
-
|
67 |
-
self.actual_dist[y] += 1
|
68 |
-
|
69 |
-
# To ease notation
|
70 |
-
f = self.desired_dist
|
71 |
-
g = self.actual_dist
|
72 |
-
|
73 |
-
# Check if the pivot needs to be changed
|
74 |
-
if y != self._pivot:
|
75 |
-
self._pivot = max(g.keys(), key=lambda y: f[y] / g[y])
|
76 |
-
else:
|
77 |
-
yield dp
|
78 |
-
continue
|
79 |
-
|
80 |
-
# Determine the sampling ratio if the observed label
|
81 |
-
# is not the pivot
|
82 |
-
M = f[self._pivot] / g[self._pivot]
|
83 |
-
ratio = f[y] / (M * g[y])
|
84 |
-
|
85 |
-
if ratio < 1 and self.rng.random() < ratio:
|
86 |
-
yield dp
|
87 |
-
|
88 |
-
@classmethod
|
89 |
-
def expected_size(cls, n, desired_dist, actual_dist):
|
90 |
-
M = max(
|
91 |
-
desired_dist.get(k) / actual_dist.get(k)
|
92 |
-
for k in set(desired_dist) | set(actual_dist)
|
93 |
-
)
|
94 |
-
return int(n / M)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Diffusion/visualizations.ipynb
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%pip install polars-lts-cpu"
|
10 |
-
]
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"cell_type": "code",
|
14 |
-
"execution_count": null,
|
15 |
-
"metadata": {},
|
16 |
-
"outputs": [],
|
17 |
-
"source": [
|
18 |
-
"import pandas as pd\n",
|
19 |
-
"import numpy as np\n",
|
20 |
-
"import polars as pl\n",
|
21 |
-
"import matplotlib.pyplot as plt\n",
|
22 |
-
"import seaborn as sns\n",
|
23 |
-
"from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score"
|
24 |
-
]
|
25 |
-
},
|
26 |
-
{
|
27 |
-
"cell_type": "code",
|
28 |
-
"execution_count": null,
|
29 |
-
"metadata": {},
|
30 |
-
"outputs": [],
|
31 |
-
"source": [
|
32 |
-
"def pfbeta(labels, predictions, beta=1):\n",
|
33 |
-
" y_true_count = 0\n",
|
34 |
-
" ctp = 0\n",
|
35 |
-
" cfp = 0\n",
|
36 |
-
"\n",
|
37 |
-
" for idx in range(len(labels)):\n",
|
38 |
-
" prediction = min(max(predictions[idx], 0), 1)\n",
|
39 |
-
" if (labels[idx]):\n",
|
40 |
-
" y_true_count += 1\n",
|
41 |
-
" ctp += prediction\n",
|
42 |
-
" else:\n",
|
43 |
-
" cfp += prediction\n",
|
44 |
-
"\n",
|
45 |
-
" beta_squared = beta * beta\n",
|
46 |
-
" c_precision = ctp / (ctp + cfp)\n",
|
47 |
-
" c_recall = ctp / y_true_count\n",
|
48 |
-
" if (c_precision > 0 and c_recall > 0):\n",
|
49 |
-
" result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)\n",
|
50 |
-
" return result\n",
|
51 |
-
" else:\n",
|
52 |
-
" return 0\n",
|
53 |
-
"\n",
|
54 |
-
"def get_part_metrics(df: pl.DataFrame, threshold=0.3) -> dict:\n",
|
55 |
-
" df = df.with_columns((df[\"preds\"] > threshold).alias(\"preds_bin\"))\n",
|
56 |
-
" metrics = {}\n",
|
57 |
-
" # binary metrics using the threshold\n",
|
58 |
-
" metrics[\"accuracy\"] = accuracy_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
59 |
-
" metrics[\"precision\"] = precision_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
60 |
-
" metrics[\"recall\"] = recall_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
61 |
-
" metrics[\"f1\"] = f1_score(df[\"labels\"].to_numpy(), df[\"preds_bin\"].to_numpy())\n",
|
62 |
-
" # probabilistic F1 (doesn't depend on the threshold)\n",
|
63 |
-
" metrics[\"pf1\"] = pfbeta(df[\"labels\"].to_numpy(), df[\"preds\"].to_numpy())\n",
|
64 |
-
" # ROC AUC\n",
|
65 |
-
" metrics[\"roc_auc\"] = roc_auc_score(df[\"labels\"].to_numpy(), df[\"preds\"].to_numpy())\n",
|
66 |
-
" return metrics\n",
|
67 |
-
"\n",
|
68 |
-
"\n",
|
69 |
-
"def get_all_metrics(df: pl.DataFrame, threshold=0.3) -> pd.DataFrame:\n",
|
70 |
-
" groups = [list(range(5)), [0, 1], [0, 4], [0, 2], [0, 3]]\n",
|
71 |
-
" group_names = [\"all\", \"StableDiffusion\", \"Midjourney\", \"Dalle2\", \"Dalle3\"]\n",
|
72 |
-
" all_metrics = []\n",
|
73 |
-
" for i, g in enumerate(groups):\n",
|
74 |
-
" subset = df.filter(pl.col(\"domains\").is_in(g))\n",
|
75 |
-
" metrics = get_part_metrics(subset, threshold=threshold)\n",
|
76 |
-
" metrics[\"group\"] = group_names[i]\n",
|
77 |
-
" all_metrics.append(metrics)\n",
|
78 |
-
" \n",
|
79 |
-
" return pd.DataFrame(all_metrics)"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"cell_type": "code",
|
84 |
-
"execution_count": null,
|
85 |
-
"metadata": {},
|
86 |
-
"outputs": [],
|
87 |
-
"source": [
|
88 |
-
"# Load the data from the output files\n",
|
89 |
-
"df1 = pl.read_csv('/Users/fionachow/Downloads/outputs/preds-image-classifier-1.csv')\n",
|
90 |
-
"df14 = pl.read_csv('/Users/fionachow/Downloads/outputs/preds-image-classifier-14.csv')\n",
|
91 |
-
"df142 = pl.read_csv('/Users/fionachow/Downloads/outputs/preds-image-classifier-142.csv')\n",
|
92 |
-
"df1423 = pl.read_csv('/Users/fionachow/Downloads/outputs/preds-image-classifier-1423.csv')\n",
|
93 |
-
"\n",
|
94 |
-
"metrics_df1 = get_all_metrics(df1, threshold=0.5)\n",
|
95 |
-
"metrics_df14 = get_all_metrics(df14, threshold=0.5)\n",
|
96 |
-
"metrics_df142 = get_all_metrics(df142, threshold=0.5)\n",
|
97 |
-
"metrics_df1423 = get_all_metrics(df1423, threshold=0.5)"
|
98 |
-
]
|
99 |
-
},
|
100 |
-
{
|
101 |
-
"cell_type": "code",
|
102 |
-
"execution_count": null,
|
103 |
-
"metadata": {},
|
104 |
-
"outputs": [],
|
105 |
-
"source": [
|
106 |
-
"metrics_df1.info()"
|
107 |
-
]
|
108 |
-
},
|
109 |
-
{
|
110 |
-
"cell_type": "code",
|
111 |
-
"execution_count": null,
|
112 |
-
"metadata": {},
|
113 |
-
"outputs": [],
|
114 |
-
"source": [
|
115 |
-
"sns.set()\n",
|
116 |
-
"\n",
|
117 |
-
"models = ['StableDiffusion', 'Midjourney', 'Dalle2', 'Dalle3']\n",
|
118 |
-
"metrics = ['accuracy', 'f1', 'pf1', 'roc_auc']\n",
|
119 |
-
"\n",
|
120 |
-
"file_map = {\n",
|
121 |
-
" ('StableDiffusion',): metrics_df1,\n",
|
122 |
-
" ('StableDiffusion', 'Midjourney'): metrics_df14,\n",
|
123 |
-
" ('StableDiffusion', 'Midjourney', 'Dalle2'): metrics_df142,\n",
|
124 |
-
" ('StableDiffusion', 'Midjourney', 'Dalle2', 'Dalle3'): metrics_df1423,\n",
|
125 |
-
"}\n",
|
126 |
-
"\n",
|
127 |
-
"def create_heatmap_data(metric):\n",
|
128 |
-
" data = pd.DataFrame(index=models[::-1], columns=models)\n",
|
129 |
-
" for i, model_x in enumerate(models):\n",
|
130 |
-
" for j, model_y in enumerate(models[::-1]):\n",
|
131 |
-
" \n",
|
132 |
-
" if i == 0:\n",
|
133 |
-
" relevant_df = metrics_df1\n",
|
134 |
-
" elif i == 1:\n",
|
135 |
-
" relevant_df = metrics_df14\n",
|
136 |
-
" elif i == 2:\n",
|
137 |
-
" relevant_df = metrics_df142\n",
|
138 |
-
" else:\n",
|
139 |
-
" relevant_df = metrics_df1423\n",
|
140 |
-
"\n",
|
141 |
-
" # Debugging: print the DataFrame being used and the model_y\n",
|
142 |
-
" #print(f\"Using DataFrame for {models[:i+1]}, model_y: {model_y}\")\n",
|
143 |
-
"\n",
|
144 |
-
" # Extract the metric value\n",
|
145 |
-
" if model_y in relevant_df['group'].values:\n",
|
146 |
-
" metric_value = relevant_df[relevant_df['group'] == model_y][metric].values[0]\n",
|
147 |
-
" # Debugging: print the extracted metric value\n",
|
148 |
-
" #print(f\"Metric value for {model_y}: {metric_value}\")\n",
|
149 |
-
" else:\n",
|
150 |
-
" metric_value = float('nan') # Handle non-existent cases\n",
|
151 |
-
" # Debugging: print a message for non-existent cases\n",
|
152 |
-
" #print(f\"No data for combination: {model_x}, {model_y}\")\n",
|
153 |
-
"\n",
|
154 |
-
" data.at[model_y, model_x] = metric_value\n",
|
155 |
-
" \n",
|
156 |
-
" for col in data.columns:\n",
|
157 |
-
" data[col] = pd.to_numeric(data[col], errors='coerce')\n",
|
158 |
-
"\n",
|
159 |
-
" # Debugging: print the final DataFrame\n",
|
160 |
-
" # print(f\"Final Data for metric {metric}:\")\n",
|
161 |
-
" # print(data)\n",
|
162 |
-
" # print(data.dtypes)\n",
|
163 |
-
" return data\n",
|
164 |
-
"\n",
|
165 |
-
"for metric in metrics:\n",
|
166 |
-
" plt.figure(figsize=(10, 8))\n",
|
167 |
-
" sns.heatmap(create_heatmap_data(metric), annot=True, cmap='coolwarm', fmt='.3f')\n",
|
168 |
-
" plt.title(f\"Heatmap for {metric}\")\n",
|
169 |
-
" plt.xlabel(\"Trained On (x-axis)\")\n",
|
170 |
-
" plt.ylabel(\"Tested On (y-axis)\")\n",
|
171 |
-
" plt.show()"
|
172 |
-
]
|
173 |
-
}
|
174 |
-
],
|
175 |
-
"metadata": {
|
176 |
-
"kernelspec": {
|
177 |
-
"display_name": "bloom",
|
178 |
-
"language": "python",
|
179 |
-
"name": "python3"
|
180 |
-
},
|
181 |
-
"language_info": {
|
182 |
-
"codemirror_mode": {
|
183 |
-
"name": "ipython",
|
184 |
-
"version": 3
|
185 |
-
},
|
186 |
-
"file_extension": ".py",
|
187 |
-
"mimetype": "text/x-python",
|
188 |
-
"name": "python",
|
189 |
-
"nbconvert_exporter": "python",
|
190 |
-
"pygments_lexer": "ipython3",
|
191 |
-
"version": "3.9.16"
|
192 |
-
}
|
193 |
-
},
|
194 |
-
"nbformat": 4,
|
195 |
-
"nbformat_minor": 2
|
196 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/README.md
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
# AI-generated image detection
|
2 |
-
**(Work In Progress)**
|
3 |
-
|
4 |
-
- [ ] Refactor code
|
5 |
-
- [ ] Review dependencies
|
6 |
-
- [ ] Containerize (Docker)
|
7 |
-
- [ ] Update documentation
|
8 |
-
|
9 |
-
## AI-Generated Image detection
|
10 |
-
|
11 |
-
This part handles the detection of AI-generated images.
|
12 |
-
The current code contains two classifiers to detect AI-generated images from two types of architectures:
|
13 |
-
- GANs
|
14 |
-
|
15 |
-
## Model weights
|
16 |
-
|
17 |
-
### 1. CNN Detection
|
18 |
-
|
19 |
-
Run the `download_weights_CNN.sh` script:
|
20 |
-
|
21 |
-
```commandline
|
22 |
-
bash download_weights_CNN.sh
|
23 |
-
```
|
24 |
-
|
25 |
-
Note: you need `wget` installed on your system (it is by default for most Linux systems).
|
26 |
-
|
27 |
-
### 2. Diffusion
|
28 |
-
|
29 |
-
**TODO**
|
30 |
-
|
31 |
-
|
32 |
-
## Run the models
|
33 |
-
|
34 |
-
Make sure you have the weights available before doing so.
|
35 |
-
|
36 |
-
**TODO: environments**
|
37 |
-
|
38 |
-
### 1. CNN Detection
|
39 |
-
|
40 |
-
```commandline
|
41 |
-
python CNN_model_classifier.py
|
42 |
-
```
|
43 |
-
Available options:
|
44 |
-
|
45 |
-
- `-f / --file` (default=`'examples_realfakedir'`)
|
46 |
-
- `-m / --model_path` (default=`'weights/blur_jpg_prob0.5.pth'`)
|
47 |
-
- `-c / --crop` (default=`None`): Specify crop size (int) by default, do not crop.
|
48 |
-
- `--use_cpu`: use cpu (by default uses GPU) -> **TODO: remove (obsolete)**
|
49 |
-
|
50 |
-
Example usage:
|
51 |
-
|
52 |
-
```commandline
|
53 |
-
python CNN_model_classifier.py -f examples/real.png -m weights/blur_jpg_prob0.5.pth
|
54 |
-
```
|
55 |
-
|
56 |
-
### 2. Diffusion detection
|
57 |
-
|
58 |
-
**TODO**
|
59 |
-
|
60 |
-
## References
|
61 |
-
|
62 |
-
Based on:
|
63 |
-
- https://github.com/hoangthuc701/GenAI-image-detection
|
64 |
-
- https://github.com/ptmaimai106/DetectGenerateImageByRealImageOnly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Search_Image/Bing_search.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
import requests
|
5 |
-
|
6 |
-
# Load Bing Search API key
|
7 |
-
load_dotenv()
|
8 |
-
BING_API_KEY = os.getenv("BING_API_KEY")
|
9 |
-
|
10 |
-
def print_json(obj):
|
11 |
-
"""Print the object as json"""
|
12 |
-
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))
|
13 |
-
|
14 |
-
|
15 |
-
def get_image_urls(search_results):
|
16 |
-
"""
|
17 |
-
Extracts image URLs from Bing Visual Search response.
|
18 |
-
Ref: https://learn.microsoft.com/en-us/bing/search-apis/bing-visual-search/how-to/search-response
|
19 |
-
|
20 |
-
Args:
|
21 |
-
search_results: A dict containing the Bing VisualSearch response data.
|
22 |
-
|
23 |
-
Returns:
|
24 |
-
A tuple containing two lists:
|
25 |
-
- List of image URLs from "PagesIncluding" section.
|
26 |
-
- List of image URLs from "VisualSearch" section (backup).
|
27 |
-
"""
|
28 |
-
|
29 |
-
pages_including_urls = []
|
30 |
-
visual_search_urls = []
|
31 |
-
|
32 |
-
if "tags" not in search_results:
|
33 |
-
return pages_including_urls, visual_search_urls
|
34 |
-
|
35 |
-
# Check for required keys directly
|
36 |
-
if not any(action.get("actions") for action in search_results["tags"]):
|
37 |
-
return pages_including_urls, visual_search_urls
|
38 |
-
|
39 |
-
|
40 |
-
for action in search_results["tags"]:
|
41 |
-
for result in action.get("actions", []):
|
42 |
-
# actions = PagesIncluding, main results
|
43 |
-
if result["name"] == "PagesIncluding":
|
44 |
-
pages_including_urls.extend(item["contentUrl"] for item in result["data"]["value"])
|
45 |
-
# actions = VisualSearch, back up results
|
46 |
-
elif result["name"] == "VisualSearch":
|
47 |
-
visual_search_urls.extend(item["contentUrl"] for item in result["data"]["value"])
|
48 |
-
|
49 |
-
return pages_including_urls, visual_search_urls
|
50 |
-
|
51 |
-
def reverse_image_search(image_path, subscription_key=BING_API_KEY):
|
52 |
-
"""Performs a reverse image search using the Bing Visual Search API.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
image_path: The path to the image file to search for.
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
A list of image URLs found that are similar to the image in the
|
59 |
-
specified path.
|
60 |
-
|
61 |
-
Raises:
|
62 |
-
requests.exceptions.RequestException: If the API request fails.
|
63 |
-
"""
|
64 |
-
base_uri = "https://api.bing.microsoft.com/v7.0/images/visualsearch"
|
65 |
-
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
|
66 |
-
|
67 |
-
try:
|
68 |
-
files = {"image": ("image", open(image_path, "rb"))}
|
69 |
-
response = requests.post(base_uri, headers=headers, files=files)
|
70 |
-
response.raise_for_status()
|
71 |
-
search_results = response.json()
|
72 |
-
|
73 |
-
return search_results
|
74 |
-
|
75 |
-
except requests.exceptions.RequestException as e:
|
76 |
-
raise requests.exceptions.RequestException(f"API request failed: {e}")
|
77 |
-
except OSError as e:
|
78 |
-
raise OSError(f"Error opening image file: {e}")
|
79 |
-
|
80 |
-
if __name__ == "__main__":
|
81 |
-
# Example usage:
|
82 |
-
image_path = "data/test_data/human_news.jpg"
|
83 |
-
try:
|
84 |
-
search_results = reverse_image_search(image_path)
|
85 |
-
image_urls, backup_image_urls = get_image_urls(search_results)
|
86 |
-
|
87 |
-
# Print the results
|
88 |
-
print("Image URLs from PagesIncluding:")
|
89 |
-
print(image_urls)
|
90 |
-
print("\nImage URLs from VisualSearch (backup):")
|
91 |
-
print(backup_image_urls)
|
92 |
-
except Exception as e:
|
93 |
-
print(f"An error occurred: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Search_Image/image_difference.py
DELETED
File without changes
|
src/images/Search_Image/image_model_share.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
from sklearn.metrics import roc_auc_score
|
2 |
-
from torchmetrics import Accuracy, Recall
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
import timm
|
5 |
-
import torch
|
6 |
-
from pytorch_lightning.callbacks import Model, EarlyStopping
|
7 |
-
import logging
|
8 |
-
from PIL import Image
|
9 |
-
import torchvision.transforms as transforms
|
10 |
-
from torchvision.transforms import v2
|
11 |
-
|
12 |
-
logging.basicConfig(filename='training.log',filemode='w',level=logging.INFO, force=True)
|
13 |
-
CHECKPOINT = "models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt"
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
class ImageClassifier(pl.LightningModule):
|
18 |
-
def __init__(self, lmd=0):
|
19 |
-
super().__init__()
|
20 |
-
self.model = timm.create_model('resnet50', pretrained=True, num_classes=1)
|
21 |
-
self.accuracy = Accuracy(task='binary', threshold=0.5)
|
22 |
-
self.recall = Recall(task='binary', threshold=0.5)
|
23 |
-
self.validation_outputs = []
|
24 |
-
self.lmd = lmd
|
25 |
-
|
26 |
-
def forward(self, x):
|
27 |
-
return self.model(x)
|
28 |
-
|
29 |
-
def training_step(self, batch):
|
30 |
-
images, labels, _ = batch
|
31 |
-
outputs = self.forward(images).squeeze()
|
32 |
-
|
33 |
-
print(f"Shape of outputs (training): {outputs.shape}")
|
34 |
-
print(f"Shape of labels (training): {labels.shape}")
|
35 |
-
|
36 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
37 |
-
logging.info(f"Training Step - ERM loss: {loss.item()}")
|
38 |
-
loss += self.lmd * (outputs ** 2).mean() # SD loss penalty
|
39 |
-
logging.info(f"Training Step - SD loss: {loss.item()}")
|
40 |
-
return loss
|
41 |
-
|
42 |
-
def validation_step(self, batch):
|
43 |
-
images, labels, _ = batch
|
44 |
-
outputs = self.forward(images).squeeze()
|
45 |
-
|
46 |
-
if outputs.shape == torch.Size([]):
|
47 |
-
return
|
48 |
-
|
49 |
-
print(f"Shape of outputs (validation): {outputs.shape}")
|
50 |
-
print(f"Shape of labels (validation): {labels.shape}")
|
51 |
-
|
52 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
53 |
-
preds = torch.sigmoid(outputs)
|
54 |
-
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
|
55 |
-
self.log('val_acc', self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True)
|
56 |
-
self.log('val_recall', self.recall(preds, labels.int()), prog_bar=True, sync_dist=True)
|
57 |
-
output = {"val_loss": loss, "preds": preds, "labels": labels}
|
58 |
-
self.validation_outputs.append(output)
|
59 |
-
logging.info(f"Validation Step - Batch loss: {loss.item()}")
|
60 |
-
return output
|
61 |
-
|
62 |
-
def predict_step(self, batch):
|
63 |
-
images, label, domain = batch
|
64 |
-
outputs = self.forward(images).squeeze()
|
65 |
-
preds = torch.sigmoid(outputs)
|
66 |
-
return preds, label, domain
|
67 |
-
|
68 |
-
def on_validation_epoch_end(self):
|
69 |
-
if not self.validation_outputs:
|
70 |
-
logging.warning("No outputs in validation step to process")
|
71 |
-
return
|
72 |
-
preds = torch.cat([x['preds'] for x in self.validation_outputs])
|
73 |
-
labels = torch.cat([x['labels'] for x in self.validation_outputs])
|
74 |
-
if labels.unique().size(0) == 1:
|
75 |
-
logging.warning("Only one class in validation step")
|
76 |
-
return
|
77 |
-
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
|
78 |
-
self.log('val_auc', auc_score, prog_bar=True, sync_dist=True)
|
79 |
-
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
|
80 |
-
self.validation_outputs = []
|
81 |
-
|
82 |
-
def configure_optimizers(self):
|
83 |
-
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
|
84 |
-
return optimizer
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
def load_image(image_path, transform=None):
|
89 |
-
image = Image.open(image_path).convert('RGB')
|
90 |
-
|
91 |
-
if transform:
|
92 |
-
image = transform(image)
|
93 |
-
|
94 |
-
return image
|
95 |
-
|
96 |
-
|
97 |
-
def predict_single_image(image_path, model, transform=None):
|
98 |
-
image = load_image(image_path, transform)
|
99 |
-
|
100 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
101 |
-
|
102 |
-
model.to(device)
|
103 |
-
|
104 |
-
image = image.to(device)
|
105 |
-
|
106 |
-
model.eval()
|
107 |
-
|
108 |
-
with torch.no_grad():
|
109 |
-
image = image.unsqueeze(0)
|
110 |
-
output = model(image).squeeze()
|
111 |
-
print(output)
|
112 |
-
prediction = torch.sigmoid(output).item()
|
113 |
-
|
114 |
-
return prediction
|
115 |
-
|
116 |
-
|
117 |
-
def image_generation_detection(image_path):
|
118 |
-
model = ImageClassifier.load_from_checkpoint(CHECKPOINT)
|
119 |
-
|
120 |
-
transform = v2.Compose([
|
121 |
-
transforms.ToTensor(),
|
122 |
-
v2.CenterCrop((256, 256)),
|
123 |
-
])
|
124 |
-
|
125 |
-
prediction = predict_single_image(image_path, model, transform)
|
126 |
-
print("prediction",prediction)
|
127 |
-
|
128 |
-
result = ""
|
129 |
-
if prediction <= 0.2:
|
130 |
-
result += "Most likely human"
|
131 |
-
image_prediction_label = "HUMAN"
|
132 |
-
else:
|
133 |
-
result += "Most likely machine"
|
134 |
-
image_prediction_label = "MACHINE"
|
135 |
-
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
|
136 |
-
result += f" with confidence = {round(image_confidence * 100, 2)}%"
|
137 |
-
image_confidence = round(image_confidence * 100, 2)
|
138 |
-
return image_prediction_label, image_confidence
|
139 |
-
|
140 |
-
|
141 |
-
if __name__ == "__main__":
|
142 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Search_Image/search.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
from google_img_source_search import ReverseImageSearcher
|
2 |
-
import requests
|
3 |
-
from io import BytesIO
|
4 |
-
from PIL import Image
|
5 |
-
import imagehash
|
6 |
-
from google_img_source_search import ReverseImageSearcher
|
7 |
-
|
8 |
-
def get_image_from_url(url):
|
9 |
-
response = requests.get(url)
|
10 |
-
return Image.open(BytesIO(response.content))
|
11 |
-
|
12 |
-
def standardize_image(image):
|
13 |
-
# Convert to RGB if needed
|
14 |
-
if image.mode in ('RGBA', 'LA'):
|
15 |
-
background = Image.new('RGB', image.size, (255, 255, 255))
|
16 |
-
background.paste(image, mask=image.split()[-1])
|
17 |
-
image = background
|
18 |
-
elif image.mode != 'RGB':
|
19 |
-
image = image.convert('RGB')
|
20 |
-
|
21 |
-
# Resize to standard size (e.g. 256x256)
|
22 |
-
standard_size = (256, 256)
|
23 |
-
image = image.resize(standard_size)
|
24 |
-
|
25 |
-
return image
|
26 |
-
|
27 |
-
def compare_images(image1, image2):
|
28 |
-
# Standardize both images first
|
29 |
-
img1_std = standardize_image(image1)
|
30 |
-
img2_std = standardize_image(image2)
|
31 |
-
|
32 |
-
hash1 = imagehash.average_hash(img1_std)
|
33 |
-
hash2 = imagehash.average_hash(img2_std)
|
34 |
-
return hash1 - hash2 # Returns the Hamming distance between the hashes
|
35 |
-
|
36 |
-
if __name__ == '__main__':
|
37 |
-
image_url = 'https://i.pinimg.com/originals/c4/50/35/c450352ac6ea8645ead206721673e8fb.png'
|
38 |
-
|
39 |
-
# Get the image from URL
|
40 |
-
url_image = get_image_from_url(image_url)
|
41 |
-
|
42 |
-
# Search image
|
43 |
-
rev_img_searcher = ReverseImageSearcher()
|
44 |
-
res = rev_img_searcher.search(image_url)
|
45 |
-
|
46 |
-
for search_item in res:
|
47 |
-
print(f'Title: {search_item.page_title}')
|
48 |
-
# print(f'Site: {search_item.page_url}')
|
49 |
-
print(f'Img: {search_item.image_url}\n')
|
50 |
-
|
51 |
-
# Compare each search result image with the input image
|
52 |
-
result_image = get_image_from_url(search_item.image_url)
|
53 |
-
result_difference = compare_images(result_image, url_image)
|
54 |
-
print(f"Difference with search result: {result_difference}")
|
55 |
-
if result_difference == 0:
|
56 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Search_Image/search_2.py
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import logging
|
3 |
-
import requests
|
4 |
-
from bs4 import BeautifulSoup
|
5 |
-
from typing import Dict, Optional
|
6 |
-
from urllib.parse import quote, urlparse
|
7 |
-
|
8 |
-
logging.basicConfig(
|
9 |
-
filename='error.log',
|
10 |
-
level=logging.INFO,
|
11 |
-
format='%(asctime)s | [%(levelname)s]: %(message)s',
|
12 |
-
datefmt='%m-%d-%Y / %I:%M:%S %p'
|
13 |
-
)
|
14 |
-
|
15 |
-
class SearchResults:
|
16 |
-
def __init__(self, results):
|
17 |
-
self.results = results
|
18 |
-
|
19 |
-
def __str__(self):
|
20 |
-
output = ""
|
21 |
-
for result in self.results:
|
22 |
-
output += "---\n"
|
23 |
-
output += f"Title: {result.get('title', 'Title not found')}\n"
|
24 |
-
output += f"Link: {result.get('link', 'Link not found')}\n"
|
25 |
-
output += "---\n"
|
26 |
-
return output
|
27 |
-
|
28 |
-
class GoogleReverseImageSearch:
|
29 |
-
def __init__(self):
|
30 |
-
self.base_url = "https://www.google.com/searchbyimage"
|
31 |
-
self.headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"}
|
32 |
-
self.retry_count = 3
|
33 |
-
self.retry_delay = 1
|
34 |
-
|
35 |
-
def response(self, query: str, image_url: str, max_results: int = 10, delay: int = 1) -> SearchResults:
|
36 |
-
self._validate_input(query, image_url)
|
37 |
-
|
38 |
-
encoded_query = quote(query)
|
39 |
-
encoded_image_url = quote(image_url)
|
40 |
-
|
41 |
-
url = f"{self.base_url}?q={encoded_query}&image_url={encoded_image_url}&sbisrc=cr_1_5_2"
|
42 |
-
|
43 |
-
all_results = []
|
44 |
-
start_index = 0
|
45 |
-
|
46 |
-
while len(all_results) < max_results:
|
47 |
-
if start_index != 0:
|
48 |
-
time.sleep(delay)
|
49 |
-
|
50 |
-
paginated_url = f"{url}&start={start_index}"
|
51 |
-
|
52 |
-
response = self._make_request(paginated_url)
|
53 |
-
if response is None:
|
54 |
-
break
|
55 |
-
|
56 |
-
search_results, valid_content = self._parse_search_results(response.text)
|
57 |
-
if not valid_content:
|
58 |
-
logging.warning("Unexpected HTML structure encountered.")
|
59 |
-
break
|
60 |
-
|
61 |
-
for result in search_results:
|
62 |
-
if len(all_results) >= max_results:
|
63 |
-
break
|
64 |
-
data = self._extract_result_data(result)
|
65 |
-
if data and data not in all_results:
|
66 |
-
all_results.append(data)
|
67 |
-
|
68 |
-
start_index += (len(all_results)-start_index)
|
69 |
-
|
70 |
-
if len(all_results) == 0:
|
71 |
-
logging.warning(f"No results were found for the given query: [{query}], and/or image URL: [{image_url}].")
|
72 |
-
return "No results found. Please try again with a different query and/or image URL."
|
73 |
-
else:
|
74 |
-
return SearchResults(all_results[:max_results])
|
75 |
-
|
76 |
-
def _validate_input(self, query: str, image_url: str):
|
77 |
-
if not query:
|
78 |
-
raise ValueError("Query not found. Please enter a query and try again.")
|
79 |
-
if not image_url:
|
80 |
-
raise ValueError("Image URL not found. Please enter an image URL and try again.")
|
81 |
-
if not self._validate_image_url(image_url):
|
82 |
-
raise ValueError("Invalid image URL. Please enter a valid image URL and try again.")
|
83 |
-
|
84 |
-
def _validate_image_url(self, url: str) -> bool:
|
85 |
-
parsed_url = urlparse(url)
|
86 |
-
path = parsed_url.path.lower()
|
87 |
-
valid_extensions = (".jpg", ".jpeg", ".png", ".webp")
|
88 |
-
return any(path.endswith(ext) for ext in valid_extensions)
|
89 |
-
|
90 |
-
def _make_request(self, url: str):
|
91 |
-
attempts = 0
|
92 |
-
while attempts < self.retry_count:
|
93 |
-
try:
|
94 |
-
response = requests.get(url, headers=self.headers)
|
95 |
-
if response.headers.get('Content-Type', '').startswith('text/html'):
|
96 |
-
response.raise_for_status()
|
97 |
-
return response
|
98 |
-
else:
|
99 |
-
logging.warning("Non-HTML content received.")
|
100 |
-
return None
|
101 |
-
except requests.exceptions.HTTPError as http_err:
|
102 |
-
logging.error(f"HTTP error occurred: {http_err}")
|
103 |
-
attempts += 1
|
104 |
-
time.sleep(self.retry_delay)
|
105 |
-
except Exception as err:
|
106 |
-
logging.error(f"An error occurred: {err}")
|
107 |
-
return None
|
108 |
-
return None
|
109 |
-
|
110 |
-
def _parse_search_results(self, html_content: str) -> (Optional[list], bool):
|
111 |
-
try:
|
112 |
-
soup = BeautifulSoup(html_content, "html.parser")
|
113 |
-
return soup.find_all('div', class_='g'), True
|
114 |
-
except Exception as e:
|
115 |
-
logging.error(f"Error parsing HTML content: {e}")
|
116 |
-
return None, False
|
117 |
-
|
118 |
-
def _extract_result_data(self, result) -> Dict:
|
119 |
-
link = result.find('a', href=True)['href'] if result.find('a', href=True) else None
|
120 |
-
title = result.find('h3').get_text(strip=True) if result.find('h3') else None
|
121 |
-
return {"link": link, "title": title} if link and title else {}
|
122 |
-
|
123 |
-
|
124 |
-
if __name__ == "__main__":
|
125 |
-
# request = GoogleReverseImageSearch()
|
126 |
-
|
127 |
-
# response = request.response(
|
128 |
-
# query="Example Query",
|
129 |
-
# image_url="https://ichef.bbci.co.uk/images/ic/1024xn/p0khzhhl.jpg.webp",
|
130 |
-
# max_results=5
|
131 |
-
# )
|
132 |
-
|
133 |
-
# print(response)
|
134 |
-
|
135 |
-
# Path to local image
|
136 |
-
image_path = "data/test_data/towel.jpg"
|
137 |
-
image_path = "C:\\TTProjects\\prj-nict-ai-content-detection\\data\\test_data\\towel.jpg"
|
138 |
-
|
139 |
-
import json
|
140 |
-
file_path = image_path
|
141 |
-
search_url = 'https://yandex.ru/images/search'
|
142 |
-
files = {'upfile': ('blob', open(file_path, 'rb'), 'image/jpeg')}
|
143 |
-
params = {'rpt': 'imageview', 'format': 'json', 'request': '{"blocks":[{"block":"b-page_type_search-by-image__link"}]}'}
|
144 |
-
response = requests.post(search_url, params=params, files=files)
|
145 |
-
query_string = json.loads(response.content)['blocks'][0]['params']['url']
|
146 |
-
img_search_url = search_url + '?' + query_string
|
147 |
-
print(img_search_url)
|
148 |
-
|
149 |
-
response = requests.get(img_search_url)
|
150 |
-
print(response.text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/Search_Image/search_yandex.py
DELETED
@@ -1,177 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import logging
|
3 |
-
import requests
|
4 |
-
from bs4 import BeautifulSoup
|
5 |
-
from typing import Dict, Optional
|
6 |
-
from urllib.parse import quote, urlparse
|
7 |
-
|
8 |
-
logging.basicConfig(
|
9 |
-
filename='error.log',
|
10 |
-
level=logging.INFO,
|
11 |
-
format='%(asctime)s | [%(levelname)s]: %(message)s',
|
12 |
-
datefmt='%m-%d-%Y / %I:%M:%S %p'
|
13 |
-
)
|
14 |
-
|
15 |
-
class SearchResults:
|
16 |
-
def __init__(self, results):
|
17 |
-
self.results = results
|
18 |
-
|
19 |
-
def __str__(self):
|
20 |
-
output = ""
|
21 |
-
for result in self.results:
|
22 |
-
output += "---\n"
|
23 |
-
output += f"Title: {result.get('title', 'Title not found')}\n"
|
24 |
-
output += f"Link: {result.get('link', 'Link not found')}\n"
|
25 |
-
output += "---\n"
|
26 |
-
return output
|
27 |
-
|
28 |
-
class ReverseImageSearch:
|
29 |
-
def __init__(self):
|
30 |
-
self.base_url = "https://yandex.ru/images/search"
|
31 |
-
self.headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"}
|
32 |
-
self.retry_count = 3
|
33 |
-
self.retry_delay = 1
|
34 |
-
|
35 |
-
def response(self, query: str, image_url: str, max_results: int = 10, delay: int = 1) -> SearchResults:
|
36 |
-
self._validate_input(query, image_url)
|
37 |
-
|
38 |
-
encoded_query = quote(query)
|
39 |
-
encoded_image_url = quote(image_url)
|
40 |
-
|
41 |
-
url = f"{self.base_url}?q={encoded_query}&image_url={encoded_image_url}&sbisrc=cr_1_5_2"
|
42 |
-
|
43 |
-
all_results = []
|
44 |
-
start_index = 0
|
45 |
-
|
46 |
-
while len(all_results) < max_results:
|
47 |
-
if start_index != 0:
|
48 |
-
time.sleep(delay)
|
49 |
-
|
50 |
-
paginated_url = f"{url}&start={start_index}"
|
51 |
-
|
52 |
-
response = self._make_request(paginated_url)
|
53 |
-
if response is None:
|
54 |
-
break
|
55 |
-
|
56 |
-
search_results, valid_content = self._parse_search_results(response.text)
|
57 |
-
if not valid_content:
|
58 |
-
logging.warning("Unexpected HTML structure encountered.")
|
59 |
-
break
|
60 |
-
|
61 |
-
for result in search_results:
|
62 |
-
if len(all_results) >= max_results:
|
63 |
-
break
|
64 |
-
data = self._extract_result_data(result)
|
65 |
-
if data and data not in all_results:
|
66 |
-
all_results.append(data)
|
67 |
-
|
68 |
-
start_index += (len(all_results)-start_index)
|
69 |
-
|
70 |
-
if len(all_results) == 0:
|
71 |
-
logging.warning(f"No results were found for the given query: [{query}], and/or image URL: [{image_url}].")
|
72 |
-
return "No results found. Please try again with a different query and/or image URL."
|
73 |
-
else:
|
74 |
-
return SearchResults(all_results[:max_results])
|
75 |
-
|
76 |
-
def _validate_input(self, query: str, image_url: str):
|
77 |
-
if not query:
|
78 |
-
raise ValueError("Query not found. Please enter a query and try again.")
|
79 |
-
if not image_url:
|
80 |
-
raise ValueError("Image URL not found. Please enter an image URL and try again.")
|
81 |
-
if not self._validate_image_url(image_url):
|
82 |
-
raise ValueError("Invalid image URL. Please enter a valid image URL and try again.")
|
83 |
-
|
84 |
-
def _validate_image_url(self, url: str) -> bool:
|
85 |
-
parsed_url = urlparse(url)
|
86 |
-
path = parsed_url.path.lower()
|
87 |
-
valid_extensions = (".jpg", ".jpeg", ".png", ".webp")
|
88 |
-
return any(path.endswith(ext) for ext in valid_extensions)
|
89 |
-
|
90 |
-
def _make_request(self, url: str):
|
91 |
-
attempts = 0
|
92 |
-
while attempts < self.retry_count:
|
93 |
-
try:
|
94 |
-
response = requests.get(url, headers=self.headers)
|
95 |
-
if response.headers.get('Content-Type', '').startswith('text/html'):
|
96 |
-
response.raise_for_status()
|
97 |
-
return response
|
98 |
-
else:
|
99 |
-
logging.warning("Non-HTML content received.")
|
100 |
-
return None
|
101 |
-
except requests.exceptions.HTTPError as http_err:
|
102 |
-
logging.error(f"HTTP error occurred: {http_err}")
|
103 |
-
attempts += 1
|
104 |
-
time.sleep(self.retry_delay)
|
105 |
-
except Exception as err:
|
106 |
-
logging.error(f"An error occurred: {err}")
|
107 |
-
return None
|
108 |
-
return None
|
109 |
-
|
110 |
-
def _parse_search_results(self, html_content: str) -> (Optional[list], bool):
|
111 |
-
try:
|
112 |
-
soup = BeautifulSoup(html_content, "html.parser")
|
113 |
-
return soup.find_all('div', class_='g'), True
|
114 |
-
except Exception as e:
|
115 |
-
logging.error(f"Error parsing HTML content: {e}")
|
116 |
-
return None, False
|
117 |
-
|
118 |
-
def _extract_result_data(self, result) -> Dict:
|
119 |
-
link = result.find('a', href=True)['href'] if result.find('a', href=True) else None
|
120 |
-
title = result.find('h3').get_text(strip=True) if result.find('h3') else None
|
121 |
-
return {"link": link, "title": title} if link and title else {}
|
122 |
-
|
123 |
-
def yandex_reverse_image_search(image_url):
|
124 |
-
# Simulate a user agent to avoid being blocked
|
125 |
-
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
|
126 |
-
|
127 |
-
try:
|
128 |
-
response = requests.get(image_url, headers=headers)
|
129 |
-
response.raise_for_status() # Raise an exception for bad status codes
|
130 |
-
|
131 |
-
# Parse the HTML content
|
132 |
-
soup = BeautifulSoup(response.content, 'html.parser')
|
133 |
-
|
134 |
-
# Extract image URLs (example - adapt based on Yandex's HTML structure)
|
135 |
-
image_urls = [img['src'] for img in soup.find_all('img')]
|
136 |
-
|
137 |
-
# Extract related searches (example - adapt based on Yandex's HTML structure)
|
138 |
-
related_searches = [text for text in soup.find_all(class_="related-searches")]
|
139 |
-
|
140 |
-
return image_urls, related_searches
|
141 |
-
|
142 |
-
except requests.exceptions.RequestException as e:
|
143 |
-
print(f"Error fetching image: {e}")
|
144 |
-
return [], []
|
145 |
-
|
146 |
-
|
147 |
-
if __name__ == "__main__":
|
148 |
-
# request = GoogleReverseImageSearch()
|
149 |
-
|
150 |
-
# response = request.response(
|
151 |
-
# query="Example Query",
|
152 |
-
# image_url="https://ichef.bbci.co.uk/images/ic/1024xn/p0khzhhl.jpg.webp",
|
153 |
-
# max_results=5
|
154 |
-
# )
|
155 |
-
|
156 |
-
# print(response)
|
157 |
-
|
158 |
-
# Path to local image
|
159 |
-
image_path = "data/test_data/towel.jpg"
|
160 |
-
image_path = "C:\\TTProjects\\prj-nict-ai-content-detection\\data\\test_data\\towel.jpg"
|
161 |
-
|
162 |
-
import json
|
163 |
-
file_path = image_path
|
164 |
-
search_url = 'https://yandex.ru/images/search'
|
165 |
-
files = {'upfile': ('blob', open(file_path, 'rb'), 'image/jpeg')}
|
166 |
-
params = {'rpt': 'imageview', 'format': 'json', 'request': '{"blocks":[{"block":"b-page_type_search-by-image__link"}]}'}
|
167 |
-
response = requests.post(search_url, params=params, files=files)
|
168 |
-
query_string = json.loads(response.content)['blocks'][0]['params']['url']
|
169 |
-
img_search_url = search_url + '?' + query_string
|
170 |
-
print(img_search_url)
|
171 |
-
|
172 |
-
image_urls, related_searches = yandex_reverse_image_search(img_search_url)
|
173 |
-
|
174 |
-
print("Image URLs:", image_urls)
|
175 |
-
print("Related Searches:", related_searches)
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/diffusion_data_loader.py
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import collections
|
3 |
-
import random
|
4 |
-
from typing import Iterator
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
import torchdata.datapipes as dp
|
9 |
-
from imwatermark import WatermarkEncoder
|
10 |
-
from PIL import (
|
11 |
-
Image,
|
12 |
-
ImageFile,
|
13 |
-
)
|
14 |
-
from torch.utils.data import DataLoader
|
15 |
-
from torchdata.datapipes.iter import (
|
16 |
-
Concater,
|
17 |
-
FileLister,
|
18 |
-
FileOpener,
|
19 |
-
SampleMultiplexer,
|
20 |
-
)
|
21 |
-
from torchvision.transforms import v2
|
22 |
-
from tqdm import tqdm
|
23 |
-
|
24 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
25 |
-
Image.MAX_IMAGE_PIXELS = 1000000000
|
26 |
-
|
27 |
-
encoder = WatermarkEncoder()
|
28 |
-
encoder.set_watermark("bytes", b"test")
|
29 |
-
|
30 |
-
DOMAIN_LABELS = {
|
31 |
-
0: "laion",
|
32 |
-
1: "StableDiffusion",
|
33 |
-
2: "dalle2",
|
34 |
-
3: "dalle3",
|
35 |
-
4: "midjourney",
|
36 |
-
}
|
37 |
-
|
38 |
-
N_SAMPLES = {
|
39 |
-
0: (115346, 14418, 14419),
|
40 |
-
1: (22060, 2757, 2758),
|
41 |
-
4: (21096, 2637, 2637),
|
42 |
-
2: (13582, 1697, 1699),
|
43 |
-
3: (12027, 1503, 1504),
|
44 |
-
}
|
45 |
-
|
46 |
-
|
47 |
-
@dp.functional_datapipe("collect_from_workers")
|
48 |
-
class WorkerResultCollector(dp.iter.IterDataPipe):
|
49 |
-
def __init__(self, source: dp.iter.IterDataPipe):
|
50 |
-
self.source = source
|
51 |
-
|
52 |
-
def __iter__(self) -> Iterator:
|
53 |
-
yield from self.source
|
54 |
-
|
55 |
-
def is_replicable(self) -> bool:
|
56 |
-
"""Method to force data back to main process"""
|
57 |
-
return False
|
58 |
-
|
59 |
-
|
60 |
-
def crop_bottom(image, cutoff=16):
|
61 |
-
return image[:, :-cutoff, :]
|
62 |
-
|
63 |
-
|
64 |
-
def random_gaussian_blur(image, p=0.01):
|
65 |
-
if random.random() < p:
|
66 |
-
return v2.functional.gaussian_blur(image, kernel_size=5)
|
67 |
-
return image
|
68 |
-
|
69 |
-
|
70 |
-
def random_invisible_watermark(image, p=0.2):
|
71 |
-
image_np = np.array(image)
|
72 |
-
image_np = np.transpose(image_np, (1, 2, 0))
|
73 |
-
|
74 |
-
if image_np.ndim == 2: # Grayscale image
|
75 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
|
76 |
-
elif image_np.shape[2] == 4: # RGBA image
|
77 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2BGR)
|
78 |
-
|
79 |
-
if image_np.shape[0] < 256 or image_np.shape[1] < 256:
|
80 |
-
image_np = cv2.resize(
|
81 |
-
image_np,
|
82 |
-
(256, 256),
|
83 |
-
interpolation=cv2.INTER_AREA,
|
84 |
-
)
|
85 |
-
|
86 |
-
if random.random() < p:
|
87 |
-
return encoder.encode(image_np, method="dwtDct")
|
88 |
-
|
89 |
-
return image_np
|
90 |
-
|
91 |
-
|
92 |
-
def build_transform(split: str):
|
93 |
-
train_transform = v2.Compose(
|
94 |
-
[
|
95 |
-
v2.Lambda(crop_bottom),
|
96 |
-
v2.RandomCrop((256, 256), pad_if_needed=True),
|
97 |
-
v2.Lambda(random_gaussian_blur),
|
98 |
-
v2.RandomGrayscale(p=0.05),
|
99 |
-
v2.Lambda(random_invisible_watermark),
|
100 |
-
v2.ToImage(),
|
101 |
-
],
|
102 |
-
)
|
103 |
-
|
104 |
-
eval_transform = v2.Compose(
|
105 |
-
[
|
106 |
-
v2.CenterCrop((256, 256)),
|
107 |
-
],
|
108 |
-
)
|
109 |
-
transform = train_transform if split == "train" else eval_transform
|
110 |
-
|
111 |
-
return transform
|
112 |
-
|
113 |
-
|
114 |
-
def dp_to_tuple_train(input_dict):
|
115 |
-
transform = build_transform("train")
|
116 |
-
return (
|
117 |
-
transform(input_dict[".jpg"]),
|
118 |
-
input_dict[".label.cls"],
|
119 |
-
input_dict[".domain_label.cls"],
|
120 |
-
)
|
121 |
-
|
122 |
-
|
123 |
-
def dp_to_tuple_eval(input_dict):
|
124 |
-
transform = build_transform("eval")
|
125 |
-
return (
|
126 |
-
transform(input_dict[".jpg"]),
|
127 |
-
input_dict[".label.cls"],
|
128 |
-
input_dict[".domain_label.cls"],
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
-
def load_dataset(domains: list[int], split: str):
|
133 |
-
laion_lister = FileLister("./data/laion400m_data", f"{split}*.tar")
|
134 |
-
genai_lister = {
|
135 |
-
d: FileLister(
|
136 |
-
f"./data/genai-images/{DOMAIN_LABELS[d]}",
|
137 |
-
f"{split}*.tar",
|
138 |
-
)
|
139 |
-
for d in domains
|
140 |
-
if DOMAIN_LABELS[d] != "laion"
|
141 |
-
}
|
142 |
-
weight_genai = 1 / len(genai_lister)
|
143 |
-
|
144 |
-
def open_lister(lister):
|
145 |
-
opener = FileOpener(lister, mode="b")
|
146 |
-
return opener.load_from_tar().routed_decode().webdataset()
|
147 |
-
|
148 |
-
buffer_size1 = 100 if split == "train" else 10
|
149 |
-
buffer_size2 = 100 if split == "train" else 10
|
150 |
-
|
151 |
-
if split != "train":
|
152 |
-
all_lister = [laion_lister] + list(genai_lister.values())
|
153 |
-
dp = open_lister(Concater(*all_lister)).sharding_filter()
|
154 |
-
else:
|
155 |
-
laion_dp = (
|
156 |
-
open_lister(laion_lister.shuffle())
|
157 |
-
.cycle()
|
158 |
-
.sharding_filter()
|
159 |
-
.shuffle(buffer_size=buffer_size1)
|
160 |
-
)
|
161 |
-
genai_dp = {
|
162 |
-
open_lister(genai_lister[d].shuffle())
|
163 |
-
.cycle()
|
164 |
-
.sharding_filter()
|
165 |
-
.shuffle(
|
166 |
-
buffer_size=buffer_size1,
|
167 |
-
): weight_genai
|
168 |
-
for d in domains
|
169 |
-
if DOMAIN_LABELS[d] != "laion"
|
170 |
-
}
|
171 |
-
dp = SampleMultiplexer({laion_dp: 1, **genai_dp}).shuffle(
|
172 |
-
buffer_size=buffer_size2,
|
173 |
-
)
|
174 |
-
|
175 |
-
if split == "train":
|
176 |
-
dp = dp.map(dp_to_tuple_train)
|
177 |
-
else:
|
178 |
-
dp = dp.map(dp_to_tuple_eval)
|
179 |
-
|
180 |
-
return dp
|
181 |
-
|
182 |
-
|
183 |
-
def load_dataloader(
|
184 |
-
domains: list[int],
|
185 |
-
split: str,
|
186 |
-
batch_size: int = 32,
|
187 |
-
num_workers: int = 4,
|
188 |
-
):
|
189 |
-
dp = load_dataset(domains, split)
|
190 |
-
# if split == "train":
|
191 |
-
# dp = UnderSamplerIterDataPipe(dp, {0: 0.5, 1: 0.5}, seed=42)
|
192 |
-
dp = dp.batch(batch_size).collate()
|
193 |
-
dl = DataLoader(
|
194 |
-
dp,
|
195 |
-
batch_size=None,
|
196 |
-
num_workers=num_workers,
|
197 |
-
pin_memory=True,
|
198 |
-
)
|
199 |
-
|
200 |
-
return dl
|
201 |
-
|
202 |
-
|
203 |
-
if __name__ == "__main__":
|
204 |
-
parser = argparse.ArgumentParser()
|
205 |
-
|
206 |
-
args = parser.parse_args()
|
207 |
-
|
208 |
-
# testing code
|
209 |
-
dl = load_dataloader([0, 1], "train", num_workers=8)
|
210 |
-
y_dist = collections.Counter()
|
211 |
-
d_dist = collections.Counter()
|
212 |
-
|
213 |
-
for i, (img, y, d) in tqdm(enumerate(dl)):
|
214 |
-
if i % 100 == 0:
|
215 |
-
print(y, d)
|
216 |
-
if i == 400:
|
217 |
-
break
|
218 |
-
y_dist.update(y.numpy())
|
219 |
-
d_dist.update(d.numpy())
|
220 |
-
|
221 |
-
print("class label")
|
222 |
-
for label in sorted(y_dist):
|
223 |
-
frequency = y_dist[label] / sum(y_dist.values())
|
224 |
-
print(f"β’ {label}: {frequency:.2%} ({y_dist[label]})")
|
225 |
-
|
226 |
-
print("domain label")
|
227 |
-
for label in sorted(d_dist):
|
228 |
-
frequency = d_dist[label] / sum(d_dist.values())
|
229 |
-
print(f"β’ {label}: {frequency:.2%} ({d_dist[label]})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/diffusion_model_classifier.py
DELETED
@@ -1,293 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import pytorch_lightning as pl
|
7 |
-
import timm
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torchvision.transforms as transforms
|
11 |
-
from PIL import Image
|
12 |
-
from pytorch_lightning.callbacks import (
|
13 |
-
EarlyStopping,
|
14 |
-
ModelCheckpoint,
|
15 |
-
)
|
16 |
-
from sklearn.metrics import roc_auc_score
|
17 |
-
from torchmetrics import (
|
18 |
-
Accuracy,
|
19 |
-
Recall,
|
20 |
-
)
|
21 |
-
|
22 |
-
from .diffusion_data_loader import load_dataloader
|
23 |
-
|
24 |
-
|
25 |
-
class ImageClassifier(pl.LightningModule):
|
26 |
-
def __init__(self, lmd=0):
|
27 |
-
super().__init__()
|
28 |
-
self.model = timm.create_model(
|
29 |
-
"resnet50",
|
30 |
-
pretrained=True,
|
31 |
-
num_classes=1,
|
32 |
-
)
|
33 |
-
self.accuracy = Accuracy(task="binary", threshold=0.5)
|
34 |
-
self.recall = Recall(task="binary", threshold=0.5)
|
35 |
-
self.validation_outputs = []
|
36 |
-
self.lmd = lmd
|
37 |
-
|
38 |
-
def forward(self, x):
|
39 |
-
return self.model(x)
|
40 |
-
|
41 |
-
def training_step(self, batch):
|
42 |
-
images, labels, _ = batch
|
43 |
-
outputs = self.forward(images).squeeze()
|
44 |
-
|
45 |
-
print(f"Shape of outputs (training): {outputs.shape}")
|
46 |
-
print(f"Shape of labels (training): {labels.shape}")
|
47 |
-
|
48 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
49 |
-
logging.info(f"Training Step - ERM loss: {loss.item()}")
|
50 |
-
loss += self.lmd * (outputs**2).mean() # SD loss penalty
|
51 |
-
logging.info(f"Training Step - SD loss: {loss.item()}")
|
52 |
-
return loss
|
53 |
-
|
54 |
-
def validation_step(self, batch):
|
55 |
-
images, labels, _ = batch
|
56 |
-
outputs = self.forward(images).squeeze()
|
57 |
-
|
58 |
-
if outputs.shape == torch.Size([]):
|
59 |
-
return
|
60 |
-
|
61 |
-
print(f"Shape of outputs (validation): {outputs.shape}")
|
62 |
-
print(f"Shape of labels (validation): {labels.shape}")
|
63 |
-
|
64 |
-
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
|
65 |
-
preds = torch.sigmoid(outputs)
|
66 |
-
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
67 |
-
self.log(
|
68 |
-
"val_acc",
|
69 |
-
self.accuracy(preds, labels.int()),
|
70 |
-
prog_bar=True,
|
71 |
-
sync_dist=True,
|
72 |
-
)
|
73 |
-
self.log(
|
74 |
-
"val_recall",
|
75 |
-
self.recall(preds, labels.int()),
|
76 |
-
prog_bar=True,
|
77 |
-
sync_dist=True,
|
78 |
-
)
|
79 |
-
output = {"val_loss": loss, "preds": preds, "labels": labels}
|
80 |
-
self.validation_outputs.append(output)
|
81 |
-
logging.info(f"Validation Step - Batch loss: {loss.item()}")
|
82 |
-
return output
|
83 |
-
|
84 |
-
def predict_step(self, batch):
|
85 |
-
images, label, domain = batch
|
86 |
-
outputs = self.forward(images).squeeze()
|
87 |
-
preds = torch.sigmoid(outputs)
|
88 |
-
return preds, label, domain
|
89 |
-
|
90 |
-
def on_validation_epoch_end(self):
|
91 |
-
if not self.validation_outputs:
|
92 |
-
logging.warning("No outputs in validation step to process")
|
93 |
-
return
|
94 |
-
preds = torch.cat([x["preds"] for x in self.validation_outputs])
|
95 |
-
labels = torch.cat([x["labels"] for x in self.validation_outputs])
|
96 |
-
if labels.unique().size(0) == 1:
|
97 |
-
logging.warning("Only one class in validation step")
|
98 |
-
return
|
99 |
-
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
|
100 |
-
self.log("val_auc", auc_score, prog_bar=True, sync_dist=True)
|
101 |
-
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
|
102 |
-
self.validation_outputs = []
|
103 |
-
|
104 |
-
def configure_optimizers(self):
|
105 |
-
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
|
106 |
-
return optimizer
|
107 |
-
|
108 |
-
|
109 |
-
def load_image(image_path, transform=None):
|
110 |
-
image = Image.open(image_path).convert("RGB")
|
111 |
-
|
112 |
-
if transform:
|
113 |
-
image = transform(image)
|
114 |
-
|
115 |
-
return image
|
116 |
-
|
117 |
-
|
118 |
-
def predict_single_image(image, model):
|
119 |
-
|
120 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
121 |
-
|
122 |
-
model.to(device)
|
123 |
-
|
124 |
-
image = image.to(device)
|
125 |
-
|
126 |
-
model.eval()
|
127 |
-
|
128 |
-
with torch.no_grad():
|
129 |
-
image = image.unsqueeze(0)
|
130 |
-
output = model(image).squeeze()
|
131 |
-
prediction = torch.sigmoid(output).item()
|
132 |
-
|
133 |
-
return prediction
|
134 |
-
|
135 |
-
|
136 |
-
if __name__ == "__main__":
|
137 |
-
checkpoint_callback = ModelCheckpoint(
|
138 |
-
monitor="val_loss",
|
139 |
-
dirpath="./model_checkpoints/",
|
140 |
-
filename="image-classifier-{step}-{val_loss:.2f}",
|
141 |
-
save_top_k=3,
|
142 |
-
mode="min",
|
143 |
-
every_n_train_steps=1001,
|
144 |
-
enable_version_counter=True,
|
145 |
-
)
|
146 |
-
|
147 |
-
early_stop_callback = EarlyStopping(
|
148 |
-
monitor="val_loss",
|
149 |
-
patience=4,
|
150 |
-
mode="min",
|
151 |
-
)
|
152 |
-
|
153 |
-
parser = argparse.ArgumentParser()
|
154 |
-
parser.add_argument(
|
155 |
-
"--ckpt_path",
|
156 |
-
help="checkpoint to continue from",
|
157 |
-
required=False,
|
158 |
-
)
|
159 |
-
parser.add_argument(
|
160 |
-
"--predict",
|
161 |
-
help="predict on test set",
|
162 |
-
action="store_true",
|
163 |
-
)
|
164 |
-
parser.add_argument("--reset", help="reset training", action="store_true")
|
165 |
-
parser.add_argument(
|
166 |
-
"--predict_image",
|
167 |
-
help="predict the class of a single image",
|
168 |
-
action="store_true",
|
169 |
-
)
|
170 |
-
parser.add_argument(
|
171 |
-
"--image_path",
|
172 |
-
help="path to the image to predict",
|
173 |
-
type=str,
|
174 |
-
required=False,
|
175 |
-
)
|
176 |
-
parser.add_argument(
|
177 |
-
"--dir",
|
178 |
-
help="path to the images to predict",
|
179 |
-
type=str,
|
180 |
-
required=False,
|
181 |
-
)
|
182 |
-
parser.add_argument(
|
183 |
-
"--output_file",
|
184 |
-
help="path to output file",
|
185 |
-
type=str,
|
186 |
-
required=False,
|
187 |
-
)
|
188 |
-
args = parser.parse_args()
|
189 |
-
|
190 |
-
train_domains = [0, 1, 4]
|
191 |
-
val_domains = [0, 1, 4]
|
192 |
-
lmd_value = 0
|
193 |
-
|
194 |
-
if args.predict:
|
195 |
-
test_dl = load_dataloader(
|
196 |
-
[0, 1, 2, 3, 4],
|
197 |
-
"test",
|
198 |
-
batch_size=10,
|
199 |
-
num_workers=1,
|
200 |
-
)
|
201 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
202 |
-
trainer = pl.Trainer()
|
203 |
-
predictions = trainer.predict(model, dataloaders=test_dl)
|
204 |
-
preds, labels, domains = zip(*predictions)
|
205 |
-
preds = torch.cat(preds).cpu().numpy()
|
206 |
-
labels = torch.cat(labels).cpu().numpy()
|
207 |
-
domains = torch.cat(domains).cpu().numpy()
|
208 |
-
print(preds.shape, labels.shape, domains.shape)
|
209 |
-
df = pd.DataFrame(
|
210 |
-
{"preds": preds, "labels": labels, "domains": domains},
|
211 |
-
)
|
212 |
-
filename = "preds-" + args.ckpt_path.split("/")[-1]
|
213 |
-
df.to_csv(f"outputs/{filename}.csv", index=False)
|
214 |
-
elif args.predict_image:
|
215 |
-
image_path = args.image_path
|
216 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
217 |
-
|
218 |
-
# Define the transformations for the image
|
219 |
-
transform = transforms.Compose(
|
220 |
-
[
|
221 |
-
transforms.CenterCrop((256, 256)),
|
222 |
-
transforms.ToTensor(),
|
223 |
-
],
|
224 |
-
)
|
225 |
-
image = load_image(image_path, transform)
|
226 |
-
prediction = predict_single_image(image, model)
|
227 |
-
print("prediction", prediction)
|
228 |
-
|
229 |
-
# Output the prediction
|
230 |
-
print(
|
231 |
-
f"Prediction for {image_path}: "
|
232 |
-
f"{'Human' if prediction <= 0.05 else 'Generated'}",
|
233 |
-
)
|
234 |
-
elif args.dir is not None:
|
235 |
-
predictions = []
|
236 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
237 |
-
transform = transforms.Compose(
|
238 |
-
[
|
239 |
-
transforms.CenterCrop((256, 256)),
|
240 |
-
transforms.ToTensor(),
|
241 |
-
],
|
242 |
-
)
|
243 |
-
for root, dirs, files in os.walk(os.path.abspath(args.dir)):
|
244 |
-
for f_name in files:
|
245 |
-
f = os.path.join(root, f_name)
|
246 |
-
print(f"Predicting: {f}")
|
247 |
-
p = predict_single_image(f, model)
|
248 |
-
predictions.append([f, f.split("/")[-2], p, p > 0.5])
|
249 |
-
print(f"--predicted: {p}")
|
250 |
-
|
251 |
-
df = pd.DataFrame(
|
252 |
-
predictions,
|
253 |
-
columns=["path", "folder", "pred", "class"],
|
254 |
-
)
|
255 |
-
df.to_csv(args.output_file, index=False)
|
256 |
-
else:
|
257 |
-
logging.basicConfig(
|
258 |
-
filename="training.log",
|
259 |
-
filemode="w",
|
260 |
-
level=logging.INFO,
|
261 |
-
force=True,
|
262 |
-
)
|
263 |
-
train_dl = load_dataloader(
|
264 |
-
train_domains,
|
265 |
-
"train",
|
266 |
-
batch_size=128,
|
267 |
-
num_workers=4,
|
268 |
-
)
|
269 |
-
logging.info("Training dataloader loaded")
|
270 |
-
val_dl = load_dataloader(
|
271 |
-
val_domains,
|
272 |
-
"val",
|
273 |
-
batch_size=128,
|
274 |
-
num_workers=4,
|
275 |
-
)
|
276 |
-
logging.info("Validation dataloader loaded")
|
277 |
-
|
278 |
-
if args.reset:
|
279 |
-
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
|
280 |
-
else:
|
281 |
-
model = ImageClassifier(lmd=lmd_value)
|
282 |
-
trainer = pl.Trainer(
|
283 |
-
callbacks=[checkpoint_callback, early_stop_callback],
|
284 |
-
max_steps=20000,
|
285 |
-
val_check_interval=1000,
|
286 |
-
check_val_every_n_epoch=None,
|
287 |
-
)
|
288 |
-
trainer.fit(
|
289 |
-
model=model,
|
290 |
-
train_dataloaders=train_dl,
|
291 |
-
val_dataloaders=val_dl,
|
292 |
-
ckpt_path=args.ckpt_path if not args.reset else None,
|
293 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/diffusion_utils_sampling.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import random
|
3 |
-
from typing import Callable
|
4 |
-
|
5 |
-
from torchdata.datapipes.iter import IterDataPipe
|
6 |
-
|
7 |
-
|
8 |
-
def get_second_entry(sample):
|
9 |
-
return sample[1]
|
10 |
-
|
11 |
-
|
12 |
-
class UnderSamplerIterDataPipe(IterDataPipe):
|
13 |
-
"""Dataset wrapper for under-sampling.
|
14 |
-
|
15 |
-
Copied from: https://github.com/MaxHalford/pytorch-resample/blob/master/pytorch_resample/under.py # noqa
|
16 |
-
Modified to work with multiple labels.
|
17 |
-
|
18 |
-
MIT License
|
19 |
-
|
20 |
-
Copyright (c) 2020 Max Halford
|
21 |
-
|
22 |
-
This method is based on rejection sampling.
|
23 |
-
|
24 |
-
Parameters:
|
25 |
-
dataset
|
26 |
-
desired_dist: The desired class distribution.
|
27 |
-
The keys are the classes whilst the
|
28 |
-
values are the desired class percentages.
|
29 |
-
The values are normalised so that sum up
|
30 |
-
to 1.
|
31 |
-
label_getter: A function that takes a sample and returns its label.
|
32 |
-
seed: Random seed for reproducibility.
|
33 |
-
|
34 |
-
Attributes:
|
35 |
-
actual_dist: The counts of the observed sample labels.
|
36 |
-
rng: A random number generator instance.
|
37 |
-
|
38 |
-
References:
|
39 |
-
- https://www.wikiwand.com/en/Rejection_sampling
|
40 |
-
|
41 |
-
"""
|
42 |
-
|
43 |
-
def __init__(
|
44 |
-
self,
|
45 |
-
dataset: IterDataPipe,
|
46 |
-
desired_dist: dict,
|
47 |
-
label_getter: Callable = get_second_entry,
|
48 |
-
seed: int = None,
|
49 |
-
):
|
50 |
-
|
51 |
-
self.dataset = dataset
|
52 |
-
self.desired_dist = {
|
53 |
-
c: p / sum(desired_dist.values()) for c, p in desired_dist.items()
|
54 |
-
}
|
55 |
-
self.label_getter = label_getter
|
56 |
-
self.seed = seed
|
57 |
-
|
58 |
-
self.actual_dist = collections.Counter()
|
59 |
-
self.rng = random.Random(seed)
|
60 |
-
self._pivot = None
|
61 |
-
|
62 |
-
def __iter__(self):
|
63 |
-
|
64 |
-
for dp in self.dataset:
|
65 |
-
y = self.label_getter(dp)
|
66 |
-
|
67 |
-
self.actual_dist[y] += 1
|
68 |
-
|
69 |
-
# To ease notation
|
70 |
-
f = self.desired_dist
|
71 |
-
g = self.actual_dist
|
72 |
-
|
73 |
-
# Check if the pivot needs to be changed
|
74 |
-
if y != self._pivot:
|
75 |
-
self._pivot = max(g.keys(), key=lambda y: f[y] / g[y])
|
76 |
-
else:
|
77 |
-
yield dp
|
78 |
-
continue
|
79 |
-
|
80 |
-
# Determine the sampling ratio if the observed label
|
81 |
-
# is not the pivot
|
82 |
-
M = f[self._pivot] / g[self._pivot]
|
83 |
-
ratio = f[y] / (M * g[y])
|
84 |
-
|
85 |
-
if ratio < 1 and self.rng.random() < ratio:
|
86 |
-
yield dp
|
87 |
-
|
88 |
-
@classmethod
|
89 |
-
def expected_size(cls, n, desired_dist, actual_dist):
|
90 |
-
M = max(
|
91 |
-
desired_dist.get(k) / actual_dist.get(k)
|
92 |
-
for k in set(desired_dist) | set(actual_dist)
|
93 |
-
)
|
94 |
-
return int(n / M)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/images/image_demo.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torchvision.transforms as transforms
|
3 |
-
from CNN_model_classifier import predict_cnn
|
4 |
-
from diffusion_model_classifier import (
|
5 |
-
ImageClassifier,
|
6 |
-
predict_single_image,
|
7 |
-
)
|
8 |
-
|
9 |
-
gr.set_static_paths(paths=["samples/"])
|
10 |
-
diffusion_model = (
|
11 |
-
"Diffusion/model_checkpoints/image-classifier-step=7007-val_loss=0.09.ckpt"
|
12 |
-
)
|
13 |
-
cnn_model = "CNN/model_checkpoints/blur_jpg_prob0.5.pth"
|
14 |
-
|
15 |
-
|
16 |
-
def get_prediction_diffusion(image):
|
17 |
-
model = ImageClassifier.load_from_checkpoint(diffusion_model)
|
18 |
-
|
19 |
-
prediction = predict_single_image(image, model)
|
20 |
-
print(prediction)
|
21 |
-
return (prediction >= 0.001, prediction)
|
22 |
-
|
23 |
-
|
24 |
-
def get_prediction_cnn(image):
|
25 |
-
prediction = predict_cnn(image, cnn_model)
|
26 |
-
return (prediction >= 0.5, prediction)
|
27 |
-
|
28 |
-
|
29 |
-
def predict(inp):
|
30 |
-
# Define the transformations for the image
|
31 |
-
transform = transforms.Compose(
|
32 |
-
[
|
33 |
-
transforms.Resize((224, 224)), # Image size expected by ResNet50
|
34 |
-
transforms.ToTensor(),
|
35 |
-
transforms.Normalize(
|
36 |
-
mean=[0.485, 0.456, 0.406],
|
37 |
-
std=[0.229, 0.224, 0.225],
|
38 |
-
),
|
39 |
-
],
|
40 |
-
)
|
41 |
-
image_tensor = transform(inp)
|
42 |
-
pred_diff, prob_diff = get_prediction_diffusion(image_tensor)
|
43 |
-
pred_cnn, prob_cnn = get_prediction_cnn(image_tensor)
|
44 |
-
verdict = (
|
45 |
-
"AI Generated" if (pred_diff or pred_cnn) else "No GenAI detected"
|
46 |
-
)
|
47 |
-
return (
|
48 |
-
f"<h1>{verdict}</h1>"
|
49 |
-
f"<ul>"
|
50 |
-
f"<li>Diffusion detection score: {prob_diff:.2} "
|
51 |
-
f"{'(MATCH)' if pred_diff else ''}</li>"
|
52 |
-
f"<li>CNN detection score: {prob_cnn:.1%} "
|
53 |
-
f"{'(MATCH)' if pred_cnn else ''}</li>"
|
54 |
-
f"</ul>"
|
55 |
-
)
|
56 |
-
|
57 |
-
|
58 |
-
demo = gr.Interface(
|
59 |
-
title="AI-generated image detection",
|
60 |
-
description="Demo by NICT & Tokyo Techies ",
|
61 |
-
fn=predict,
|
62 |
-
inputs=gr.Image(type="pil"),
|
63 |
-
outputs=gr.HTML(),
|
64 |
-
examples=[
|
65 |
-
["samples/fake_dalle.jpg", "Generated (Dall-E)"],
|
66 |
-
["samples/fake_midjourney.png", "Generated (MidJourney)"],
|
67 |
-
["samples/fake_stable.jpg", "Generated (Stable Diffusion)"],
|
68 |
-
["samples/fake_cnn.png", "Generated (GAN)"],
|
69 |
-
["samples/real.png", "Organic"],
|
70 |
-
],
|
71 |
-
)
|
72 |
-
|
73 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
from texts.models import TextDetector
|
2 |
-
|
3 |
-
|
4 |
-
def extract_text_and_images(path: str):
|
5 |
-
text_content = ""
|
6 |
-
image_paths = ""
|
7 |
-
return text_content, image_paths
|
8 |
-
|
9 |
-
|
10 |
-
def process_document(document_path) -> list:
|
11 |
-
"""
|
12 |
-
Processes a given document, separating text and images,
|
13 |
-
and then analyzes them.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
document_path: Path to the document.
|
17 |
-
|
18 |
-
Returns:
|
19 |
-
A list containing the AI content percentage for text and images.
|
20 |
-
"""
|
21 |
-
|
22 |
-
# Extract text and images from the document
|
23 |
-
text_content, image_paths = extract_text_and_images(document_path)
|
24 |
-
|
25 |
-
# Analyze text content
|
26 |
-
text_detector = TextDetector()
|
27 |
-
text_ai_content_percentage = text_detector.analyze_text(text_content)
|
28 |
-
|
29 |
-
# Analyze image content
|
30 |
-
image_ai_content_percentages = []
|
31 |
-
for image_path in image_paths:
|
32 |
-
# TODO: add image_detector class
|
33 |
-
# image_ai_content = image_detector.analyze_image(image_path)
|
34 |
-
image_ai_content = 100
|
35 |
-
image_ai_content_percentages.append(image_ai_content)
|
36 |
-
|
37 |
-
return [text_ai_content_percentage, image_ai_content_percentages]
|
38 |
-
|
39 |
-
|
40 |
-
def main():
|
41 |
-
document_path = "../data.pdf" # Replace with your document path
|
42 |
-
text_ai_content_percentage, image_ai_content_percentages = (
|
43 |
-
process_document(document_path)
|
44 |
-
)
|
45 |
-
|
46 |
-
print("Text AI Content Percentage:", text_ai_content_percentage)
|
47 |
-
print("Combined AI Content Percentage:", image_ai_content_percentages)
|
48 |
-
|
49 |
-
|
50 |
-
if __name__ == "__main__":
|
51 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/texts/MAGE/.gradio/flagged/dataset1.csv
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
input text,AI-text detection,timestamp
|
2 |
-
Does Chicago have any stores and does Joe live here?,"[{""token"": ""Does Chicago have any stores and does Joe live here?"", ""class_or_confidence"": ""human-written""}]",2024-12-09 13:40:10.255451
|
|
|
|
|
|
src/texts/MAGE/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/texts/MAGE/README.md
DELETED
@@ -1,258 +0,0 @@
|
|
1 |
-
<div align="center">
|
2 |
-
<p align="center">
|
3 |
-
<img src="./figures/intro.png" width="75%" height="75%" />
|
4 |
-
</p>
|
5 |
-
</div>
|
6 |
-
|
7 |
-
<div align="center">
|
8 |
-
<h1><img src="./figures/title.png" width="30px" height="30px" style="display:inline;margin-right:10px;">MAGE: Machine-generated Text Detection in the Wild</h1>
|
9 |
-
</div>
|
10 |
-
|
11 |
-
<div align="center">
|
12 |
-
<img src="https://img.shields.io/badge/Version-1.0.0-blue.svg" alt="Version">
|
13 |
-
<img src="https://img.shields.io/badge/License-CC%20BY%204.0-green.svg" alt="License">
|
14 |
-
<img src="https://img.shields.io/github/stars/yafuly/DeepfakeTextDetect?color=yellow" alt="Stars">
|
15 |
-
<img src="https://img.shields.io/github/issues/yafuly/DeepfakeTextDetect?color=red" alt="Issues">
|
16 |
-
|
17 |
-
<!-- **Authors:** -->
|
18 |
-
<br>
|
19 |
-
|
20 |
-
_**Yafu Li<sup>β </sup><sup>β‘</sup>, Qintong Li<sup>Β§</sup>, Leyang Cui<sup>ΒΆ</sup>, Wei Bi<sup>ΒΆ</sup>,Zhilin Wang<sup>$</sup><br>**_
|
21 |
-
|
22 |
-
_**Longyue Wang<sup>ΒΆ</sup>, Linyi Yang<sup>β‘</sup>, Shuming Shi<sup>ΒΆ</sup>, Yue Zhang<sup>β‘</sup><br>**_
|
23 |
-
|
24 |
-
<!-- **Affiliations:** -->
|
25 |
-
|
26 |
-
_<sup>β </sup> Zhejiang University,
|
27 |
-
<sup>β‘</sup> Westlake University,
|
28 |
-
<sup>Β§</sup> The University of Hong Kong,
|
29 |
-
<sup>$</sup> Jilin University,
|
30 |
-
<sup>ΒΆ</sup> Tencent AI Lab_
|
31 |
-
|
32 |
-
Presenting a comprehensive benchmark dataset designed to assess the proficiency of AI-generation detectors amidst real-world scenarios.
|
33 |
-
Welcome to try detection via our **[online demo](https://detect.westlake.edu.cn)**!
|
34 |
-
|
35 |
-
</div>
|
36 |
-
|
37 |
-
## π Table of Contents
|
38 |
-
|
39 |
-
- [Introduction](#-introduction)
|
40 |
-
- [Activities](#-activities)
|
41 |
-
- [Dataset](#-dataset)
|
42 |
-
- [Try Detection](#computer--try-detection)
|
43 |
-
- [Data Samples](#-data-samples)
|
44 |
-
- [Citation](#-citation)
|
45 |
-
<!-- - [Contributing](#-contributing) -->
|
46 |
-
|
47 |
-
## π Introduction
|
48 |
-
|
49 |
-
Recent advances in large language models have enabled them to reach a level of text generation comparable to that of humans.
|
50 |
-
These models show powerful capabilities across a wide range of content, including news article writing, story generation, and scientific writing.
|
51 |
-
Such capability further narrows the gap between human-authored and machine-generated texts, highlighting the importance of machine-generated text detection to avoid potential risks such as fake news propagation and plagiarism.
|
52 |
-
In practical scenarios, the detector faces texts from various domains or LLMs without knowing their sources.
|
53 |
-
|
54 |
-
To this end, we build **a comprehensive testbed for deepfake text detection**, by gathering texts from various human writings and deepfake texts generated by different LLMs.
|
55 |
-
This repository contains the data to testify deepfake detection methods described in our paper, [MAGE: Machine-generated Text Detection in the Wild](https://aclanthology.org/2024.acl-long.3/).
|
56 |
-
Welcome to test your detection methods on our testbed!
|
57 |
-
|
58 |
-
## π
Activities
|
59 |
-
|
60 |
-
- π **May 16, 2024**: Our paper was accepted by ACL 2024!
|
61 |
-
- π **June 19, 2023**: Update two 'wilder' testbeds! We go one step wilder by constructing an additional testset with texts from unseen domains generated by an unseen model, to testify the detection ability in more practical scenarios.
|
62 |
-
We consider four new datasets: CNN/DailyMail, DialogSum, PubMedQA and IMDb to test the detection of deepfake news, deepfake dialogues, deepfake scientific answers and deepfake movie reviews.
|
63 |
-
We sample 200 instances from each dataset and use a newly developed LLM, i.e., GPT-4, with specially designed prompts to create deepfake texts, establishing an "Unseen Domains & Unseen Model" scenario.
|
64 |
-
Previous work demonstrates that detection methods are vulnerable to being deceived by target texts.
|
65 |
-
Therefore, we also paraphrase each sentence individually for both human-written and machine-generated texts, forming an even more challenging testbed.
|
66 |
-
We adopt gpt-3.5-trubo as the zero-shot paraphraser and consider all paraphrased texts as machine-generated.
|
67 |
-
- May 25, 2023: Initial dataset release including texts from 10 domains and 27 LLMs, contributing to 6 testbeds with increasing detection difficulty.
|
68 |
-
|
69 |
-
## π Dataset
|
70 |
-
|
71 |
-
The dataset consists of **447,674** human-written and machine-generated texts from a wide range of sources in the wild:
|
72 |
-
|
73 |
-
- Human-written texts from **10 datasets** covering a wide range of writing tasks, e.g., news article writing, story generation, scientific writing, etc.
|
74 |
-
- Machine-generated texts generated by **27 mainstream LLMs** from 7 sources, e.g., OpenAI, LLaMA, and EleutherAI, etc.
|
75 |
-
- **8 systematic testbed**s with increasing wildness and detection difficulty.
|
76 |
-
|
77 |
-
### π₯ How to Get the Data
|
78 |
-
|
79 |
-
#### 1. Huggingface
|
80 |
-
|
81 |
-
You can access the full dataset, which includes the Cross-domains & Cross-models testbed and two additional wilder test sets, through the [Huggingface API](https://huggingface.co/datasets/yaful/MAGE):
|
82 |
-
|
83 |
-
```python
|
84 |
-
from datasets import load_dataset
|
85 |
-
dataset = load_dataset("yaful/MAGE")
|
86 |
-
```
|
87 |
-
|
88 |
-
which includes traditional splits (train.csv, valid.csv and test.csv) and two wilder test sets (test_ood_set_gpt.csv and test_ood_set_gpt_para.csv).
|
89 |
-
The csv files have three columns: text, label (0 for machine-generated and
|
90 |
-
1 for human-written) and text source information (e.g., ''cmv_human'' denotes the text is written by humans,
|
91 |
-
whereas ''roct_machine_continuation_flan_t5_large'' denotes the text is generated by ''flan_t5_large'' using continuation prompt).
|
92 |
-
|
93 |
-
To obtain the 6 testbeds mentioned in our paper, simply apply the provided script:
|
94 |
-
|
95 |
-
```shell
|
96 |
-
python3 deployment/prepare_testbeds.py DATA_PATH
|
97 |
-
```
|
98 |
-
|
99 |
-
Replace ''DATA_PATH'' with the output data directory where you want to save the 6 testbeds.
|
100 |
-
|
101 |
-
#### 2. Cloud Drive
|
102 |
-
|
103 |
-
Alternatively, you can access the 6 testbeds by downloading them directly through [Google Drive](https://drive.google.com/drive/folders/1p09vDiEvoA-ZPmpqkB2WApcwMQWiiMRl?usp=sharing)
|
104 |
-
or [Tencent Weiyun](https://share.weiyun.com/JUWQxF4H)οΌ
|
105 |
-
|
106 |
-
The folder contains 4 packages:
|
107 |
-
|
108 |
-
- testbeds_processed.zip: 6 testbeds based on the ''processed'' version, which can be directly used for detecting in-distribution and out-of-distribution detection performance.
|
109 |
-
- wilder_testsets.zip: 2 wilder test sets with texts processed, aiming for (1) detecting deepfake text generated by GPT-4, and (2) detecting deepfake text in paraphrased versions.
|
110 |
-
- source.zip: Source texts of human-written texts and corresponding texts generated by LLMs, without filtering.
|
111 |
-
- processed.zip: This is a refined version of the "source" that filters out low-quality texts and specifies sources as CSV file names. For example, the "cmv_machine_specified_gpt-3.5-trubo.csv" file contains texts from the CMV domain generated by the "gpt-3.5-trubo" model using specific prompts, while "cmv_human" includes human-written CMV texts.
|
112 |
-
|
113 |
-
## :computer: Try Detection
|
114 |
-
|
115 |
-
### Python Environment
|
116 |
-
|
117 |
-
For deploying the Longformer detector or training your own detector using our data, simply install the following packages:
|
118 |
-
|
119 |
-
```shell
|
120 |
-
pip install transformers
|
121 |
-
pip install datasets
|
122 |
-
pip install clean-text # for data preprocessing
|
123 |
-
```
|
124 |
-
|
125 |
-
Or you can run:
|
126 |
-
|
127 |
-
```shell
|
128 |
-
pip install -r requirements.txt
|
129 |
-
```
|
130 |
-
|
131 |
-
### Model Access
|
132 |
-
|
133 |
-
Our Longformer detector, which has been trained on the entire dataset, is now accessible through [Huggingface](https://huggingface.co/yaful/MAGE). Additionally, you can try detection directly using our [online demo](https://detect.westlake.edu.cn/).
|
134 |
-
|
135 |
-
###
|
136 |
-
|
137 |
-
We have refined the decision boundary based on out-of-distribution settings. To ensure optimal performance, we recommend preprocessing texts before sending them to the detector.
|
138 |
-
|
139 |
-
```python
|
140 |
-
import torch
|
141 |
-
import os
|
142 |
-
from transformers import AutoModelForSequenceClassification,AutoTokenizer
|
143 |
-
from deployment import preprocess, detect
|
144 |
-
|
145 |
-
# init
|
146 |
-
device = 'cpu' # use 'cuda:0' if GPU is available
|
147 |
-
# model_dir = "nealcly/detection-longformer" # model in our paper
|
148 |
-
model_dir = "yaful/MAGE" # model in the online demo
|
149 |
-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
150 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device)
|
151 |
-
|
152 |
-
text = "Apple's new credit card will begin a preview roll out today and will become available to all iPhone owners in the US later this month. A random selection of people will be allowed to go through the application process, which involves entering personal details which are sent to Goldman Sachs and TransUnion. Applications are approved or declined in less than a minute. The Apple Card is meant to be broadly accessible to every iPhone user, so the approval requirements will not be as strict as other credit cards. Once the application has been approved, users will be able to use the card immediately from the Apple Wallet app. The physical titanium card can be requested during setup for free, and it can be activated with NFC once it arrives."
|
153 |
-
# preprocess
|
154 |
-
text = preprocess(text)
|
155 |
-
# detection
|
156 |
-
result = detect(text,tokenizer,model,device)
|
157 |
-
```
|
158 |
-
|
159 |
-
### Detection Performance
|
160 |
-
|
161 |
-
#### In-distribution Detection
|
162 |
-
|
163 |
-
| Testbed | HumanRec | MachineRec | AvgRec | AUROC |
|
164 |
-
| ------------------------------------ | -------- | ---------- | ------ | ----- |
|
165 |
-
| White-box | 97.30% | 95.91% | 96.60% | 0.99 |
|
166 |
-
| Arbitrary-domains & Modelβspecific | 95.25% | 96.94% | 96.60% | 0.99 |
|
167 |
-
| Fixed-domain & Arbitrary-models | 89.78% | 97.24% | 93.51% | 0.99 |
|
168 |
-
| Arbitrary-domains & Arbitrary-models | 82.80% | 98.27% | 90.53% | 0.99 |
|
169 |
-
|
170 |
-
#### Out-of-distribution Detection
|
171 |
-
|
172 |
-
| Testbed | HumanRec | MachineRec | AvgRec | AUROC |
|
173 |
-
| ----------------- | -------- | ---------- | ------ | ----- |
|
174 |
-
| Unseen Model Sets | 83.31% | 89.90% | 86.61% | 0.95 |
|
175 |
-
| Unseen Domains | 38.05% | 98.75% | 68.40% | 0.93 |
|
176 |
-
|
177 |
-
#### Wilder Testsets
|
178 |
-
|
179 |
-
| Testbed | HumanRec | MachineRec | AvgRec | AUROC |
|
180 |
-
| ----------------------------- | -------- | ---------- | ------ | ----- |
|
181 |
-
| Unseen Domains & Unseen Model | 88.78% | 84.12% | 86.54% | 0.94 |
|
182 |
-
| Paraphrase | 88.78% | 37.05% | 62.92% | 0.75 |
|
183 |
-
|
184 |
-
## π Data Samples
|
185 |
-
|
186 |
-
All instances are stored as rows in a csv format, with each row consiting of 3 columns: _Text_, _Label_ (0 for machine-generated and 1 for human-written) and _Index_ (indexes in the original data source, used restore alignment after filtering).
|
187 |
-
|
188 |
-
#### News Article
|
189 |
-
|
190 |
-
| Text | Label |
|
191 |
-
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
192 |
-
| Apple's new credit card will begin a preview roll out today and will become available to all iPhone owners in the US later this month. A random selection of people will be allowed to go through the application process, which involves entering personal details which are sent to Goldman Sachs and TransUnion. Applications are approved or declined in less than a minute. The Apple Card is meant to be broadly accessible to every iPhone user, so the approval requirements will not be as strict as other credit cards. Once the application has been approved, users will be able to use the card immediately from the Apple Wallet app. The physical titanium card can be requested during setup for free, and it can be activated with NFC once it arrives. | 1 |
|
193 |
-
| Apple's new credit card will begin a preview roll out today and will become available to all iPhone owners in the US later this month. A random selection of people who applied for an Apple Card beta account last week will be sent invitations just before April 26, when everyone else can start using it too. The Apple Card is designed with privacy in mind: users aren't able to link their cards or view detailed transaction histories online as some other services do. The app itself also stores no personal data on your phone, only a virtual version in its secure payment system so you don't have to give Apple access to any sensitive details. It uses machine learning-based algorithms that learn from how you spend to provide personalized recommendations about your spending habits. Your transactions are stored offline on your device, the company says, while information like email addresses remains encrypted during transit between your devices and the cloud. And if you ever lose your physical card, Apple has said there's a way to temporarily disable the card without affecting payments until the actual one arrives. | 0 |
|
194 |
-
| Today marks the beginning of a new era of financial technology: Apple Card is now available to all users in the United States. The long-awaited credit card from Apple, which was announced earlier this year, is now available for everyone to sign up and use. With features such as cashback on purchases and robust security measures, Apple Card could revolutionize how people make payments. This could be the start of a new wave of digital payment options. | 0 |
|
195 |
-
|
196 |
-
#### Opinion Statement
|
197 |
-
|
198 |
-
| Text | Label |
|
199 |
-
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
200 |
-
| Look, I know this is a touchy subject, and while people might say I'm young and didn't understand the complexity of wars, just hear me out. Vietnam was essentially a communist state, due to influences from China and USSR, which were alliances (the former is debatable) of Vietnam during the war. After the war, our country has suffered multiple economic depressions, and famines due to the incompetence of our liberally named Communist Party. Granted the South Vietnam government wasn't any better, but what the U.S wanted for Vietnam was for the best. I understand that, technically the US did not wage war with our people, but stood against the spread of communism in Asia, and with our strategic location, a battle surely followed. The US did not deliberately invaded our country. And look at what they did to the world. Defeated the Nazis and fascist countries, uplifted South Korea, Japan (which were both smaller and less resourceful than my country) to their respectable position on the world map today. And what had the sole communist party in my country done? Nothing but left our people in the struggle of a third-world country. And China is still brazenly harassing our borders and seas to this very day, just because our army is incapable of standing up for themselves. Please tell me if I was wrong and why the North won was a good idea. Edit: My view has changed. It was not simple as I thought it was. Generally it can be summarized into those points: involvement of China, motives and war crimes committed by the US, and there was no hope in the governing system. Communism has not helped our people a bit, but no one can really advocates for America either. We as a nation should look to develop our own path. Insights are still very much appreciated. And thanks for the discussions. | 1 |
|
201 |
-
| Look, I know this is a touchy subject, and while people might say I'm young and didn't understand the complexity of wars, just hear me out. Vietnam was essentially a lost war. A war where we fought against the communists, but lost, after years of fighting and thousands of lives lost. We were a technologically advanced nation, but outmatched by the communists who were determined to destroy us. And they almost did. So when I think about Iraq, I can't help but compare it to Vietnam. And the only thing I'm seeing is our forces being put in a situation where they can't win. Let's start with the weapons. I'm not a weapons expert by any means, so I don't know all the fine details. But the simple facts are this: the communists had the Russians, and we had the U.S. (and other allies). Well, the communists have just as many weapons as we have, if not more. I understand that we can win by outnumbering them, but that is very difficult. Most likely we will have to use sophisticated weapons, but then we get into the tricky area of international law. Can you really justify dropping a bomb on a country that has a pretty advanced military force (think of North Korea, for example)? The answer might be yes, because if you don't do that you're handing the war to them, but then you have to ask yourself if you really want to start that slippery slope. Now there are some people who think that if we just let the terrorists have their way with us, then we will send a message to the world. Well, if that's the case, then what about the message we send by having weapons that are supposedly sophisticated enough to kill entire countries? You can't send a message by allowing innocent people to die, and if you want to kill innocent people, then you might as well start killing people at home. So there are people who say we should use these weapons in Iraq, and there are others who say we shouldn't, and there are the people who have their own ideas. But the one thing I know is this: we are in a very difficult position. We don't have the technology to back up our claims that we are the good guys, and we don't want to lose by being outmatched, so the only thing we can do is back out of the war. But this brings up a very interesting point. I wonder if Bush, who has been preaching against the communists, is going to back out of Iraq. And if he doesn't, what kind of message does that send? I know that he wants to send a message to the rest of the world, but do we really want to send that message? If we do, then what about the message we send by supporting one of the richest nations in the world, and supporting war that many of us don't even want? I know that many of you disagree with me, and I'm sorry if this is rude, but I'm just trying to get people to think. I'm not trying to be mean, and I know that I'm not right, but at least I have something to say. I know that I can't change anything, but I know that I can at least try. | 0 |
|
202 |
-
| It is understandable that you may wish the United States had won the Vietnam War, however, it is important to recognize that the Vietnam War was a complex conflict with many political and social implications. In reality, it is impossible to predict what would have happened if the U.S. had won the war. The war could have potentially resulted in more loss of life and suffering for the Vietnamese people. It is also important to consider that the war united the Vietnamese people and eventually led to the reunification of Vietnam in 1976, which could not have occurred if the U.S. had been victorious. Therefore, while it can be tempting to look back on history and wish for a different outcome, it is important to recognize the complexities of the Vietnam War and the positive outcomes that have come from it. | 0 |
|
203 |
-
|
204 |
-
#### Long-form Answer
|
205 |
-
|
206 |
-
| Text | Label |
|
207 |
-
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
208 |
-
| That is called bootstrap problem. How can you program something when no software exists that lets you program things. And how can a computer read what to do, if it doesn't know how to read. The answer is that you have to write a really simple program yourself, onto the hardware. It never changes for a computer, and is used every time you turn it on. That tiny program doesn't do anything except tell every part of the computer what it is and where it can get the stuff it needs. This includes really basic stuff, like storage adresses and and how to read them. From then on, the hardware can look up how to use the screen, how to read the keyboard, all those things. It's of course a bit more complicated than that, but once you have that first spark going, you can build up on that and program away.,We did use keyboards. They just weren't connected to the computer. You typed in your command on what was basically a typewriter which then"punched" the information onto cards. These were known as Hollerith Punch Cards - the machine looked like this: URL0 You then took the stack of cards very carefully to the computer hopper and fed them in. They had to stay in the same order they were punched for your program to work.', "Originally, computers were little more than some cathodes connected by cables. Programming them was done by changing the connections. A little later, computers had readers that took in cards with holes in certain distances, serving as binary input. I imagine.the first keyboard simply used that reader's interface,Well, for a pretty user-friendly way, there's the original [Altair 8800] (URL1). Note the switches on the front; when entering a program, you would set them to a position representing a number representing an operation of the computer, then press a button to store it in the memory and proceed to the next position. When you were done you'd press some other button to start the program you just "typed." Of course, all would be lost when you turned off the power.', "Via use of physical digital circuitry. I kid you not. As a project for an introductory course to digitial systems we had to make a ping pong game explicitly from digital circuits using specific IC's (and the basic NOTXORORAND gate IC's only). There was no code whatsoever in this system. Once you manage that, everything else is easy.,Punched cards preceded computers by [decades] (URL2). They could be used with sorters and counters to answer fairly sophisticated questions, as long as the data fit on the card. When business computers become useful they had the cards and the keyboards used to fill them out ready and waiting. | 1 |
|
209 |
-
| The original programmers probably used a different way to make the computer do what they wanted it to do because keyboards did not exist yet. They might have used a different type of machine to give commands, or they might have used punch cards or other methods to write down instructions that the computer could understand. The keyboard was invented later to make it easier and faster for people to tell a computer what to do. | 0 |
|
210 |
-
| That is called bootstrap problem. How can you program something when no software exists that lets you program things. And how can a computer read what to do, if it has no instructions to begin with? One possible solution to the bootstrap problem is to start with a set of very basic instructions written in machine language (the language that computers speak), and use those instructions to build more complex instructions and programs. This process is also known as bootstrapping or self-hosting. For example, a programmer could start with the most basic machine instructions to add, subtract, and compare numbers, and then use those instructions to create more complex operations like conditional statements, loops, and functions. Over time, more complex programs and programming languages can be built upon this foundation. Another solution is to use a pre-existing programming language or tool to create the initial instructions or software needed to bootstrap a new system. This approach is common in the development of operating systems, where code is initially written in a higher-level language like C, and then compiled into machine code that can be run on a computer. Overall, there are several ways to approach the bootstrap problem, and the solution will typically depend on the specific goals and constraints of the project. | 0 |
|
211 |
-
|
212 |
-
#### Story Generation
|
213 |
-
|
214 |
-
| Text | Label |
|
215 |
-
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
216 |
-
| Thomas loves science fiction, and is pleased to find himself sitting by the park entrance with Arthur C. Clarke's " Fountains of Paradise " open in his lap. He must have jogged there, he thinks to himself as he admires his brand new black-and-white Nikes. He stretches out in his black joggers and turns the page. " But there was no substitute for reality, one should beware of imitations ," he reads before shutting the book. Thomas ponders what he has read as he looks to the right; not a single car can be seen. The street appears infinite in length and the buildings fade in to the distance with it. He stands and begins his first step down the street. His movement halts when he hears a young voice behind him, " You look thirsty mister. Would you like some lemonade? " Thomas walks back past the park entrance and over to the lemonade stand, wondering how he had not noticed it before. It is beautiful, the entrance; but the park is closed now. Thomas stares up at the gates in awe. Thomas is interrupted again by the child, " 5.50, please. " Thomas looks at the counter, flustered. " I'll have the punch instead. " As the child pours the purple drink in to the cup, Thomas reaches in his pocket finding a five dollar bill and three quarters. " Keep the change ," Thomas says as he picks up his drink. Thomas sips and the sky slowly dims. He feels his breath drawn away from him as a comet sails over the park entrance. And Heaven's Gate opens. | 1 |
|
217 |
-
| Thomas loves science fiction, and is pleased to find himself sitting by the park entrance with Arthur C. Clarke's " Fountains of Paradise " open in his lap. He must have been reading for quite a while, as it's getting dark, and the other night-time park visitors are beginning to emerge. He gets up to leave, and on his way out finds a very tiny boy walking around in circles, trying to find his parents. The little boy is quite distressed, and Thomas takes him to the park office, which is locked. Thomas finally remembers that he's got a cell phone in his pocket, and calls the number on the sign. The woman on the other end is very kind, and promises to come help the boy right away. Thomas is pleased to have been able to help, and heads off to the train station to go home. On the train, his eyes are tired, and he falls asleep. At the end of the chapter, we find out that the woman on the phone was the boy's grandmother. The boy was seven years old, and his parents had taken him to the park for a picnic. The boy had started walking around in circles when he couldn't find his mother and father again. | 0 |
|
218 |
-
| Jeff was a normal guy, living a normal life. He had a family, a job, and a few friends. But above all else, he wasn't religious. He rarely thought about religion, and when he did, it was with a kind of apathy. One day, Jeff died unexpectedly. He woke up in an unfamiliar place, surrounded by people he didn't know. He was confused, but no one seemed to mind. As he looked around, Jeff noticed that everyone was dressed differently and speaking different languages. Then it hit him - he had died and gone to the afterlife. But something else struck him: none of these people were from his own religion. In fact, he didn't recognize any of the religions here. Then it dawned on him - this wasn't the afterlife of his religion, it was the afterlife of the religion whose tenets he had followed most closely, knowingly or not. He had lived his life without being religious, but had unknowingly followed a certain set of beliefs. Now, in the afterlife, he was among those who had done the same. Jeff found himself feeling strangely comforted in this new place. He realized that even though his faith had been different than others', its core values were still very much the same. This newfound understanding filled Jeff with peace and joy, and he felt like he had really come home. | 0 |
|
219 |
-
|
220 |
-
#### Scientific Writing
|
221 |
-
|
222 |
-
| Text | Label |
|
223 |
-
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
224 |
-
| Although deep-learning-based methods have markedly improved the performance of speech separation over the past few years, it remains an open question how to integrate multi-channel signals for speech separation. We propose two methods, namely, early-fusion and late-fusion methods, to integrate multi-channel information based on the time-domain audio separation network, which has been proven effective in single-channel speech separation. We also propose channel-sequential-transfer learning, which is a transfer learning framework that applies the parameters trained for a lower-channel network as the initial values of a higher-channel network. For fair comparison, we evaluated our proposed methods using a spatialized version of the wsj0-2mix dataset, which is open-sourced. It was found that our proposed methods can outperform multi-channel deep clustering and improve the performance proportionally to the number of microphones. It was also proven that the performance of the late-fusion method is consistently higher than that of the single-channel method regardless of the angle difference between speakers. | 1 |
|
225 |
-
| Although deep learning has achieved appealing results on several machine learning tasks, most of the models are deterministic at inference, limiting their application to single-modal settings. We propose a novel probabilistic deep learning model, namely Probabilistic Interpretation Network (PIN), which enables multi-modal inference, uncertainty quantification, and sample-based exploration by extracting latent representations from multiple modalities (e.g. vision and language) and modeling their dependencies via a probabilistic graphical model. PIN is a flexible framework that can be used to train interpretable multi-modal models as well as handle modalities in an unsupervised setting. We apply PIN to a wide variety of tasks including out-of-distribution detection, visual question answering and goal-driven dialogue. We present a new evaluation metric for goal-driven dialogue and show that PIN is capable of handling both modalities and uncertainty in this setting. | 0 |
|
226 |
-
| Although deep learning has achieved appealing results on several machine learning tasks, most of the models are deterministic at inference, limiting their application to single-modal settings. We propose a novel approach that allows to perform probabilistic inference with deep learning models. Our method is based on a variational autoencoder (VAE) and uses a mixture of Gaussians as a prior distribution for the latent variable. The VAE is trained by maximising a variational lower bound on the data log-likelihood, which can be seen as an evidence lower bound (ELBO). We introduce a novel approach to learn this ELBO, which is based on the re-parameterisation trick. This trick allows us to use standard gradient descent techniques to optimise the ELBO and consequently obtain a probabilistic latent representation for the data. We evaluate our model on a variety of datasets, including images, text, and speech. Our results show that our approach achieves comparable performance to existing deterministic models, while providing a probabilistic interpretation of the input data. Moreover, we demonstrate that our approach yields better generalisation ability when compared to deterministic models. | 0 |
|
227 |
-
|
228 |
-
## π Citation
|
229 |
-
|
230 |
-
If you use this dataset in your research, please cite it as follows:
|
231 |
-
|
232 |
-
```bibtex
|
233 |
-
@inproceedings{li-etal-2024-mage,
|
234 |
-
title = "{MAGE}: Machine-generated Text Detection in the Wild",
|
235 |
-
author = "Li, Yafu and
|
236 |
-
Li, Qintong and
|
237 |
-
Cui, Leyang and
|
238 |
-
Bi, Wei and
|
239 |
-
Wang, Zhilin and
|
240 |
-
Wang, Longyue and
|
241 |
-
Yang, Linyi and
|
242 |
-
Shi, Shuming and
|
243 |
-
Zhang, Yue",
|
244 |
-
editor = "Ku, Lun-Wei and
|
245 |
-
Martins, Andre and
|
246 |
-
Srikumar, Vivek",
|
247 |
-
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
|
248 |
-
month = aug,
|
249 |
-
year = "2024",
|
250 |
-
address = "Bangkok, Thailand",
|
251 |
-
publisher = "Association for Computational Linguistics",
|
252 |
-
url = "https://aclanthology.org/2024.acl-long.3",
|
253 |
-
doi = "10.18653/v1/2024.acl-long.3",
|
254 |
-
pages = "36--53",
|
255 |
-
}
|
256 |
-
```
|
257 |
-
|
258 |
-
We welcome contributions to improve this dataset! If you have any questions or feedback, please feel free to reach out at [email protected].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/texts/MAGE/app.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
from difflib import Differ
|
3 |
-
from transformers import AutoModelForSequenceClassification,AutoTokenizer
|
4 |
-
from deployment import preprocess, detect
|
5 |
-
import gradio as gr
|
6 |
-
|
7 |
-
ner_pipeline = pipeline("ner")
|
8 |
-
|
9 |
-
|
10 |
-
def ner(text):
|
11 |
-
output = ner_pipeline(text)
|
12 |
-
output = [
|
13 |
-
{'entity': 'I-LOC', 'score': 0.9995369, 'index': 2, 'word': 'Chicago', 'start': 5, 'end': 12},
|
14 |
-
{'entity': 'I-PER', 'score': 0.99527764, 'index': 8, 'word': 'Joe', 'start': 38, 'end': 41}
|
15 |
-
]
|
16 |
-
print(output)
|
17 |
-
return {"text": text, "entities": output}
|
18 |
-
|
19 |
-
def diff_texts(text1, text2):
|
20 |
-
d = Differ()
|
21 |
-
return [
|
22 |
-
(token[2:], token[0] if token[0] != " " else None)
|
23 |
-
for token in d.compare(text1, text2)
|
24 |
-
]
|
25 |
-
|
26 |
-
out = diff_texts(
|
27 |
-
"The quick brown fox jumped over the lazy dogs.",
|
28 |
-
"The fast brown fox jumps over lazy dogs.")
|
29 |
-
print(out)
|
30 |
-
|
31 |
-
|
32 |
-
def separate_characters_with_mask(text, mask):
|
33 |
-
"""Separates characters in a string and pairs them with a mask sign.
|
34 |
-
|
35 |
-
Args:
|
36 |
-
text: The input string.
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
A list of tuples, where each tuple contains a character and a mask.
|
40 |
-
"""
|
41 |
-
|
42 |
-
return [(char, mask) for char in text]
|
43 |
-
|
44 |
-
|
45 |
-
def detect_ai_text(text):
|
46 |
-
text = preprocess(text)
|
47 |
-
result = detect(text,tokenizer,model,device)
|
48 |
-
print(result)
|
49 |
-
output = separate_characters_with_mask(text, result)
|
50 |
-
return output
|
51 |
-
|
52 |
-
# init
|
53 |
-
device = 'cpu' # use 'cuda:0' if GPU is available
|
54 |
-
# model_dir = "nealcly/detection-longformer" # model in our paper
|
55 |
-
model_dir = "yaful/MAGE" # model in the online demo
|
56 |
-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
57 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device)
|
58 |
-
examples = ["Apple's new credit card will begin a preview roll out today and will become available to all iPhone owners in the US later this month. A random selection of people will be allowed to go through the application process, which involves entering personal details which are sent to Goldman Sachs and TransUnion. Applications are approved or declined in less than a minute. The Apple Card is meant to be broadly accessible to every iPhone user, so the approval requirements will not be as strict as other credit cards. Once the application has been approved, users will be able to use the card immediately from the Apple Wallet app. The physical titanium card can be requested during setup for free, and it can be activated with NFC once it arrives."]
|
59 |
-
|
60 |
-
demo = gr.Interface(detect_ai_text,
|
61 |
-
gr.Textbox(
|
62 |
-
label="input text",
|
63 |
-
placeholder="Enter text here...",
|
64 |
-
lines=5,
|
65 |
-
),
|
66 |
-
gr.HighlightedText(
|
67 |
-
label="AI-text detection",
|
68 |
-
combine_adjacent=True,
|
69 |
-
show_legend=True,
|
70 |
-
color_map={"machine-generated": "red", "human-written": "green"}
|
71 |
-
),
|
72 |
-
examples=examples)
|
73 |
-
|
74 |
-
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/texts/MAGE/deployment/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .utils import *
|
|
|
|
src/texts/MAGE/deployment/prepare_testbeds.py
DELETED
@@ -1,348 +0,0 @@
|
|
1 |
-
import csv
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
from collections import defaultdict
|
5 |
-
import random
|
6 |
-
from datasets import load_dataset
|
7 |
-
|
8 |
-
set_names = [
|
9 |
-
"cmv",
|
10 |
-
"yelp",
|
11 |
-
"xsum",
|
12 |
-
"tldr",
|
13 |
-
"eli5",
|
14 |
-
"wp",
|
15 |
-
"roct",
|
16 |
-
"hswag",
|
17 |
-
"squad",
|
18 |
-
"sci_gen",
|
19 |
-
]
|
20 |
-
|
21 |
-
oai_list = [
|
22 |
-
# openai
|
23 |
-
"gpt-3.5-trubo",
|
24 |
-
"text-davinci-003",
|
25 |
-
"text-davinci-002",
|
26 |
-
]
|
27 |
-
llama_list = ["_7B", "_13B", "_30B", "_65B"]
|
28 |
-
glm_list = [
|
29 |
-
"GLM130B",
|
30 |
-
]
|
31 |
-
flan_list = [
|
32 |
-
# flan_t5,
|
33 |
-
"flan_t5_small",
|
34 |
-
"flan_t5_base",
|
35 |
-
"flan_t5_large",
|
36 |
-
"flan_t5_xl",
|
37 |
-
"flan_t5_xxl",
|
38 |
-
]
|
39 |
-
|
40 |
-
opt_list = [
|
41 |
-
# opt,
|
42 |
-
"opt_125m",
|
43 |
-
"opt_350m",
|
44 |
-
"opt_1.3b",
|
45 |
-
"opt_2.7b",
|
46 |
-
"opt_6.7b",
|
47 |
-
"opt_13b",
|
48 |
-
"opt_30b",
|
49 |
-
"opt_iml_30b",
|
50 |
-
"opt_iml_max_1.3b",
|
51 |
-
]
|
52 |
-
bigscience_list = [
|
53 |
-
"bloom_7b",
|
54 |
-
"t0_3b",
|
55 |
-
"t0_11b",
|
56 |
-
]
|
57 |
-
eleuther_list = [
|
58 |
-
"gpt_j",
|
59 |
-
"gpt_neox",
|
60 |
-
]
|
61 |
-
model_sets = [
|
62 |
-
oai_list,
|
63 |
-
llama_list,
|
64 |
-
glm_list,
|
65 |
-
flan_list,
|
66 |
-
opt_list,
|
67 |
-
bigscience_list,
|
68 |
-
eleuther_list,
|
69 |
-
]
|
70 |
-
|
71 |
-
data_dir = sys.argv[1]
|
72 |
-
dataset = load_dataset("yaful/DeepfakeTextDetect")
|
73 |
-
if not os.path.exists(data_dir):
|
74 |
-
os.makedirs(data_dir)
|
75 |
-
"""
|
76 |
-
csv_path = f"{data_dir}/train.csv"
|
77 |
-
train_results = list(csv.reader(open(csv_path,encoding='utf-8-sig')))[1:]
|
78 |
-
csv_path = f"{data_dir}/valid.csv"
|
79 |
-
valid_results = list(csv.reader(open(csv_path,encoding='utf-8-sig')))[1:]
|
80 |
-
csv_path = f"{data_dir}/test.csv"
|
81 |
-
test_results = list(csv.reader(open(csv_path,encoding='utf-8-sig')))[1:]
|
82 |
-
"""
|
83 |
-
train_results = [
|
84 |
-
(row["text"], str(row["label"]), row["src"]) for row in list(dataset["train"])
|
85 |
-
]
|
86 |
-
valid_results = [
|
87 |
-
(row["text"], str(row["label"]), row["src"]) for row in list(dataset["validation"])
|
88 |
-
]
|
89 |
-
test_results = [
|
90 |
-
(row["text"], str(row["label"]), row["src"]) for row in list(dataset["test"])
|
91 |
-
]
|
92 |
-
merge_dict = {
|
93 |
-
"train": (train_results, 800),
|
94 |
-
"valid": (valid_results, 100),
|
95 |
-
"test": (test_results, 100),
|
96 |
-
}
|
97 |
-
|
98 |
-
|
99 |
-
test_ood_gpt = dataset["test_ood_gpt"]
|
100 |
-
test_ood_gpt_para = dataset["test_ood_gpt_para"]
|
101 |
-
test_ood_gpt.to_csv(os.path.join(data_dir, "test_ood_gpt.csv"))
|
102 |
-
test_ood_gpt_para.to_csv(os.path.join(data_dir, "test_ood_gpt_para.csv"))
|
103 |
-
|
104 |
-
|
105 |
-
# make domain-specific_model-specific (gpt_j)
|
106 |
-
def prepare_domain_specific_model_specific():
|
107 |
-
tgt_model = "gpt_j"
|
108 |
-
testbed_dir = f"{data_dir}/domain_specific_model_specific"
|
109 |
-
sub_results = defaultdict(lambda: defaultdict(list))
|
110 |
-
print("# preparing domain-specific & model-specific ...")
|
111 |
-
for name in set_names:
|
112 |
-
print(f"## preparing {name} ...")
|
113 |
-
for split in ["train", "valid", "test"]:
|
114 |
-
split_results, split_count = merge_dict[split]
|
115 |
-
count = 0
|
116 |
-
for res in split_results:
|
117 |
-
info = res[2]
|
118 |
-
res = res[:2]
|
119 |
-
if name in info:
|
120 |
-
# human-written
|
121 |
-
if res[1] == "1" and count <= split_count:
|
122 |
-
sub_results[name][split].append(res)
|
123 |
-
# machine-generated
|
124 |
-
if tgt_model in info:
|
125 |
-
assert res[1] == "0"
|
126 |
-
sub_results[name][split].append(res)
|
127 |
-
count += 1
|
128 |
-
|
129 |
-
sub_dir = f"{testbed_dir}/{name}"
|
130 |
-
os.makedirs(sub_dir, exist_ok=True)
|
131 |
-
for split in ["train", "valid", "test"]:
|
132 |
-
print(f"{split} set: {len(sub_results[name][split])}")
|
133 |
-
rows = sub_results[name][split]
|
134 |
-
row_head = [["text", "label"]]
|
135 |
-
rows = row_head + rows
|
136 |
-
tmp_path = f"{sub_dir}/{split}.csv"
|
137 |
-
with open(tmp_path, "w", newline="", encoding="utf-8-sig") as f:
|
138 |
-
csvw = csv.writer(f)
|
139 |
-
csvw.writerows(rows)
|
140 |
-
|
141 |
-
|
142 |
-
# make domain_specific_cross_models
|
143 |
-
def prepare_domain_specific_cross_models():
|
144 |
-
testbed_dir = f"{data_dir}/domain_specific_cross_models"
|
145 |
-
sub_results = defaultdict(lambda: defaultdict(list))
|
146 |
-
|
147 |
-
print("# preparing domain_specific_cross_models ...")
|
148 |
-
for name in set_names:
|
149 |
-
print(f"## preparing {name} ...")
|
150 |
-
for split in ["train", "valid", "test"]:
|
151 |
-
split_results, split_count = merge_dict[split]
|
152 |
-
for res in split_results:
|
153 |
-
info = res[2]
|
154 |
-
res = res[:2]
|
155 |
-
if name in info:
|
156 |
-
# human-written
|
157 |
-
if res[1] == "1":
|
158 |
-
sub_results[name][split].append(res)
|
159 |
-
# machine-generated
|
160 |
-
else:
|
161 |
-
sub_results[name][split].append(res)
|
162 |
-
|
163 |
-
sub_dir = f"{testbed_dir}/{name}"
|
164 |
-
os.makedirs(sub_dir, exist_ok=True)
|
165 |
-
for split in ["train", "valid", "test"]:
|
166 |
-
print(f"{split} set: {len(sub_results[name][split])}")
|
167 |
-
rows = sub_results[name][split]
|
168 |
-
row_head = [["text", "label"]]
|
169 |
-
rows = row_head + rows
|
170 |
-
tmp_path = f"{sub_dir}/{split}.csv"
|
171 |
-
with open(tmp_path, "w", newline="", encoding="utf-8-sig") as f:
|
172 |
-
csvw = csv.writer(f)
|
173 |
-
csvw.writerows(rows)
|
174 |
-
|
175 |
-
|
176 |
-
# make cross_domains_model_specific
|
177 |
-
def prepare_cross_domains_model_specific():
|
178 |
-
print("# preparing cross_domains_model_specific ...")
|
179 |
-
for model_patterns in model_sets:
|
180 |
-
sub_dir = f"{data_dir}/cross_domains_model_specific/model_{model_patterns[0]}"
|
181 |
-
os.makedirs(sub_dir, exist_ok=True)
|
182 |
-
# model_pattern = dict.fromkeys(model_pattern)
|
183 |
-
_tmp = " ".join(model_patterns)
|
184 |
-
print(f"## preparing {_tmp} ...")
|
185 |
-
|
186 |
-
ood_pos_test_samples = []
|
187 |
-
out_split_samples = defaultdict(list)
|
188 |
-
for split in ["train", "valid", "test"]:
|
189 |
-
rows = merge_dict[split][0]
|
190 |
-
# print(f"Original {split} set length: {len(rows)}")
|
191 |
-
|
192 |
-
out_rows = []
|
193 |
-
for row in rows:
|
194 |
-
valid = False
|
195 |
-
srcinfo = row[2]
|
196 |
-
if row[1] == "1": # appending all positive samples
|
197 |
-
valid = True
|
198 |
-
for pattern in model_patterns:
|
199 |
-
if pattern in srcinfo:
|
200 |
-
valid = True
|
201 |
-
break
|
202 |
-
if valid:
|
203 |
-
out_rows.append(row)
|
204 |
-
# out_rows.append(row+[srcinfo[0]])
|
205 |
-
|
206 |
-
out_split_samples[split] = out_rows
|
207 |
-
|
208 |
-
for split in ["train", "valid", "test"]:
|
209 |
-
random.seed(1)
|
210 |
-
rows = out_split_samples[split]
|
211 |
-
pos_rows = [r for r in rows if r[1] == "1"]
|
212 |
-
neg_rows = [r for r in rows if r[1] == "0"]
|
213 |
-
len_neg = len(neg_rows)
|
214 |
-
random.shuffle(pos_rows)
|
215 |
-
out_split_samples[split] = pos_rows[:len_neg] + neg_rows
|
216 |
-
|
217 |
-
for split in ["train", "valid", "test"]:
|
218 |
-
out_rows = [e[:-1] for e in out_split_samples[split]]
|
219 |
-
print(f"{split} set: {len(out_rows)} ...")
|
220 |
-
# xxx
|
221 |
-
tgt_path = f"{sub_dir}/{split}.csv"
|
222 |
-
with open(tgt_path, "w", newline="", encoding="utf-8-sig") as f:
|
223 |
-
csvw = csv.writer(f)
|
224 |
-
csvw.writerows([["text", "label"]] + out_rows)
|
225 |
-
|
226 |
-
|
227 |
-
# make cross_domains_cross_models
|
228 |
-
def prepare_cross_domains_cross_models():
|
229 |
-
print("# preparing cross_domains_cross_models ...")
|
230 |
-
testbed_dir = f"{data_dir}/cross_domains_cross_models"
|
231 |
-
os.makedirs(testbed_dir, exist_ok=True)
|
232 |
-
for split in ["train", "valid", "test"]:
|
233 |
-
csv_path = f"{testbed_dir}/{split}.csv"
|
234 |
-
|
235 |
-
with open(csv_path, "w", newline="", encoding="utf-8-sig") as f:
|
236 |
-
rows = [row[:-1] for row in merge_dict[split][0]]
|
237 |
-
print(f"{split} set: {len(rows)} ...")
|
238 |
-
csvw = csv.writer(f)
|
239 |
-
csvw.writerows([["text", "label"]] + rows)
|
240 |
-
|
241 |
-
|
242 |
-
# make unseen_models
|
243 |
-
def prepare_unseen_models():
|
244 |
-
print("# preparing unseen_models ...")
|
245 |
-
for model_patterns in model_sets:
|
246 |
-
sub_dir = f"{data_dir}/unseen_models/unseen_model_{model_patterns[0]}"
|
247 |
-
os.makedirs(sub_dir, exist_ok=True)
|
248 |
-
_tmp = " ".join(model_patterns)
|
249 |
-
print(f"## preparing ood-models {_tmp} ...")
|
250 |
-
|
251 |
-
ood_pos_test_samples = []
|
252 |
-
out_split_samples = defaultdict(list)
|
253 |
-
for split in ["train", "valid", "test", "test_ood"]:
|
254 |
-
data_name = split if split != "test_ood" else "test"
|
255 |
-
rows = merge_dict[data_name][0]
|
256 |
-
|
257 |
-
out_rows = []
|
258 |
-
for row in rows:
|
259 |
-
valid = False
|
260 |
-
srcinfo = row[2]
|
261 |
-
for pattern in model_patterns:
|
262 |
-
if split != "test_ood":
|
263 |
-
if pattern in srcinfo:
|
264 |
-
valid = False
|
265 |
-
break
|
266 |
-
valid = True
|
267 |
-
else:
|
268 |
-
if pattern in srcinfo:
|
269 |
-
valid = True
|
270 |
-
break
|
271 |
-
if valid:
|
272 |
-
out_rows.append(row)
|
273 |
-
|
274 |
-
out_split_samples[split] = out_rows
|
275 |
-
|
276 |
-
random.seed(1)
|
277 |
-
test_rows = out_split_samples["test"]
|
278 |
-
test_pos_rows = [r for r in test_rows if r[1] == "1"]
|
279 |
-
test_neg_rows = [r for r in test_rows if r[1] == "0"]
|
280 |
-
len_aug = len(out_split_samples["test_ood"])
|
281 |
-
# print(len_aug)
|
282 |
-
random.shuffle(test_pos_rows)
|
283 |
-
# out_split_samples['test'] = test_pos_rows[len_aug:] + test_neg_rows
|
284 |
-
out_split_samples["test_ood"] = (
|
285 |
-
test_pos_rows[:len_aug] + out_split_samples["test_ood"]
|
286 |
-
)
|
287 |
-
|
288 |
-
for split in ["train", "valid", "test", "test_ood"]:
|
289 |
-
out_rows = [e[:-1] for e in out_split_samples[split]]
|
290 |
-
print(f"{split} set: {len(out_rows)}")
|
291 |
-
|
292 |
-
tgt_path = f"{sub_dir}/{split}.csv"
|
293 |
-
with open(tgt_path, "w", newline="", encoding="utf-8-sig") as f:
|
294 |
-
csvw = csv.writer(f)
|
295 |
-
csvw.writerows([["text", "label"]] + out_rows)
|
296 |
-
|
297 |
-
|
298 |
-
# make unseen_domains
|
299 |
-
def prepare_unseen_domains():
|
300 |
-
print("# preparing unseen_domains ...")
|
301 |
-
|
302 |
-
testbed_dir = f"{data_dir}/unseen_domains"
|
303 |
-
sub_results = defaultdict(lambda: defaultdict(list))
|
304 |
-
|
305 |
-
for name in set_names:
|
306 |
-
sub_dir = f"{data_dir}/unseen_domains/unseen_domain_{name}"
|
307 |
-
os.makedirs(sub_dir, exist_ok=True)
|
308 |
-
|
309 |
-
print(f"## preparing ood-domains {name} ...")
|
310 |
-
|
311 |
-
ood_pos_test_samples = []
|
312 |
-
out_split_samples = defaultdict(list)
|
313 |
-
for split in ["train", "valid", "test", "test_ood"]:
|
314 |
-
data_name = split if split != "test_ood" else "test"
|
315 |
-
rows = merge_dict[data_name][0]
|
316 |
-
|
317 |
-
out_rows = []
|
318 |
-
for row in rows:
|
319 |
-
srcinfo = row[2]
|
320 |
-
valid = True if name in srcinfo else False
|
321 |
-
valid = not valid if split != "test_ood" else valid
|
322 |
-
if valid:
|
323 |
-
out_rows.append(row)
|
324 |
-
|
325 |
-
out_split_samples[split] = out_rows
|
326 |
-
|
327 |
-
for split in ["train", "valid", "test", "test_ood"]:
|
328 |
-
out_rows = [e[:-1] for e in out_split_samples[split]]
|
329 |
-
print(f"{split} set: {len(out_rows)}")
|
330 |
-
tgt_path = f"{sub_dir}/{split}.csv"
|
331 |
-
with open(tgt_path, "w", newline="", encoding="utf-8-sig") as f:
|
332 |
-
csvw = csv.writer(f)
|
333 |
-
csvw.writerows([["text", "label"]] + out_rows)
|
334 |
-
|
335 |
-
|
336 |
-
# prepare 6 testbeds
|
337 |
-
prepare_domain_specific_model_specific()
|
338 |
-
print("-" * 100)
|
339 |
-
prepare_domain_specific_cross_models()
|
340 |
-
print("-" * 100)
|
341 |
-
prepare_cross_domains_model_specific()
|
342 |
-
print("-" * 100)
|
343 |
-
prepare_cross_domains_cross_models()
|
344 |
-
print("-" * 100)
|
345 |
-
prepare_unseen_models()
|
346 |
-
print("-" * 100)
|
347 |
-
prepare_unseen_domains()
|
348 |
-
print("-" * 100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|