File size: 10,553 Bytes
8e6cbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Main Flask application for the watermark detection web interface.
"""

from flask import Flask, render_template, request, jsonify, Response, stream_with_context
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json

from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator
from .utils import get_token_details, template_prompt

CACHE_DIR = "wm_interactive/static/hf_cache"

def convert_nan_to_null(obj):
    """Convert NaN values to null for JSON serialization"""
    import math
    if isinstance(obj, float) and math.isnan(obj):
        return None
    elif isinstance(obj, dict):
        return {k: convert_nan_to_null(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_nan_to_null(item) for item in obj]
    return obj

def set_to_int(value, default_value = None):
    try:
        return int(value)
    except (ValueError, TypeError):
        return default_value

def create_detector(detector_type, tokenizer, **kwargs):
    """Create a detector instance based on the specified type."""
    detector_map = {
        'maryland': MarylandDetector,
        'marylandz': MarylandDetectorZ,
        'openai': OpenaiDetector,
        'openaiz': OpenaiDetectorZ
    }
    
    # Validate and set default values for parameters
    if 'seed' in kwargs:
        kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0)
    if 'ngram' in kwargs:
        kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1)
            
    detector_class = detector_map.get(detector_type, MarylandDetector)
    return detector_class(tokenizer=tokenizer, **kwargs)

def create_app():
    app = Flask(__name__, 
                static_folder='../static',
                template_folder='../templates')

    # Add zip to Jinja's global context
    app.jinja_env.globals.update(zip=zip)

    # Pick a model
    # model_id = "meta-llama/Llama-3.2-1B-Instruct"
    model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu")

    # Create default generator
    generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0)

    @app.route("/", methods=["GET"])
    def index():
        return render_template("index.html")

    @app.route("/tokenize", methods=["POST"])
    def tokenize():
        try:
            data = request.get_json()
            if not data:
                return jsonify({'error': 'No JSON data received'}), 400
                
            text = data.get('text', '')
            params = data.get('params', {})
            
            # Create a detector instance with the provided parameters
            detector = create_detector(
                detector_type=params.get('detector_type', 'maryland'),
                tokenizer=tokenizer,
                seed=params.get('seed', 0),
                ngram=params.get('ngram', 1)
            )
            
            if text:
                try:
                    display_info = get_token_details(text, detector)

                    # Extract summary stats (last item in display_info)
                    stats = display_info.pop()

                    response_data = {
                        'token_count': len(display_info),
                        'tokens': [info['token'] for info in display_info],
                        'colors': [info['color'] for info in display_info],
                        'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info],
                        'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info],
                        'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0,
                        'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0,
                        'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5
                    }
                    
                    # Convert any NaN values to null before sending
                    response_data = convert_nan_to_null(response_data)
                    
                    # Ensure numeric fields have default values if they became null
                    if response_data['final_score'] is None:
                        response_data['final_score'] = 0
                    if response_data['ntoks_scored'] is None:
                        response_data['ntoks_scored'] = 0
                    if response_data['final_pvalue'] is None:
                        response_data['final_pvalue'] = 0.5
                        
                    return jsonify(response_data)
                    
                except Exception as e:
                    app.logger.error(f'Error processing text: {str(e)}')
                    return jsonify({'error': f'Error processing text: {str(e)}'}), 500
            
            return jsonify({
                'token_count': 0,
                'tokens': [],
                'colors': [],
                'scores': [],
                'pvalues': [],
                'final_score': 0,
                'ntoks_scored': 0,
                'final_pvalue': 0.5
            })
            
        except Exception as e:
            app.logger.error(f'Server error: {str(e)}')
            return jsonify({'error': f'Server error: {str(e)}'}), 500

    @app.route("/generate", methods=["POST"])
    def generate():
        try:
            data = request.get_json()
            if not data:
                return jsonify({'error': 'No JSON data received'}), 400
                
            prompt = template_prompt(data.get('prompt', ''))
            params = data.get('params', {})
            temperature = float(params.get('temperature', 0.8))
            
            def generate_stream():
                try:
                    # Create generator with correct parameters
                    generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator
                    generator = generator_class(
                        model=model,
                        tokenizer=tokenizer,
                        ngram=set_to_int(params.get('ngram', 1)),
                        seed=set_to_int(params.get('seed', 0)),
                        delta=float(params.get('delta', 2.0)),
                    )

                    # Get special tokens to filter out
                    special_tokens = {
                        '<|im_start|>', '<|im_end|>',
                        tokenizer.pad_token, tokenizer.eos_token,
                        tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None,
                        tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None
                    }
                    special_tokens = {t for t in special_tokens if t is not None}

                    # Encode prompt
                    prompt_tokens = tokenizer.encode(prompt)
                    prompt_size = len(prompt_tokens)
                    max_gen_len = 100
                    total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size)

                    # Initialize generation
                    tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long()
                    tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long()
                    input_text_mask = tokens != model.config.pad_token_id

                    # Generate token by token
                    prev_pos = 0
                    outputs = None  # Initialize outputs to None
                    for cur_pos in range(prompt_size, total_len):
                        # Get model outputs
                        outputs = model.forward(
                            tokens[:, prev_pos:cur_pos], 
                            use_cache=True, 
                            past_key_values=outputs.past_key_values if prev_pos > 0 else None
                        )

                        # Sample next token using the generator's sampling method
                        ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist()
                        aux = {
                            'ngram_tokens': ngram_tokens,
                            'cur_pos': cur_pos,
                        }
                        next_token = generator.sample_next(
                            outputs.logits[:, -1, :],
                            aux,
                            temperature=temperature,
                            top_p=0.9
                        )
                        # Check for EOS token
                        if next_token == model.config.eos_token_id:
                            break

                        # Decode and check if it's a special token
                        new_text = tokenizer.decode([next_token])
                        if new_text not in special_tokens and not any(st in new_text for st in special_tokens):
                            yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n"

                        # Update token and position
                        tokens[0, cur_pos] = next_token
                        prev_pos = cur_pos

                    # Send final complete text, filtering out special tokens
                    final_tokens = tokens[0, prompt_size:cur_pos+1].tolist()
                    final_text = tokenizer.decode(final_tokens)
                    for st in special_tokens:
                        final_text = final_text.replace(st, '')
                    yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n"
                    
                except Exception as e:
                    app.logger.error(f'Error generating text: {str(e)}')
                    yield f"data: {json.dumps({'error': str(e)})}\n\n"
            
            return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')
            
        except Exception as e:
            app.logger.error(f'Server error: {str(e)}')
            return jsonify({'error': f'Server error: {str(e)}'}), 500

    return app

app = create_app()

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)