rubyquiz/SudokuSolver.hs

248 lines
9.8 KiB
Haskell
Raw Normal View History

2012-10-24 18:18:18 +05:30
{-
A solution to rubyquiz 43 (http://rubyquiz.com/quiz43.html).
A fast multi-threaded Sudoku solver using recursive depth-first backtracking
for searching and constraint propagation for solving.
Solves the 49191 puzzles at http://school.maths.uwa.edu.au/~gordon/sudoku17
in 32 seconds on a quad core machine with output switched off.
Each puzzle should be formatted as a single line of 81 character, top to bottom,
left to right, with digits 1 to 9 if the cell has a value else a dot (.).
Example:
4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......
Usage:
cat sudoku17 | bin/SudokuSolver +RTS -N4 -H800m -K50m
echo "4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......" | bin/SudokuSolver
Copyright 2012 Abhinav Sarkar <abhinav@abhinavsarkar.net>
-}
{-# LANGUAGE BangPatterns, RecordWildCards #-}
2012-10-08 14:24:00 +05:30
2012-10-27 11:07:22 +05:30
module SudokuSolver (Cell(..), Board, emptyBoard, boardCells, cellValues,
isBoardSolved, readBoard, showBoard, prettyShowBoard,
solveSudoku, main)
where
2012-10-08 14:24:00 +05:30
import qualified Data.Set as S
import qualified Data.HashMap.Strict as M
import Control.Concurrent (forkIO, newEmptyMVar, putMVar, takeMVar)
import Control.Monad (foldM, forM_, forM, (>=>))
import Data.Bits (testBit, (.&.), complement, popCount, bit)
2012-10-24 18:18:18 +05:30
import Data.Char (digitToInt, intToDigit, isDigit)
import Data.List (foldl', intersperse, intercalate, find, sortBy)
import Data.List.Split (chunksOf)
2012-10-08 14:24:00 +05:30
import Data.Maybe (fromJust)
import Data.Ord (comparing)
import Data.Word (Word16)
2012-10-08 14:24:00 +05:30
import System.CPUTime (getCPUTime)
import Text.Printf (printf)
2012-10-24 18:18:18 +05:30
-- A cell in the Sudoku. The fields are cell index, possible cell values as
-- a bitset and number of possible cell values.
data Cell = Cell {-# UNPACK #-} !Int
{-# UNPACK #-} !Word16
{-# UNPACK #-} !Int
2012-10-08 14:24:00 +05:30
2012-10-24 18:18:18 +05:30
-- The Sudoku board implemented as a HashMap from cell index to cell (ixMap).
-- ambCells is the set of cells which have not been solved yet.
data Board = Board { ixMap :: !(M.HashMap Int Cell),
ambCells :: !(S.Set Cell)
2012-10-27 11:07:22 +05:30
} deriving (Eq)
instance Eq Cell where
{-# INLINE (==) #-}
(Cell i1 v1 _) == (Cell i2 v2 _) = i1 == i2 && v1 == v2
2012-10-08 14:24:00 +05:30
instance Show Cell where
2012-10-27 11:07:22 +05:30
show cell@(Cell ix val _) = "<" ++ show ix ++ " " ++ show (cellValues cell) ++ ">"
2012-10-08 14:24:00 +05:30
instance Ord Cell where
(Cell i1 v1 vl1) `compare` (Cell i2 v2 vl2) =
if i1 == i2 && v1 == v2
2012-10-24 20:00:11 +05:30
then EQ
else (vl1, i1) `compare`(vl2, i2)
2012-10-08 14:24:00 +05:30
2012-10-27 11:07:22 +05:30
cellValues :: Cell -> [Int]
cellValues (Cell _ val _) = filter (testBit val) [1..9]
boardCells :: Board -> [Cell]
boardCells = map snd . sortBy (comparing fst) . M.toList . ixMap
2012-10-24 18:18:18 +05:30
-- Gets the index of the lowest bit set as 1.
firstSol :: Word16 -> Int
firstSol val = fromJust . find (testBit val) $ [1..9]
cells :: Board -> [Int] -> [Cell]
cells board = map (fromJust . flip M.lookup (ixMap board))
rowIxs, columnIxs, boxIxs, unitIxs :: [[Int]]
rowIxs = chunksOf 9 [0..80]
columnIxs = map (\i -> take 9 [i, i + 9 ..]) [0..8]
boxIxs = concatMap (\(x:y:z:_) -> zipWith3 (\a b c -> a ++ b ++ c) x y z)
. chunksOf 3 . map (chunksOf 3) $ rowIxs
unitIxs = rowIxs ++ columnIxs ++ boxIxs
-- Checks if a Sudoku board is solved.
-- A board is solved if all the cells have only one possible value and all rows,
-- columns and boxes follow the rule of Sudoku.
isBoardSolved :: Board -> Bool
isBoardSolved board =
(all (\(Cell _ _ vl) -> vl == 1) . M.elems . ixMap $ board)
&& all (isUnitSolved . cells board) unitIxs
where
isUnitSolved unit = S.size (S.fromList unit) == 9
2012-10-24 18:18:18 +05:30
-- An empty Sudoku board where all cells have all possible values.
emptyBoard :: Board
emptyBoard =
Board (foldl' (\m c@(Cell i _ _) -> M.insert i c m) M.empty cells)
(S.fromList cells)
where cells = map (\i -> Cell i 1022 9) [0..80]
2012-10-24 18:18:18 +05:30
-- Updates the given cell in the board.
updateBoard :: Board -> Cell -> Board
updateBoard board@Board{..} cell@(Cell ix _ vl) = case M.lookup ix ixMap of
2012-10-24 20:00:11 +05:30
Nothing -> board
Just oldCell | oldCell == cell -> board
| vl == 1 -> Board (M.insert ix cell ixMap)
(S.delete oldCell ambCells)
| otherwise -> Board (M.insert ix cell ixMap)
(S.insert cell (S.delete oldCell ambCells))
2012-10-24 18:18:18 +05:30
-- Constrains the values of a cell (third argument) according to the values of
-- another cell (first argument) in the given board.
-- If there is a conflict in the values of the cells, returns Nothing.
constrainCell :: Cell -> Board -> Cell -> Maybe Board
2012-10-24 20:00:11 +05:30
constrainCell cell@(Cell _ val vl) board@Board{..} c@(Cell i pos pl)
| c == cell = return board
| pos' == 0 && vl == 1 = Nothing
| pos' == 0 = return board
| pl' == 1 && pl > 1 = constrainBoard board (Cell i pos' pl')
| otherwise = return $ updateBoard board (Cell i pos' pl')
2012-10-08 14:24:00 +05:30
where
pos' = pos .&. complement val
pl' = popCount pos'
2012-10-24 18:18:18 +05:30
-- Constrains the values of the given cells according to the values of the given
-- cell in the given board.
-- If there is a conflict in the values, returns Nothing.
constrainCells :: Cell -> Board -> [Cell] -> Maybe Board
constrainCells cell = foldM (constrainCell cell)
2012-10-08 14:24:00 +05:30
2012-10-24 18:18:18 +05:30
-- Constrains the given board according to the values of the given cell by
-- applying the rule of Sudoku: a unit cannot have same value for more than
-- one cell where a unit is a row, cell or a 3x3 box.
2012-10-08 14:24:00 +05:30
constrainBoard :: Board -> Cell -> Maybe Board
constrainBoard board cell@(Cell ix _ _) =
2012-10-24 20:00:11 +05:30
foldM (\board' unitf -> constrainCells cell board' (unitf board'))
(updateBoard board cell) [row, column, box]
2012-10-08 14:24:00 +05:30
where
(rowIx, colIx) = ix `divMod` 9
(rowIx', colIx') = ((rowIx `div` 3) * 3, (colIx `div` 3) * 3)
row board = cells board $ take 9 [rowIx * 9 ..]
column board = cells board $ take 9 [colIx, colIx + 9 ..]
box board =
2012-10-08 14:24:00 +05:30
cells board [r * 9 + c | r <- [rowIx' .. rowIx' + 2], c <- [colIx' .. colIx' + 2]]
2012-10-24 18:18:18 +05:30
-- Reads a board from a properly formatted string.
-- Returns Nothing is the string represents an invalid board.
2012-10-08 14:24:00 +05:30
readBoard :: String -> Maybe Board
readBoard str =
2012-10-24 18:18:18 +05:30
if length str /= 81
then Nothing
else foldM constrainBoard emptyBoard
. map (\(ix, n) -> Cell ix (bit $ digitToInt n) 1)
. filter (isDigit . snd)
. zip [0..] $ str
-- Represents a board as a string of 81 characters in a single line with each
-- character being either a digit between 1 to 9 if there is a solution for
-- that cell else a dot (.).
2012-10-08 14:24:00 +05:30
showBoard :: Board -> String
showBoard board =
zipWith (\(Cell _ val vl) dot ->
if vl == 1 then intToDigit . firstSol $ val else dot)
2012-10-27 11:07:22 +05:30
(boardCells board)
2012-10-08 14:24:00 +05:30
(repeat '.')
2012-10-24 18:18:18 +05:30
-- Pretty prints a Sudoku board.
2012-10-27 11:07:22 +05:30
prettyShowBoard :: Board -> String
prettyShowBoard board =
(\t -> line ++ "\n" ++ t ++ line ++ "\n")
. unlines . intercalate [line] . chunksOf 3
. map ((\r -> "| " ++ r ++ " |")
. intercalate " | " . map (intersperse ' ') . chunksOf 3)
. chunksOf 9
2012-10-08 14:24:00 +05:30
. showBoard $ board
where line = "+-------+-------+-------+"
2012-10-24 18:18:18 +05:30
-- Solves a Sudoku board using recursive backtracking DFS.
2012-10-08 14:24:00 +05:30
solveSudoku :: Board -> Maybe Board
solveSudoku board
2012-10-24 18:18:18 +05:30
-- if solved, return the board
| isBoardSolved board = Just board
2012-10-24 18:18:18 +05:30
-- if no more unsolved cells left then return Nothing
| S.null (ambCells board) = Nothing
2012-10-24 18:18:18 +05:30
-- if the current cell has no possible values, solve with rest cells
| val == 0 = solveSudoku $ board { ambCells = cs }
| otherwise = let
2012-10-24 18:18:18 +05:30
-- create two cells from current cell, one with only the smallest possible
-- value and second with the rest
fs = bit . firstSol $ val
nextPos = Cell ix fs 1
restPos = Cell ix (val .&. complement fs) (vl - 1)
boardR = updateBoard board restPos
2012-10-24 18:18:18 +05:30
-- try to constrain with the current cell with only one value
in case constrainBoard board nextPos of
2012-10-24 18:18:18 +05:30
-- if failed, continue with the current cell with the rest values
2012-10-24 20:00:11 +05:30
Nothing -> solveSudoku boardR
2012-10-24 18:18:18 +05:30
-- if solved, return the board
Just board' | isBoardSolved board' -> Just board'
2012-10-24 18:18:18 +05:30
-- else try to recursively solve the board further
| otherwise -> case solveSudoku board' of
2012-10-24 18:18:18 +05:30
-- if solved, return the board
Just board'' -> Just board''
2012-10-24 18:18:18 +05:30
-- else try to solve the board with the current cell
-- with the rest values
Nothing -> solveSudoku boardR
2012-10-08 14:24:00 +05:30
where
2012-10-24 18:18:18 +05:30
-- Finds the cell which has fewest possible values.
(Cell ix val vl, cs) = S.deleteFindMin (ambCells board)
2012-10-24 18:18:18 +05:30
-- Reads the puzzles from stdin and solves them
main :: IO ()
2012-10-08 14:24:00 +05:30
main = do
2012-10-24 18:18:18 +05:30
-- read the puzzles in chunks of 10
chunks <- fmap (chunksOf 10 . lines) getContents
2012-10-24 18:18:18 +05:30
-- spawn a thread for each chunk
2012-10-24 20:00:11 +05:30
solutionsVs <- forM chunks $ \chunk -> do
solutionsV <- newEmptyMVar
forkIO $ do
2012-10-24 20:00:11 +05:30
-- for each line in the chunk, read it as a Sudoku board and solve it
2012-10-24 18:18:18 +05:30
-- return solution as a string represented by showBoard if solvable else "Unsolvable"
2012-10-24 20:00:11 +05:30
-- return an error if invalid board
solutions <- forM chunk $ \line -> do
start <- getCPUTime
let sudoku = readBoard line
case sudoku of
Nothing -> return $ "Invalid input sudoku: " ++ line
Just board -> do
let !res = solveSudoku board
end <- getCPUTime
let diff = fromIntegral (end - start) / (10 ^ 9) :: Double
return $ printf "%s -> %s [%0.3f ms]" line
(maybe "Unsolvable" showBoard res) diff
2012-10-24 20:00:11 +05:30
putMVar solutionsV solutions
return solutionsV
2012-10-24 20:00:11 +05:30
-- wait for all thread to finish and print the solutions (or errors)
forM_ solutionsVs $ takeMVar >=> mapM_ putStrLn