Ritvik19 commited on
Commit
01c1e7a
·
verified ·
1 Parent(s): ec12063

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ os.environ["KERAS_BACKEND"] = "torch"
5
+ import streamlit as st
6
+ import numpy as np
7
+ from keras import models, utils
8
+ import pandas as pd
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ def echo_sudoku(sudoku, model_name):
12
+ model = models.load_model(hf_hub_download(
13
+ repo_id="Ritvik19/SudokuNet",
14
+ filename=f"{model_name}.keras",
15
+ ))
16
+ puzzles = sudoku.copy().values.reshape(1, 9, 9)
17
+ for _ in range((puzzles == 0).sum((1, 2)).max()):
18
+ model_preds = model.predict(
19
+ utils.to_categorical(puzzles, num_classes=10), verbose=0
20
+ )
21
+ preds = np.zeros((puzzles.shape[0], 81, 9))
22
+ for i in range(9):
23
+ for j in range(9):
24
+ preds[:, i * 9 + j] = model_preds[f"position_{i+1}_{j+1}"]
25
+ probs = preds.max(2)
26
+ values = preds.argmax(2) + 1
27
+ zeros = (puzzles == 0).reshape((puzzles.shape[0], 81))
28
+ for grid, prob, value, zero in zip(puzzles, probs, values, zeros):
29
+ if any(zero):
30
+ where = np.where(zero)[0]
31
+ confidence_position = where[prob[zero].argmax()]
32
+ confidence_value = value[confidence_position]
33
+ grid.flat[confidence_position] = confidence_value
34
+ return puzzles[0]
35
+
36
+ model_types = ['ffn', 'cnn', 'rnn', 'lstm', 'gru']
37
+ model_sizes = ['64x2', '64x4', '128x2', '128x4']
38
+ model_names = [f"{model_type}__{model_size}" for model_type in model_types for model_size in model_sizes]
39
+
40
+ DEFAULT_PUZZLE = """
41
+ 0 0 4 3 0 0 2 0 9
42
+ 0 0 5 0 0 9 0 0 1
43
+ 0 7 0 0 6 0 0 4 3
44
+ 0 0 6 0 0 2 0 8 7
45
+ 1 9 0 0 0 7 4 0 0
46
+ 0 5 0 0 8 3 0 0 0
47
+ 6 0 0 0 0 0 1 0 5
48
+ 0 0 3 5 0 8 6 9 0
49
+ 0 4 2 9 1 0 3 0 0
50
+ """.strip()
51
+ DEFAULT_PUZZLE = np.array([int(digit) for digit in DEFAULT_PUZZLE.split()]).reshape(9, 9)
52
+
53
+
54
+ interface = gr.Interface(
55
+ fn=echo_sudoku,
56
+ inputs=[
57
+ gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9, value=DEFAULT_PUZZLE),
58
+ gr.Dropdown(label="Select Model", choices=model_names, value="cnn__64x2")
59
+ ],
60
+ outputs=gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9),
61
+ title="Sudoku Solver",
62
+ description='A demo app for <a href="https://ritvik19.github.io/sudoku-net" target="_blank">SudokuNet</a>'
63
+ )
64
+
65
+ # Run the app
66
+ if __name__ == "__main__":
67
+ interface.launch(debug=True)