diff --git a/3/3.hs b/3/3.hs index 312593e..649f8e0 100644 --- a/3/3.hs +++ b/3/3.hs @@ -3,8 +3,10 @@ module Main where import Control.Applicative (some) import Data.Bits (Bits(shift)) +import Data.Function (on) +import qualified Data.Set as Set import qualified Data.Tree as T -import Data.List (maximumBy, foldl', nub) +import Data.List (maximumBy, foldl', sort) import Data.Ord (comparing) import Text.Parsec hiding (Empty) @@ -15,7 +17,13 @@ data Rect = Rect { rectID :: Int , rectTop :: Int , rectWidth :: Int , rectHeight :: Int - } deriving (Eq) + } + +instance Eq Rect where + (==) = (==) `on` rectID + +instance Ord Rect where + compare = compare `on` rectID instance Show Rect where show (Rect id l t w h) = "#" ++ show id ++ " " ++ show l ++ "," ++ show t ++ ":" ++ show (l+w) ++ "," ++ show (t+h) @@ -60,7 +68,7 @@ bruteForceSolve rects = ---------------- Interval tree ---------------- -newtype Interval a = Interval (a,a) deriving (Eq) +newtype Interval a = Interval { unInterval :: (a,a) } deriving (Eq, Ord) instance Show a => Show (Interval a) where show (Interval (a, b)) = "<" ++ show a ++ "," ++ show b ++ ">" @@ -76,12 +84,12 @@ rightOf, leftOf :: Ord a => Interval a -> a -> Bool rightOf (Interval (start, _)) x = x < start leftOf (Interval (_, end)) x = end <= x -insert :: (Ord a, Bits a, Num a) => (Interval a, b) -> IntervalTree a b -> IntervalTree a b +insert :: (Ord a, Ord b, Bits a, Num a) => (Interval a, b) -> IntervalTree a b -> IntervalTree a b insert o@(interval, _) tree = case tree of Empty start end -> go start end (start + half (end - start)) Node l center is r | interval `leftOf` center -> Node (insert o l) center is r Node l center is r | interval `rightOf` center -> Node l center is (insert o r) - Node l center is r -> Node l center (o:is) r + Node l center is r -> Node l center (sort (o:is)) r where go start end center | interval `leftOf` center = Node (insert o (Empty start center)) center [] (Empty center end) @@ -95,16 +103,18 @@ includingIntervals interval = go [] where go acc t = case t of Empty _ _ -> acc - Node l center is r | interval `leftOf` center -> go (acc' is acc) l - Node l center is r | interval `rightOf` center -> go (acc' is acc) r - Node l center is r -> go (go (acc' is acc) l) r + Node l _ is _ | not (null is) && interval `leftOf` leftmostStart is -> go acc l + Node l center is _ | interval `leftOf` center -> go (acc' is acc) l + Node _ center is r | interval `rightOf` center -> go (acc' is acc) r + Node l _ is r -> go (go (acc' is acc) l) r where acc' is acc = filter (\(i,_) -> i `includes` interval) is ++ acc + leftmostStart = fst . unInterval . fst . head includes (Interval (start1, end1)) (Interval (start2, end2)) = start1 <= start2 && end2 <= end1 -fromList :: (Ord a, Bits a, Num a) => a -> a -> [(Interval a, b)] -> IntervalTree a b +fromList :: (Ord a, Ord b, Bits a, Num a) => a -> a -> [(Interval a, b)] -> IntervalTree a b fromList start end = foldl' (flip insert) (Empty start end) rectIntervalTrees :: [Rect] -> (IntervalTree Int Rect, IntervalTree Int Rect) @@ -132,6 +142,8 @@ intervalTreeSolve rects = nub . map snd $ includingIntervals (Interval (x, x+1)) xTree ++ includingIntervals (Interval (y, y+1)) yTree + nub = Set.toList . Set.fromList + main :: IO () main = do rects <- readInput . lines <$> getContents