ApsidalSolid4 commited on
Commit
024e85c
·
verified ·
1 Parent(s): ea611e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -359
app.py CHANGED
@@ -18,10 +18,6 @@ from openpyxl.utils import get_column_letter
18
  from io import BytesIO
19
  import base64
20
  import hashlib
21
- import requests
22
- import tempfile
23
- from pathlib import Path
24
- import mimetypes
25
 
26
  # Configure logging
27
  logging.basicConfig(level=logging.INFO)
@@ -36,17 +32,6 @@ CONFIDENCE_THRESHOLD = 0.65
36
  BATCH_SIZE = 8 # Reduced batch size for CPU
37
  MAX_WORKERS = 4 # Number of worker threads for processing
38
 
39
- # IMPORTANT: Set PyTorch thread configuration at the module level
40
- # before any parallel work starts
41
- if not torch.cuda.is_available():
42
- # Set thread configuration only once at the beginning
43
- torch.set_num_threads(MAX_WORKERS)
44
- try:
45
- # Only set interop threads if it hasn't been set already
46
- torch.set_num_interop_threads(MAX_WORKERS)
47
- except RuntimeError as e:
48
- logger.warning(f"Could not set interop threads: {str(e)}")
49
-
50
  # Get password hash from environment variable (more secure)
51
  ADMIN_PASSWORD_HASH = os.environ.get('ADMIN_PASSWORD_HASH')
52
 
@@ -56,138 +41,10 @@ if not ADMIN_PASSWORD_HASH:
56
  # Excel file path for logs
57
  EXCEL_LOG_PATH = "/tmp/prediction_logs.xlsx"
58
 
59
- # OCR API settings
60
- OCR_API_KEY = "9e11346f1288957" # This is a partial key - replace with the full one
61
- OCR_API_ENDPOINT = "https://api.ocr.space/parse/image"
62
- OCR_MAX_PDF_PAGES = 3
63
- OCR_MAX_FILE_SIZE_MB = 1
64
-
65
- # Configure logging for OCR module
66
- ocr_logger = logging.getLogger("ocr_module")
67
- ocr_logger.setLevel(logging.INFO)
68
-
69
- class OCRProcessor:
70
- """
71
- Handles OCR processing of image and document files using OCR.space API
72
- """
73
- def __init__(self, api_key: str = OCR_API_KEY):
74
- self.api_key = api_key
75
- self.endpoint = OCR_API_ENDPOINT
76
-
77
- def process_file(self, file_path: str) -> Dict:
78
- """
79
- Process a file using OCR.space API
80
- """
81
- start_time = time.time()
82
- ocr_logger.info(f"Starting OCR processing for file: {os.path.basename(file_path)}")
83
-
84
- # Validate file size
85
- file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
86
- if file_size_mb > OCR_MAX_FILE_SIZE_MB:
87
- ocr_logger.warning(f"File size ({file_size_mb:.2f} MB) exceeds limit of {OCR_MAX_FILE_SIZE_MB} MB")
88
- return {
89
- "success": False,
90
- "error": f"File size ({file_size_mb:.2f} MB) exceeds limit of {OCR_MAX_FILE_SIZE_MB} MB",
91
- "text": ""
92
- }
93
-
94
- # Determine file type and handle accordingly
95
- file_type = self._get_file_type(file_path)
96
- ocr_logger.info(f"Detected file type: {file_type}")
97
-
98
- # Prepare the API request
99
- with open(file_path, 'rb') as f:
100
- file_data = f.read()
101
-
102
- # Set up API parameters
103
- payload = {
104
- 'isOverlayRequired': 'false',
105
- 'language': 'eng',
106
- 'OCREngine': '2', # Use more accurate engine
107
- 'scale': 'true',
108
- 'detectOrientation': 'true',
109
- }
110
-
111
- # For PDF files, check page count limitations
112
- if file_type == 'application/pdf':
113
- ocr_logger.info("PDF document detected, enforcing page limit")
114
- payload['filetype'] = 'PDF'
115
-
116
- # Prepare file for OCR API
117
- files = {
118
- 'file': (os.path.basename(file_path), file_data, file_type)
119
- }
120
-
121
- headers = {
122
- 'apikey': self.api_key,
123
- }
124
-
125
- # Make the OCR API request
126
- try:
127
- ocr_logger.info("Sending request to OCR.space API")
128
- response = requests.post(
129
- self.endpoint,
130
- files=files,
131
- data=payload,
132
- headers=headers
133
- )
134
- response.raise_for_status()
135
- result = response.json()
136
-
137
- # Process the OCR results
138
- if result.get('OCRExitCode') in [1, 2]: # Success or partial success
139
- extracted_text = self._extract_text_from_result(result)
140
- processing_time = time.time() - start_time
141
- ocr_logger.info(f"OCR processing completed in {processing_time:.2f} seconds")
142
-
143
- return {
144
- "success": True,
145
- "text": extracted_text,
146
- "word_count": len(extracted_text.split()),
147
- "processing_time_ms": int(processing_time * 1000)
148
- }
149
- else:
150
- ocr_logger.error(f"OCR API error: {result.get('ErrorMessage', 'Unknown error')}")
151
- return {
152
- "success": False,
153
- "error": result.get('ErrorMessage', 'OCR processing failed'),
154
- "text": ""
155
- }
156
-
157
- except requests.exceptions.RequestException as e:
158
- ocr_logger.error(f"OCR API request failed: {str(e)}")
159
- return {
160
- "success": False,
161
- "error": f"OCR API request failed: {str(e)}",
162
- "text": ""
163
- }
164
-
165
- def _extract_text_from_result(self, result: Dict) -> str:
166
- """
167
- Extract all text from the OCR API result
168
- """
169
- extracted_text = ""
170
-
171
- if 'ParsedResults' in result and result['ParsedResults']:
172
- for parsed_result in result['ParsedResults']:
173
- if parsed_result.get('ParsedText'):
174
- extracted_text += parsed_result['ParsedText']
175
-
176
- return extracted_text
177
-
178
- def _get_file_type(self, file_path: str) -> str:
179
- """
180
- Determine MIME type of a file
181
- """
182
- mime_type, _ = mimetypes.guess_type(file_path)
183
- if mime_type is None:
184
- # Default to binary if MIME type can't be determined
185
- return 'application/octet-stream'
186
- return mime_type
187
-
188
  def is_admin_password(input_text: str) -> bool:
189
  """
190
  Check if the input text matches the admin password using secure hash comparison.
 
191
  """
192
  # Hash the input text
193
  input_hash = hashlib.sha256(input_text.strip().encode()).hexdigest()
@@ -248,6 +105,11 @@ class TextWindowProcessor:
248
 
249
  class TextClassifier:
250
  def __init__(self):
 
 
 
 
 
251
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
252
  self.model_name = MODEL_NAME
253
  self.tokenizer = None
@@ -391,7 +253,7 @@ class TextClassifier:
391
  for window_idx, indices in enumerate(batch_indices):
392
  center_idx = len(indices) // 2
393
  center_weight = 0.7 # Higher weight for center sentence
394
- edge_weight = 0.3 / (len(indices) - 1) if len(indices) > 1 else 0 # Distribute remaining weight
395
 
396
  for pos, sent_idx in enumerate(indices):
397
  # Apply higher weight to center sentence
@@ -414,10 +276,10 @@ class TextClassifier:
414
 
415
  # Apply minimal smoothing at prediction boundaries
416
  if i > 0 and i < len(sentences) - 1:
417
- prev_human = sentence_scores[i-1]['human_prob'] / max(sentence_appearances[i-1], 1e-10)
418
- prev_ai = sentence_scores[i-1]['ai_prob'] / max(sentence_appearances[i-1], 1e-10)
419
- next_human = sentence_scores[i+1]['human_prob'] / max(sentence_appearances[i+1], 1e-10)
420
- next_ai = sentence_scores[i+1]['ai_prob'] / max(sentence_appearances[i+1], 1e-10)
421
 
422
  # Check if we're at a prediction boundary
423
  current_pred = 'human' if human_prob > ai_prob else 'ai'
@@ -492,72 +354,6 @@ class TextClassifier:
492
  'num_sentences': num_sentences
493
  }
494
 
495
- # Function to handle file upload, OCR processing, and text analysis
496
- def handle_file_upload_and_analyze(file_obj, mode: str, classifier) -> tuple:
497
- """
498
- Handle file upload, OCR processing, and text analysis
499
- """
500
- if file_obj is None:
501
- return (
502
- "No file uploaded",
503
- "Please upload a file to analyze",
504
- "No file uploaded for analysis"
505
- )
506
-
507
- # Create a temporary file with an appropriate extension based on content
508
- content_start = file_obj[:20] # Look at the first few bytes
509
-
510
- # Default to .bin extension
511
- file_ext = ".bin"
512
-
513
- # Try to detect PDF files
514
- if content_start.startswith(b'%PDF'):
515
- file_ext = ".pdf"
516
- # For images, detect by common magic numbers
517
- elif content_start.startswith(b'\xff\xd8'): # JPEG
518
- file_ext = ".jpg"
519
- elif content_start.startswith(b'\x89PNG'): # PNG
520
- file_ext = ".png"
521
- elif content_start.startswith(b'GIF'): # GIF
522
- file_ext = ".gif"
523
-
524
- # Create a temporary file with the detected extension
525
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
526
- temp_file_path = temp_file.name
527
- # Write uploaded file data to the temporary file
528
- temp_file.write(file_obj)
529
-
530
- try:
531
- # Process the file with OCR
532
- ocr_processor = OCRProcessor()
533
- ocr_result = ocr_processor.process_file(temp_file_path)
534
-
535
- if not ocr_result["success"]:
536
- return (
537
- "OCR Processing Error",
538
- ocr_result["error"],
539
- "Failed to extract text from the uploaded file"
540
- )
541
-
542
- # Get the extracted text
543
- extracted_text = ocr_result["text"]
544
-
545
- # If no text was extracted
546
- if not extracted_text.strip():
547
- return (
548
- "No text extracted",
549
- "The OCR process did not extract any text from the uploaded file.",
550
- "No text was found in the uploaded file"
551
- )
552
-
553
- # Call the original text analysis function with the extracted text
554
- return analyze_text(extracted_text, mode, classifier)
555
-
556
- finally:
557
- # Clean up the temporary file
558
- if os.path.exists(temp_file_path):
559
- os.remove(temp_file_path)
560
-
561
  def initialize_excel_log():
562
  """Initialize the Excel log file if it doesn't exist."""
563
  if not os.path.exists(EXCEL_LOG_PATH):
@@ -585,7 +381,6 @@ def initialize_excel_log():
585
  wb.save(EXCEL_LOG_PATH)
586
  logger.info(f"Initialized Excel log file at {EXCEL_LOG_PATH}")
587
 
588
-
589
  def log_prediction_data(input_text, word_count, prediction, confidence, execution_time, mode):
590
  """Log prediction data to an Excel file in the /tmp directory."""
591
  # Initialize the Excel file if it doesn't exist
@@ -628,7 +423,6 @@ def log_prediction_data(input_text, word_count, prediction, confidence, executio
628
  logger.error(f"Error logging prediction data to Excel: {str(e)}")
629
  return False
630
 
631
-
632
  def get_logs_as_base64():
633
  """Read the Excel logs file and return as base64 for downloading."""
634
  if not os.path.exists(EXCEL_LOG_PATH):
@@ -647,7 +441,6 @@ def get_logs_as_base64():
647
  logger.error(f"Error reading Excel logs: {str(e)}")
648
  return None
649
 
650
-
651
  def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
652
  """Analyze text using specified mode and return formatted results."""
653
  # Check if the input text matches the admin password using secure comparison
@@ -770,151 +563,51 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
770
  # Initialize the classifier globally
771
  classifier = TextClassifier()
772
 
773
- # Create Gradio interface with a properly sized file upload button
774
- def create_interface():
775
- # Custom CSS for the interface
776
- css = """
777
- #analyze-btn {
778
- background-color: #FF8C00 !important;
779
- border-color: #FF8C00 !important;
780
- color: white !important;
781
- }
782
-
783
- /* Style the file upload container to match the radio buttons */
784
- .file-upload-container {
785
- margin-left: 15px;
786
- display: inline-block;
787
- vertical-align: middle;
788
- }
789
-
790
- /* Hide file info and preview */
791
- .file-upload-container .file-preview {
792
- display: none !important;
793
- }
794
-
795
- /* Style the upload button to a proper size */
796
- .file-upload-container [data-testid="chunkFileDropArea"] {
797
- width: 150px !important;
798
- height: 40px !important;
799
- background-color: #f0f0f0 !important;
800
- border: 1px solid #d9d9d9 !important;
801
- border-radius: 4px !important;
802
- display: flex !important;
803
- align-items: center !important;
804
- justify-content: center !important;
805
- padding: 0 10px !important;
806
- margin: 0 !important;
807
- }
808
-
809
- /* Show only the "Upload Document" text */
810
- .file-upload-container [data-testid="chunkFileDropArea"] * {
811
- display: none !important;
812
- }
813
-
814
- /* Add a new label */
815
- .file-upload-container [data-testid="chunkFileDropArea"]::before {
816
- content: "Upload Document" !important;
817
- display: block !important;
818
- font-size: 14px !important;
819
- color: #444 !important;
820
- }
821
-
822
- /* Hover effect */
823
- .file-upload-container [data-testid="chunkFileDropArea"]:hover {
824
- background-color: #e0e0e0 !important;
825
- cursor: pointer !important;
826
- }
827
- """
828
-
829
- with gr.Blocks(css=css, title="AI Text Detector") as demo:
830
- gr.Markdown("# AI Text Detector")
831
- gr.Markdown("Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis. 200+ words suggested for accurate predictions.")
832
-
833
- with gr.Row():
834
- # Left column - Input
835
- with gr.Column(scale=1):
836
- # Text input area
837
- text_input = gr.Textbox(
838
- lines=8,
839
- placeholder="Enter text to analyze...",
840
- label="Input Text"
841
- )
842
-
843
- # Analysis Mode section
844
- gr.Markdown("Analysis Mode")
845
- gr.Markdown("Quick mode for faster analysis. Detailed mode for sentence-level analysis.")
846
-
847
- # Simple row layout for radio buttons and file upload
848
- with gr.Row():
849
- mode_selection = gr.Radio(
850
- choices=["quick", "detailed"],
851
- value="quick",
852
- label="",
853
- show_label=False
854
- )
855
-
856
- # File upload component with compact styling
857
- with gr.Column(elem_classes=["file-upload-container"], scale=0):
858
- file_upload = gr.File(
859
- file_types=["image", "pdf", "doc", "docx"],
860
- type="binary",
861
- label="",
862
- show_label=False,
863
- elem_id="file-upload"
864
- )
865
-
866
- # Analyze button
867
- analyze_btn = gr.Button("Analyze Text", elem_id="analyze-btn")
868
-
869
- # Right column - Results
870
- with gr.Column(scale=1):
871
- output_html = gr.HTML(label="Highlighted Analysis")
872
- output_sentences = gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10)
873
- output_result = gr.Textbox(label="Overall Result", lines=4)
874
-
875
- # Connect components
876
- # 1. Analyze button click
877
- analyze_btn.click(
878
- fn=lambda text, mode: analyze_text(text, mode, classifier),
879
- inputs=[text_input, mode_selection],
880
- outputs=[output_html, output_sentences, output_result]
881
- )
882
-
883
- # 2. File upload change event
884
- file_upload.change(
885
- fn=handle_file_upload_and_analyze,
886
- inputs=[file_upload, mode_selection],
887
- outputs=[output_html, output_sentences, output_result]
888
  )
889
-
890
- return demo
891
-
892
- # Setup the app with CORS middleware
893
- def setup_app():
894
- demo = create_interface()
895
-
896
- # Get the FastAPI app from Gradio
897
- app = demo.app
898
-
899
- # Add CORS middleware
900
- app.add_middleware(
901
- CORSMiddleware,
902
- allow_origins=["*"], # For development
903
- allow_credentials=True,
904
- allow_methods=["GET", "POST", "OPTIONS"],
905
- allow_headers=["*"],
906
- )
907
-
908
- return demo
909
-
910
- # Initialize the application
 
 
 
911
  if __name__ == "__main__":
912
- demo = setup_app()
913
-
914
- # Start the server
915
  demo.queue()
916
  demo.launch(
917
  server_name="0.0.0.0",
918
  server_port=7860,
919
  share=True
920
- )
 
 
18
  from io import BytesIO
19
  import base64
20
  import hashlib
 
 
 
 
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
 
32
  BATCH_SIZE = 8 # Reduced batch size for CPU
33
  MAX_WORKERS = 4 # Number of worker threads for processing
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Get password hash from environment variable (more secure)
36
  ADMIN_PASSWORD_HASH = os.environ.get('ADMIN_PASSWORD_HASH')
37
 
 
41
  # Excel file path for logs
42
  EXCEL_LOG_PATH = "/tmp/prediction_logs.xlsx"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def is_admin_password(input_text: str) -> bool:
45
  """
46
  Check if the input text matches the admin password using secure hash comparison.
47
+ This prevents the password from being visible in the source code.
48
  """
49
  # Hash the input text
50
  input_hash = hashlib.sha256(input_text.strip().encode()).hexdigest()
 
105
 
106
  class TextClassifier:
107
  def __init__(self):
108
+ # Set thread configuration before any model loading or parallel work
109
+ if not torch.cuda.is_available():
110
+ torch.set_num_threads(MAX_WORKERS)
111
+ torch.set_num_interop_threads(MAX_WORKERS)
112
+
113
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
  self.model_name = MODEL_NAME
115
  self.tokenizer = None
 
253
  for window_idx, indices in enumerate(batch_indices):
254
  center_idx = len(indices) // 2
255
  center_weight = 0.7 # Higher weight for center sentence
256
+ edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight
257
 
258
  for pos, sent_idx in enumerate(indices):
259
  # Apply higher weight to center sentence
 
276
 
277
  # Apply minimal smoothing at prediction boundaries
278
  if i > 0 and i < len(sentences) - 1:
279
+ prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
280
+ prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
281
+ next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
282
+ next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
283
 
284
  # Check if we're at a prediction boundary
285
  current_pred = 'human' if human_prob > ai_prob else 'ai'
 
354
  'num_sentences': num_sentences
355
  }
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def initialize_excel_log():
358
  """Initialize the Excel log file if it doesn't exist."""
359
  if not os.path.exists(EXCEL_LOG_PATH):
 
381
  wb.save(EXCEL_LOG_PATH)
382
  logger.info(f"Initialized Excel log file at {EXCEL_LOG_PATH}")
383
 
 
384
  def log_prediction_data(input_text, word_count, prediction, confidence, execution_time, mode):
385
  """Log prediction data to an Excel file in the /tmp directory."""
386
  # Initialize the Excel file if it doesn't exist
 
423
  logger.error(f"Error logging prediction data to Excel: {str(e)}")
424
  return False
425
 
 
426
  def get_logs_as_base64():
427
  """Read the Excel logs file and return as base64 for downloading."""
428
  if not os.path.exists(EXCEL_LOG_PATH):
 
441
  logger.error(f"Error reading Excel logs: {str(e)}")
442
  return None
443
 
 
444
  def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
445
  """Analyze text using specified mode and return formatted results."""
446
  # Check if the input text matches the admin password using secure comparison
 
563
  # Initialize the classifier globally
564
  classifier = TextClassifier()
565
 
566
+ # Create Gradio interface
567
+ demo = gr.Interface(
568
+ fn=lambda text, mode: analyze_text(text, mode, classifier),
569
+ inputs=[
570
+ gr.Textbox(
571
+ lines=8,
572
+ placeholder="Enter text to analyze...",
573
+ label="Input Text"
574
+ ),
575
+ gr.Radio(
576
+ choices=["quick", "detailed"],
577
+ value="quick",
578
+ label="Analysis Mode",
579
+ info="Quick mode for faster analysis, Detailed mode for sentence-level analysis"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  )
581
+ ],
582
+ outputs=[
583
+ gr.HTML(label="Highlighted Analysis"),
584
+ gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10),
585
+ gr.Textbox(label="Overall Result", lines=4)
586
+ ],
587
+ title="AI Text Detector",
588
+ description="Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis. 200+ words suggested for accurate predictions.",
589
+ api_name="predict",
590
+ flagging_mode="never"
591
+ )
592
+
593
+ # Get the FastAPI app from Gradio
594
+ app = demo.app
595
+
596
+ # Add CORS middleware
597
+ app.add_middleware(
598
+ CORSMiddleware,
599
+ allow_origins=["*"], # For development
600
+ allow_credentials=True,
601
+ allow_methods=["GET", "POST", "OPTIONS"],
602
+ allow_headers=["*"],
603
+ )
604
+
605
+ # Ensure CORS is applied before launching
606
  if __name__ == "__main__":
 
 
 
607
  demo.queue()
608
  demo.launch(
609
  server_name="0.0.0.0",
610
  server_port=7860,
611
  share=True
612
+ )
613
+