(* Type-directed memoization *) open Printf;; (* The type of the memoizer: it takes a function and returns a memoized function, of the same type. *) type ('a,'b) memo_dict = ('a -> 'b) -> ('a -> 'b);; (* Memoizing functions of primitive types: unit and bool *) let md_unit : (unit,'b) memo_dict = fun f -> let mref = ref None in fun () -> match !mref with | Some x -> x | None -> let v = f () in mref := Some v; v ;; (* We rely on the fact that BOOL = UNIT + UNIT, see below about sums *) let md_bool : (bool,'b) memo_dict = fun f -> let mref = (ref None, ref None) in fun x -> let check r = match !r with | Some v -> v | None -> let v = f x in r := Some v; v in check (if x then fst mref else snd mref) ;; (* Memoizing functions of complex, constructed types. The memoizer is constructed as a composition of memoizers for the functions of simpler types. The composition is type-directed. *) (* Memoizing a function whose argument is of the product type. *) (* We rely on the isomorphism of currying: ('a * 'b) -> 'c is isomorphic to 'a -> ('b -> 'c) *) let md_prod (mda : ('a,'b->'c) memo_dict) (mdb : ('b,'c) memo_dict) : ('a*'b,'c) memo_dict = fun f -> let fb x = mdb (fun y -> f (x,y)) in let fa = mda fb in fun (x,y) -> fa x y ;; (* Memoizing a function whose argument is of the sum type *) (* We rely on the fact that the type ('a + 'b) -> 'c is isomorphic to ('a -> 'c) * ('b -> 'c) *) type ('a,'b) either = Left of 'a | Right of 'b;; let md_sum (mda : ('a,'c) memo_dict) (mdb : ('b,'c) memo_dict) : (('a,'b) either,'c) memo_dict = fun f -> let mref = (mda (fun x -> f (Left x)), mdb (fun y -> f (Right y))) in function | Left x -> fst mref x | Right y -> snd mref y ;; (* A few tests *) let true = let f = md_unit (fun () -> printf "test1 f\n"; 1) in f () = f ();; (* test1 f is printed only once *) let true = let f = md_bool (fun x -> printf "test2 f\n"; not x) in f true = f true && f false = f false && f true = not (f false);; (* test2 f is printed only twice, once for the arg true and once for false *) let true = let f = md_prod md_bool md_bool (fun (x,y) -> printf "test3 f\n"; x <> y) in f (true,true) = f (true,true) && f (true,false) = f(true,false) && f (false,false) = f (false,false) && f (true,true) = f(false,false) ;; (* test3 f is printed 3 times *) (* Memoizing functions of a recursive type *) module type MU = sig type 'self t val mdt : ('self,'b) memo_dict -> ('self t,'b) memo_dict end;; (* We need an additional level of indirection (`fm' below is a reference cell) to prevent the divergence. The trick is not unlike the eta-expansion that converts the ordinary fixpoint combinator to the applicative one: We have to delay the computation of the fixpoint until we receive the argument for the fixpointed function. Still, the computed fixpoint should be shared among all applications of the memoized function -- hence the need for the reference cell. *) module FIX(S:MU) = struct type tfix = Fix of tfix S.t let rec md_fix : (tfix,'b) memo_dict = fun f -> let fm = ref None in function Fix x -> match !fm with | Some fixedm -> fixedm x | None -> let fixedm = S.mdt md_fix (fun x -> f (Fix x)) in fm := Some fixedm; fixedm x end;; (* Representing natural numbers as a recursive sum type: NAT = 1 + NAT or NAT = mu self.(unit + self) The OCaml notation below closely matches the categorical notation above. The expression constructing the memo table matches the structure of the type. *) module NAT = FIX(struct type 'self t = (unit,'self) either let mdt mself = md_sum md_unit mself end);; (* Useful functions for printing out NATs nicely *) let rec nat_of_int = function | 0 -> NAT.Fix (Left ()) | n -> NAT.Fix (Right (nat_of_int (pred n)));; let rec int_of_nat = function | NAT.Fix (Left ()) -> 0 | NAT.Fix (Right x) -> succ (int_of_nat x);; (* Create a few NATs for the tests *) let [nat0;nat1;nat2;nat5;nat10] = List.map nat_of_int [0;1;2;5;10];; let true = let f = NAT.md_fix (fun x -> printf "testnat f: %d\n" (int_of_nat x); NAT.Fix (Right x)) in f nat0 = f nat0 && f nat10 = f nat10 && f nat5 = f (f (f (f nat2))) ;; (* For each used value of x, the line testnat f is printed exactly once. *) (* Representing lists of booleans as a recursive sum of product type: BList = 1 + Bool * BList or BList = mu self.(unit + bool * self) Again the OCaml notation below closely matches the categorical notation above. The expression constructing the memo table matches the structure of the type. *) module BLST = FIX(struct type 'self t = (unit,bool*'self) either let mdt self = md_sum md_unit (md_prod md_bool self) end);; let nil = BLST.Fix (Left ()) let cons h t = BLST.Fix (Right (h,t)) ;; (* Useful conversion functions, for the sake of the tests *) let rec blst_of_list = function | [] -> nil | h::t -> cons h (blst_of_list t) ;; let rec show_blst = function | BLST.Fix (Left ()) -> "" | BLST.Fix (Right (h,t)) -> (if h then "T" else "F") ^ show_blst t;; (* Reversing of a BLST. We memoize that function later. *) let rev l = let rec loop acc = function | BLST.Fix (Left ()) -> acc | BLST.Fix (Right (h,t)) -> loop (cons h acc) t in loop nil l;; let true = let f = BLST.md_fix (fun x -> printf "testlst f: %s\n" (show_blst x); rev x) in let l0 = blst_of_list [] and l5 = blst_of_list [true;false;false;true;true] in f l0 = f l0 && f l5 = f (f (f l5)) ;; (* Printed output: testlst f: testlst f: TFFTT testlst f: TTFFT Again, for each unique list argument, only one line `testlst f' is printed. *)