In [ ]:
from pyeda.inter import *

In [ ]:
DIGITS = "123456789"

In [ ]:
X = exprvars('x', (1, 10), (1, 10), (1, 10))

In [ ]:
V = And(*[
        And(*[
            OneHot(*[X[r,c,v] for v in range(1, 10)])
                     for c in range(1, 10)]
        )
        for r in range(1, 10)]
    )

In [ ]:
R = And(*[
        And(*[
            OneHot(*[X[r,c,v] for c in range(1, 10)])
                     for v in range(1, 10)]
        )
        for r in range(1, 10)]
    )

In [ ]:
C = And(*[
        And(*[
            OneHot(*[X[r,c,v] for r in range(1, 10)])
                     for v in range(1, 10)]
        )
        for c in range(1, 10)]
    )

In [ ]:
B = And(*[
        And(*[
            OneHot(*[
                X[3*br+r,3*bc+c,v]
                    for r in range(1, 4)
                    for c in range(1, 4)]
            )
            for v in range(1, 10)]
        )
        for br in range(3) for bc in range(3)]
    )

In [ ]:
S = And(V, R, C, B)

In [ ]:
len(S.xs)

In [ ]:
# This step does absorption
S = S.to_cnf()

In [ ]:
len(S.xs)

In [ ]:
def parse_grid(grid):
    chars = [c for c in grid if c in DIGITS or c in "0."]
    assert len(chars) == 9 ** 2
    return And(*[ X[i // 9 + 1, i % 9 + 1, int(c)]
                  for i, c in enumerate(chars) if c in DIGITS ])

In [ ]:
grid = ( ".73|...|8.."
         "..4|13.|.5."
         ".85|..6|31."
         "---+---+---"
         "5..|.9.|.3."
         "..8|.1.|5.."
         ".1.|.6.|..7"
         "---+---+---"
         ".51|6..|28."
         ".4.|.52|9.."
         "..2|...|64." )

In [ ]:
def get_val(point, r, c):
    for v in range(1, 10):
        if point[X[r,c,v]]:
            return DIGITS[v-1]
    return "X"

In [ ]:
def display(point):
    chars = list()
    for r in range(1, 10):
        for c in range(1, 10):
            if c in (4, 7):
                chars.append("|")
            chars.append(get_val(point, r, c))
        if r != 9:
            chars.append("\n")
            if r in (3, 6):
                chars.append("---+---+---\n")
    print("".join(chars))

In [ ]:
def solve(grid):
    with parse_grid(grid):
        return S.satisfy_one()

In [ ]:
display(solve(grid))

In [ ]: