File size: 2,308 Bytes
01c1e7a
 
a4bb658
01c1e7a
a4bb658
 
01c1e7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413bacc
01c1e7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
os.environ["KERAS_BACKEND"] = "torch"

from keras import models, utils
import gradio as gr
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download

def echo_sudoku(sudoku, model_name):
    model = models.load_model(hf_hub_download(
        repo_id="Ritvik19/SudokuNet",
        filename=f"{model_name}.keras",
    ))
    puzzles = sudoku.copy().values.reshape(1, 9, 9)
    for _ in range((puzzles == 0).sum((1, 2)).max()):
        model_preds = model.predict(
            utils.to_categorical(puzzles, num_classes=10), verbose=0
        )
        preds = np.zeros((puzzles.shape[0], 81, 9))
        for i in range(9):
            for j in range(9):
                preds[:, i * 9 + j] = model_preds[f"position_{i+1}_{j+1}"]
        probs = preds.max(2)
        values = preds.argmax(2) + 1
        zeros = (puzzles == 0).reshape((puzzles.shape[0], 81))
        for grid, prob, value, zero in zip(puzzles, probs, values, zeros):
            if any(zero):
                where = np.where(zero)[0]
                confidence_position = where[prob[zero].argmax()]
                confidence_value = value[confidence_position]
                grid.flat[confidence_position] = confidence_value
    return puzzles[0]

model_types = ['ffn', 'cnn']
model_sizes = ['64x2', '64x4', '128x2', '128x4']
model_names = [f"{model_type}__{model_size}" for model_type in model_types for model_size in model_sizes]

DEFAULT_PUZZLE = """
0 0 4 3 0 0 2 0 9
0 0 5 0 0 9 0 0 1
0 7 0 0 6 0 0 4 3
0 0 6 0 0 2 0 8 7
1 9 0 0 0 7 4 0 0
0 5 0 0 8 3 0 0 0
6 0 0 0 0 0 1 0 5
0 0 3 5 0 8 6 9 0
0 4 2 9 1 0 3 0 0
""".strip()
DEFAULT_PUZZLE = np.array([int(digit) for digit in DEFAULT_PUZZLE.split()]).reshape(9, 9)


interface = gr.Interface(
    fn=echo_sudoku, 
    inputs=[
        gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9, value=DEFAULT_PUZZLE), 
        gr.Dropdown(label="Select Model", choices=model_names, value="cnn__64x2")
        ], 
    outputs=gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9),
    title="Sudoku Solver",
    description='A demo app for <a href="https://ritvik19.github.io/sudoku-net" target="_blank">SudokuNet</a>'
)

# Run the app
if __name__ == "__main__":
    interface.launch(debug=True)