diff --git a/3/3.hs b/3/3.hs index 7429be9..312593e 100644 --- a/3/3.hs +++ b/3/3.hs @@ -72,6 +72,10 @@ data IntervalTree a b = Node { itLeft :: IntervalTree a b } | Empty a a deriving (Show, Eq) +rightOf, leftOf :: Ord a => Interval a -> a -> Bool +rightOf (Interval (start, _)) x = x < start +leftOf (Interval (_, end)) x = end <= x + insert :: (Ord a, Bits a, Num a) => (Interval a, b) -> IntervalTree a b -> IntervalTree a b insert o@(interval, _) tree = case tree of Empty start end -> go start end (start + half (end - start)) @@ -84,8 +88,6 @@ insert o@(interval, _) tree = case tree of | interval `rightOf` center = Node (Empty start center) center [] (insert o (Empty center end)) | otherwise = Node (Empty start center) center [o] (Empty center end) - rightOf (Interval (start, _)) x = x < start - leftOf (Interval (_, end)) x = end <= x half = flip shift (-1) includingIntervals :: Ord a => Interval a -> IntervalTree a b -> [(Interval a, b)] @@ -93,7 +95,11 @@ includingIntervals interval = go [] where go acc t = case t of Empty _ _ -> acc - Node l _ is r -> go (go (filter (\(i,_) -> i `includes` interval) is ++ acc) l) r + Node l center is r | interval `leftOf` center -> go (acc' is acc) l + Node l center is r | interval `rightOf` center -> go (acc' is acc) r + Node l center is r -> go (go (acc' is acc) l) r + where + acc' is acc = filter (\(i,_) -> i `includes` interval) is ++ acc includes (Interval (start1, end1)) (Interval (start2, end2)) = start1 <= start2 && end2 <= end1 @@ -121,8 +127,8 @@ intervalTreeSolve rects = (xTree, yTree) = rectIntervalTrees rects overlapArea = length . filter (\c -> isOverlapCell (cellRects xTree yTree c) c) $ cells in overlapArea - where - cellRects xTree yTree (x,y) = + where + cellRects xTree yTree (x,y) = nub . map snd $ includingIntervals (Interval (x, x+1)) xTree ++ includingIntervals (Interval (y, y+1)) yTree