More performance enhancements in SudokuSolver

This commit is contained in:
Abhinav Sarkar 2012-10-23 22:20:04 +05:30
parent 2532b5167b
commit d5a7f778f9
3 changed files with 51 additions and 29 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
*.hi
*.o
*.bak
*.prof
input
bin
dist

View File

@ -2,10 +2,10 @@
module Main where
import qualified Data.Map as M
import qualified Data.Set as S
import Control.Concurrent (forkIO)
import Control.Monad (foldM, forM_)
import qualified Data.HashMap.Strict as M
import Control.Concurrent (forkIO, newEmptyMVar, putMVar, takeMVar)
import Control.Monad (foldM, forM_, forM)
import Data.Char (digitToInt, intToDigit)
import Data.List (foldl', intersperse, intercalate, (\\))
import Data.List.Split (chunksOf)
@ -16,19 +16,23 @@ import Text.Printf (printf)
data Cell = Cell {-# UNPACK #-} !Int
![Int]
{-# UNPACK #-} !Int
deriving (Eq)
data Board = Board { ixMap :: !(M.Map Int Cell),
data Board = Board { ixMap :: !(M.HashMap Int Cell),
ambCells :: !(S.Set Cell)
} deriving (Eq, Ord, Show)
} deriving (Eq, Show)
instance Eq Cell where
{-# INLINE (==) #-}
(Cell i1 v1 _) == (Cell i2 v2 _) = i1 == i2 && v1 == v2
instance Show Cell where
show (Cell ix val _) = "<" ++ show ix ++ " " ++ show val ++ ">"
instance Ord Cell where
(Cell i1 v1 vl1) `compare` (Cell i2 v2 vl2)
| i1 == i2 && v1 == v2 = EQ
| otherwise = (vl1, i1) `compare`(vl2, i2)
(Cell i1 v1 vl1) `compare` (Cell i2 v2 vl2) =
if i1 == i2 && v1 == v2
then EQ
else (vl1, i1) `compare`(vl2, i2)
emptyBoard :: Board
emptyBoard =
@ -39,10 +43,11 @@ emptyBoard =
updateBoard :: Board -> Cell -> Board
updateBoard board@Board{..} cell@(Cell ix _ vl) = case M.lookup ix ixMap of
Nothing -> board
Just oldCell | vl == 1 -> Board (M.insert ix cell ixMap)
(S.delete oldCell ambCells)
| otherwise -> Board (M.insert ix cell ixMap)
(S.insert cell (S.delete oldCell ambCells))
Just oldCell | oldCell == cell -> board
| vl == 1 -> Board (M.insert ix cell ixMap)
(S.delete oldCell ambCells)
| otherwise -> Board (M.insert ix cell ixMap)
(S.insert cell (S.delete oldCell ambCells))
constrainCell :: Cell -> Board -> Cell -> Maybe Board
constrainCell cell@(Cell _ val vl) board@Board{..} c@(Cell i pos pl) =
@ -53,9 +58,18 @@ constrainCell cell@(Cell _ val vl) board@Board{..} c@(Cell i pos pl) =
| pl' == 1 && pl > 1 -> constrainBoard board (Cell i pos' pl')
| otherwise -> return $ updateBoard board (Cell i pos' pl')
where
pos' = pos \\ val
pos' = diff pos val
pl' = length pos'
diff :: [Int] -> [Int] -> [Int]
diff [] [] = []
diff xs [] = xs
diff [] _ = []
diff xa@(x:xs) ya@(y:ys)
| x == y = diff xs ys
| x < y = x : diff xs ya
| x > y = diff xa ys
constrainCells :: Cell -> Board -> [Cell] -> Maybe Board
constrainCells cell = foldM (constrainCell cell)
@ -83,7 +97,7 @@ readBoard str =
showBoard :: Board -> String
showBoard board =
zipWith (\(Cell _ val vl) dot -> if vl == 1 then intToDigit (head val) else dot)
(map snd . M.toList . ixMap $ board)
(M.elems . ixMap $ board)
(repeat '.')
printBoard :: Board -> IO ()
@ -113,21 +127,28 @@ solveSudoku board
Just board'' -> Just board''
Nothing -> solveSudoku boardR
where
((Cell ix val vl), cs) = S.deleteFindMin (ambCells board)
(Cell ix val vl, cs) = S.deleteFindMin (ambCells board)
isSolved = all (\(Cell _ _ vl) -> vl == 1) . M.elems . ixMap
main :: IO ()
main = do
lns <- fmap lines getContents
forM_ lns $ \line -> forkIO $ do
start <- getCPUTime
let sudoku = readBoard line
case sudoku of
Nothing -> putStrLn ("Invalid input sudoku: " ++ line)
Just board -> do
let !res = solveSudoku board
end <- getCPUTime
let diff = fromIntegral (end - start) / (10 ^ 12)
chunks <- fmap (chunksOf 50 . lines) getContents
threads <- forM chunks $ \chunk -> do
done <- newEmptyMVar
forkIO $ do
forM_ chunk $ \line -> do
start <- getCPUTime
let sudoku = readBoard line
case sudoku of
Nothing -> putStrLn ("Invalid input sudoku: " ++ line)
Just board -> do
let !res = solveSudoku board
end <- getCPUTime
let diff = fromIntegral (end - start) / (10 ^ 9)
putStrLn (printf "%s -> %s [%0.3f sec]" line
(maybe "Unsolvable" showBoard res) (diff :: Double))
putStrLn (printf "%s -> %s [%0.3f ms]" line
(maybe "Unsolvable" showBoard res) (diff :: Double))
putMVar done ()
return done
mapM_ takeMVar threads

View File

@ -85,7 +85,7 @@ executable SudokuSolver
containers == 0.4.*,
mtl == 2.1.*,
split == 0.2.1.*,
array == 0.4.*
unordered-containers == 0.2.1.*
main-is : SudokuSolver.hs
ghc-options : -threaded
default-language : Haskell2010