144 lines
5.3 KiB
Haskell
144 lines
5.3 KiB
Haskell
|
{-# LANGUAGE BangPatterns #-}
|
||
|
|
||
|
module Main where
|
||
|
|
||
|
import qualified Data.Array as A
|
||
|
import qualified Data.Map as M
|
||
|
import qualified Data.Set as S
|
||
|
import Control.Monad (foldM, forM_)
|
||
|
import Data.Char (digitToInt, intToDigit)
|
||
|
import Data.List (foldl', intersperse, (\\), sortBy, groupBy)
|
||
|
import Data.List.Split (splitEvery)
|
||
|
import Data.Maybe (fromJust)
|
||
|
import Data.Ord (comparing)
|
||
|
import System.CPUTime (getCPUTime)
|
||
|
import Text.Printf (printf)
|
||
|
|
||
|
countingSortBy f lo hi =
|
||
|
concatMap snd
|
||
|
. A.assocs . A.accumArray (\ e a -> a : e) [] (lo, hi) . map (\i -> (f i, i))
|
||
|
|
||
|
data Cell = Cell !Int ![Int] deriving (Eq, Ord)
|
||
|
type Board = M.Map Int Cell
|
||
|
|
||
|
instance Show Cell where
|
||
|
show cell@(Cell ix val) = "<" ++ show ix ++ " " ++ show val ++ ">"
|
||
|
|
||
|
emptyBoard =
|
||
|
foldl' (\m i -> M.insert i (Cell i [1..9]) m) M.empty [0..80]
|
||
|
|
||
|
constrainCell cell@(Cell ix val) board c@(Cell i pos) =
|
||
|
case () of _
|
||
|
| c == cell -> return board
|
||
|
| null pos' && length val == 1 -> Nothing
|
||
|
| null pos' -> return board
|
||
|
| length pos' == 1 && length pos > 1 -> constrainBoard board (Cell i pos')
|
||
|
| otherwise -> return $ M.insert i (Cell i pos') board
|
||
|
where pos' = pos \\ val
|
||
|
|
||
|
constrainCells :: Cell -> S.Set Cell -> Board -> [Cell] -> Maybe Board
|
||
|
constrainCells cell@(Cell ix val) seen board unit = do
|
||
|
board' <- foldM (constrainCell cell) board unit
|
||
|
let unit' = map (\(Cell ix _) -> fromJust $ M.lookup ix board') $ unit
|
||
|
|
||
|
foldM (\board'' cell' ->
|
||
|
constrainCells cell' (S.insert cell' seen) board'' unit')
|
||
|
board' (dups unit')
|
||
|
where
|
||
|
dups = filter (not . flip S.member seen) . map head
|
||
|
. filter (\xs@(Cell _ v : _) -> length xs == length v)
|
||
|
. groupBy (\(Cell _ v1) (Cell _ v2) -> v1 == v2)
|
||
|
. sortBy (comparing (\(Cell _ val) -> val))
|
||
|
. filter (\(Cell _ v) -> length v > 1 && length v < 4)
|
||
|
|
||
|
constrainBoard :: Board -> Cell -> Maybe Board
|
||
|
constrainBoard board cell@(Cell ix _) =
|
||
|
foldM (\board'' unitf ->
|
||
|
constrainCells cell (S.singleton cell) board'' (unitf cell board''))
|
||
|
(M.insert ix cell board) [row, column, box]
|
||
|
where
|
||
|
(rowIx, colIx) = ix `divMod` 9
|
||
|
(rowIx', colIx') = ((rowIx `div` 3) * 3, (colIx `div` 3) * 3)
|
||
|
|
||
|
cells board = map (fromJust . flip M.lookup board)
|
||
|
row (Cell ix _) board = cells board $ take 9 [rowIx * 9 ..]
|
||
|
column (Cell ix _) board = cells board $ take 9 [colIx, colIx + 9 ..]
|
||
|
box (Cell ix _) board =
|
||
|
cells board [r * 9 + c | r <- [rowIx' .. rowIx' + 2], c <- [colIx' .. colIx' + 2]]
|
||
|
|
||
|
readBoard :: String -> Maybe Board
|
||
|
readBoard str =
|
||
|
foldM constrainBoard emptyBoard
|
||
|
. map (\(ix, n) -> Cell ix [digitToInt n])
|
||
|
. filter ((/= '.') . snd)
|
||
|
. zip [0..] $ str
|
||
|
|
||
|
showBoard :: Board -> String
|
||
|
showBoard board =
|
||
|
zipWith (\(Cell _ val) dot ->
|
||
|
if length val == 1 then intToDigit (head val) else dot)
|
||
|
(map snd . M.toList $ board)
|
||
|
(repeat '.')
|
||
|
|
||
|
printBoard :: Board -> IO ()
|
||
|
printBoard board =
|
||
|
putStrLn
|
||
|
. (\t -> line ++ "\n" ++ t ++ line ++ "\n")
|
||
|
. unlines . concat . intersperse [line] . splitEvery 3
|
||
|
. map ((\r -> "| " ++ r ++ " |") . concat
|
||
|
. intersperse " | " . map (intersperse ' ') . splitEvery 3)
|
||
|
. splitEvery 9
|
||
|
. showBoard $ board
|
||
|
where line = "+-------+-------+-------+"
|
||
|
|
||
|
solveSudoku :: Board -> Maybe Board
|
||
|
solveSudoku = fst . flip solve S.empty
|
||
|
where
|
||
|
solve board invalid = go (cells board) board invalid
|
||
|
|
||
|
cells board =
|
||
|
-- countingSortBy (\(Cell _ val) -> length val) 2 9
|
||
|
sortBy (comparing (\(Cell _ val) -> length val))
|
||
|
. filter (\(Cell _ val) -> length val /= 1)
|
||
|
. map snd . M.toList $ board
|
||
|
|
||
|
isSolved = all (\(Cell _ val) -> length val == 1) . M.elems
|
||
|
|
||
|
go [] board invalid = (Nothing, S.insert board invalid)
|
||
|
go (Cell ix val : cs) board invalid
|
||
|
| S.member board invalid = (Nothing, invalid)
|
||
|
| null val = go cs board (S.insert board invalid)
|
||
|
| otherwise = let
|
||
|
nextPos = Cell ix [head val]
|
||
|
restPos = Cell ix (tail val)
|
||
|
board' = M.insert ix nextPos board
|
||
|
in case constrainBoard board nextPos of
|
||
|
Nothing -> go (restPos : cs) board (S.insert board' invalid)
|
||
|
Just board'' | isSolved board'' -> (Just board'', invalid)
|
||
|
| otherwise -> let
|
||
|
(mBoard', invalid') = solve board'' invalid
|
||
|
in case mBoard' of
|
||
|
Just board''' -> (Just board''', invalid')
|
||
|
Nothing ->
|
||
|
go (restPos : cs) board (S.insert board'' invalid')
|
||
|
|
||
|
main = do
|
||
|
lns <- fmap lines getContents
|
||
|
forM_ lns $ \line -> do
|
||
|
start <- getCPUTime
|
||
|
let sudoku = readBoard line
|
||
|
case sudoku of
|
||
|
Nothing -> putStrLn ("Invalid input sudoku: " ++ line)
|
||
|
Just board -> do
|
||
|
let !res = solveSudoku board
|
||
|
end <- getCPUTime
|
||
|
let diff = (fromIntegral (end - start)) / (10^12)
|
||
|
|
||
|
putStrLn (printf "%s -> %s [%0.3f sec]" line (maybe "Unsolvable" showBoard res) (diff :: Double))
|
||
|
|
||
|
--putStrLn (printf "Time taken: %0.3f sec" (diff :: Double))
|
||
|
--printBoard board
|
||
|
|
||
|
--case res of
|
||
|
-- Nothing -> putStrLn "Unsolvable"
|
||
|
-- Just board' -> printBoard board'
|