Friday, December 8, 2006

Haskell code for ifThenElse

if/then/else can be so ugly. I prefer the C ternary operator (cond ? t : f)

In Haskell, since functions are first class objects, this becomes especially powerful:


module ArrowChoiceOps((?),(??),(???),(????)) where

import Control.Arrow
import Data.Either

infix 1 ?, ??, ???, ????

(?) :: Bool -> a -> Either a a
(?) True = Left
(?) False = Right

(??) :: Bool -> (a, b) -> Either a b
(??) True = Left . fst
(??) False = Right . snd

(???) :: (a -> Bool) -> (Either a a -> d) -> a -> d
p ??? q = (p &&& arr id) >>> uncurry (?) >>> q

(????) :: ((a, b) -> Bool) -> (Either a b -> d) -> (a, b) -> d
p ???? q = (p &&& arr id) >>> uncurry (??) >>> q


No code is complete without a test harness:


module Main where

import ArrowChoiceOps
import Control.Monad
import Control.Arrow
import Data.Either
import Data.Maybe


-- |Famous general recursion problem, does this halt for all n?
recurse :: Int -> [Int]
recurse n = takeWhile (> 0) . iterate takeAStep $ n
where takeAStep = (<= 1) ??? const 0
||| (even ??? (`div` 2)
||| (+1) . (*3))

-- |Faking mplus if Maybe weren't a MonadPlus
mplusMaybe :: Maybe a -> Maybe a -> Maybe a
mplusMaybe = curry (isJust . fst ???? id ||| id)

-- |Fake abs
abs' :: (Num a, Ord a) => a -> a
abs' = (< 0) ??? negate ||| id

checks :: [Bool]
checks =
[
(True ? 3 ) == (Left 3),
(False ? 4 ) == (Right 4),
(True ?? (3,4)) == (Left 3),
(False ?? (3,4)) == (Right 4),
(even ??? (`div` 2) ||| (+1).(*3) $ 3 ) == ( 10),
(even ??? (`div` 2) ||| (+1).(*3) $ 4 ) == ( 2),
(uncurry (==) ???? (+1) ||| (*3) $ (3,3)) == ( 4),
(uncurry (==) ???? (+1) ||| (*3) $ (3,4)) == ( 12),
(uncurry (==) ???? (+1) +++ (*3) $ (3,3)) == (Left 4),
(uncurry (==) ???? (+1) +++ (*3) $ (3,4)) == (Right 12),
(Nothing `mplusMaybe` Nothing :: Maybe ()) ==
(Nothing `mplus` Nothing :: Maybe ()),
(Just 3 `mplusMaybe` Nothing) ==
(Just 3 `mplus` Nothing),
(Nothing `mplusMaybe` Just 4) ==
(Nothing `mplus` Just 4),
(Just 3 `mplusMaybe` Just 4) ==
(Just 3 `mplus` Just 4),
(map recurse [-5..5]) == [[],[],[],[],[],[],[1],[2,1],
[3,10,5,16,8,4,2,1],[4,2,1],[5,16,8,4,2,1]],
(map abs' [-5..5]) == (map abs [-5..5])
]

main :: IO ()
main = print (and checks)