Haskell eDSLs 101

The other day I started playing around with Accelerate, a Haskell eDSL for gpgpu computing. Accelerate provides us with multidimensional arrays and several functions to manipulate them that we can use to build expressions. These expressions can be compiled into Cuda code using the Cuda backend and then run on the gpu. To me it seemed like the library was imbued with some form of arcane magic, so I decided to investigate this eDSL deal further.

My first stop was the wiki page on Haskell eDSLs. Eventually, I stumbled upon several tutorials, and they all seemed to start with a language for a very simple calculator. I then decided to make my own version of that language and make a compiler that would translate expressions to something executable.

This post is therefore an introductory tutorial to Haskell eDSLs. My intent is that you get a first grasp of the idea so that you can then move on to more complicated matters.

Full source code can be reached here.

The Calc Language

The language we'll be building is one for a simple calculator that can add, subtract, and multiply numbers. We start off by formalising what an expression in this language looks like:

type Number = Integer
data Expr
    = Lit Number
    | Add Expr Expr
    | Sub Expr Expr
    | Mul Expr Expr
    deriving Show

The syntax should be readable even by non-Haskell programmers. What we're saying is that an expression can be either a literal, the sum of two expressions, the difference between two expressions or their product. The deriving Show part is simply to make the expressions printable in ghci.

Now we can go ahead and build expressions. The literal number 2 is simply

Lit 2

The expression 2 + 3 can be written as

Add (Lit 2) (Lit 3)

And we can even nest expressions, like in (2+3)*(4-2):

Mul (Add (Lit 2) (Lit 3)) (Sub (Lit 4) (Lit 2))

And so on. But of course, having to construct an AST manually is a pain; we'd like to type 2+3 and (2+3)*(4-2) just like we do regularly. That's when we make Expr an instance of Num:

instance Num Expr where
         e1 + e2  = Add e1 e2
         e1 - e2  = Sub e1 e2
         e1 * e2  = Mul e1 e2
         abs e    = e
         signum e = e
         fromInteger = Lit

The definitions for abs and signum are a cheat, but for this small example it's ok. Once we've rolled this instance we can start typing expressions just as if they were regular integers:

> (2+3)*4*2 :: Int
40
> (2+3)*4*2 :: Expr
Mul (Mul (Add (Lit 2) (Lit 3)) (Lit 4)) (Lit 2)

This is quite powerful. We type a bunch of expressions using everyday syntax, and Haskell builds an AST for us.

But that's not it. Notice that Expr is now a Num. That means that any function in Haskell that is polymorphic over Num can now be used to build an expression. Take, for instance, the list of all fibonacci numbers:

fib :: Num a => [a]
fib = 0 : 1 : zipWith (+) fib (tail fib)

Since all that fib requires is that elements can be added together, fib has type Num a => [a]. Now we can build a list of the first 5 fibonacci numbers as plain Ints:

> take 5 fib :: [Int]
[0,1,1,2,3]

But here's the cool part: Expr is an instance of Num, so we can build a list of Exprs as well:

> take 5 fib :: [Expr]
[ Lit 0
, Lit 1
, Add (Lit 0) (Lit 1)
, Add (Lit 1) (Add (Lit 0) (Lit 1))
, Add (Add (Lit 0) (Lit 1)) (Add (Lit 1) (Add (Lit 0) (Lit 1)))
]

Take a look at that! fib is building the expressions that yield the first 5 numbers in the sequence. The power of this is that we can use any function that evalutes to Num to build an Expr, such as fib. In other words, we can now use the host language, Haskell, to build complicated expressions in our Calc language.

The Calc Compiler

An expression is kind of useless if we can't do anything with it other than building it. What we wish now is being able to evaluate an Expr. We could write a simple interpreter for that matter, but I figured that a compiler would be cooler.

The compiler we are going to write translates an Expr into x86 Linux assembly. This assembly code will then be compiled and run like a regular assembly program, and we'll make that program return the result of the evaluated expression back into Haskell. For this task we'll be using nasm and ld.

First we define a helper function, nconcat, that concatenates a list of lists by intercalating a new line between each successive list:

nconcat = intercalate "\n"

Next we define functions for each operation that can be done on expressions, namely addition, subtraction and multiplication. The convention we take is that these functions read their arguments off the stack and return the result in the eax register.

add = nconcat
    [ "add:"
    , "mov eax, [esp+4]"
    , "mov ebx, [esp+8]"
    , "add eax, ebx"
    , "ret"
    , ""
    ]

sub = nconcat
    [ "sub:"
    , "mov ebx, [esp+4]"
    , "mov eax, [esp+8]"
    , "sub eax, ebx"
    , "ret"
    , ""
    ]

mul = nconcat
    [ "mul:"
    , "mov eax, [esp+4]"
    , "mov ebx, [esp+8]"
    , "mul ebx"
    , "ret"
    , ""
    ]

Moving on, we define the compile' function, which translates an expression into a string:

compile' :: Expr -> String
compile' (Lit x) = "mov eax, " ++ show x
compile' (Add x y) = binOp "add" x y
compile' (Sub x y) = binOp "sub" x y
compile' (Mul x y) = binOp "mul" x y

The compile' function relies on binOp, which we define next. binOp takes a function name and two expressions and applies that function to the evaluations of the given expressions:

type Op = String

binOp :: Op -> Expr -> Expr -> String
binOp op x y
    = nconcat
    [ compile' x
    , "push eax"
    , compile' y
    , "push eax"
    , "call " ++ op
    , "add esp, 8"
    ]

So binOp and compile' are mutually recursive. compile' compiles a single expression, using binOp when this expression is a function of two other expressions.

A question that arises is how to make the assembly program return the result of an evaluation back to Haskell. Since the Calc language only defines expressions that evaluate to integers, we're going to make a little hack and make the assembly program return the result via its exit code. For this purpose, we define the exit function:

exit
    = nconcat
    [ "exit:"
    , "mov ebx, [esp+4]"
    , "mov eax, 1"
    , "int 0x80"
    , ""
    ]

exit reads a number from the stack and exits with that number as the exit code. int 0x80 is the way we perform a syscall on Linux, eax=1 is how we instruct the kernel to perform an exit, and ebx holds the exit code.

This exit code hack has one limitation, which is that only values in the range 0..255 can be returned. For our illustrative purposes this is fine, however.

Now we have all of the elements to build an assembly program. For readability, we define a program to be

newtype Prog = Compute Expr deriving Show

Next we define the function that compiles a program:

compile :: Prog -> String
compile (Compute e)
        = nconcat
        [ header
        , add
        , sub
        , mul
        , exit
        , "_start:"
        , compile' e
        , "push eax"
        , "call exit"
        ]

where header is defined as

header
    = nconcat
    [ "BITS 32"
    , "section .text"
    , "global _start"
    , ""
    ]

This header code is just a bunch of directives nasm expects. BITS 32 tells nasm we're making a 32-bit program. section .text specifies that we are defining the .text section, where the executable code is, and global _start specifies the entry point.

The compile' function takes an expression, compiles it and wraps it with the header, the functions on expressions and a call to exit that returns the result as the program's exit code.

To visualise all of this, let's compile an example expression to see the code that is produced:

> let e = 17 :: Expr
> compile (Compute e)

The resulting code is

BITS 32
section .text
global _start

add:
mov eax, [esp+4]
mov ebx, [esp+8]
add eax, ebx
ret

sub:
mov ebx, [esp+4]
mov eax, [esp+8]
sub eax, ebx
ret

mul:
mov eax, [esp+4]
mov ebx, [esp+8]
mul ebx
ret

exit:
mov ebx, [esp+4]
mov eax, 1
int 0x80

_start:
mov eax, 17
push eax
call exit

Notice how in _start, the value 17 is moved to eax, pushed onto the stack and followed by a call to exit. This makes the program quit with exit code 17.

The following code is what the expression 2*3 compiles to, omitting all of the boilerplate:

_start:
mov eax, 2
push eax
mov eax, 3
push eax
call mul
add esp, 8
push eax
call exit

The generated code could be better, for example by pushing the literals 2 and 3 directly instead of moving them into eax and then pushing eax, but for our purposes it's sufficient.

Now we need to compile the generated assembly code. This is exactly what the nasm function does:

nasm :: String -> IO String
nasm code
     = do writeFile "foo.s" code
          system "nasm -f elf foo.s"
          system "ld -o foo foo.o"
          return "./foo"

The nasm function takes some code, dumps it into foo.s, compiles it with nasm, links it with ld, and then returns the command that we must execute to run the generated program.

Finally, we define the run function, which takes a program, compiles it, runs it and interprets the result:

run :: Prog -> IO Int
run prog =
    let code = compile prog
    in nasm code >>= system >>= return . readExit

readExit :: ExitCode -> Int
readExit ExitSuccess = 0
readExit (ExitFailure x) = x

And voila. Now we can compute those fibonacci numbers and any expression that we fancy:

> let e = fib !! 6 :: Expr
> run . Compute $ e
8
> fib !! 6 :: Int
8
> 2*3 + 5*6 - 3
33
> run . Compute $ 2*3 + 5*6 - 3
33

Where To Go From Here

The Calc language is easy to model because all expressions evaluate to the same type: Integer. As soon as we add expressions of different types the language gets more complicated and we need something like a GADT. The GADTs section on the wiki has an excellent tutorial on modeling more sophisticated languages, so it is a good step to take from here.