open Prelude

type 'a custom = ..

type 'a t =
  | Arrow : 'a t * 'b t -> ('a -> 'b) t
  | List : 'a t -> 'a list t
  | Pair : 'a t * 'b t -> ('a * 'b) t
  | Custom : (string * 'a custom) -> 'a t

module Custom = struct
  type 'a data =
    { conv: 'a Conv.t option
    ; name: string
    ; friendly: string
    ; constructor: 'a custom
    ; cast: 'b. 'b custom -> ('a, 'b) Equality.t option
    ; compare: 'a -> 'a -> int }

  type entry = E : 'a data -> entry

  let dict = ref []

  let find : type a. string * a custom -> a data =
   fun (name, c) ->
    match List.find (fun (E data) -> name = data.name) !dict with
    | E data -> (
      match data.cast c with
      | Some Refl -> data
      | None -> failwith "find_custom:cast" )

  let add (type a) ?(conv : a Conv.t option) ?(compare = Stdlib.compare)
      ~friendly name =
    let module M = struct
      type 'a custom += T : a custom
    end in
    let cast (type b) : b custom -> (a, b) Equality.t option = function
      | M.T -> Some Equality.Refl
      | _ -> None
    in
    let data = {cast; friendly; compare; conv; name; constructor = M.T} in
    dict := E data :: !dict ;
    Custom (name, M.T)
end

let int = Custom.add ~friendly:"Number" ~conv:Conv.int "int"

let bool = Custom.add ~friendly:"Booleans" "bool" ~conv:Conv.bool

let unit = Custom.add ~friendly:"Unit" "unit" ~conv:Conv.unit

let string =
  Custom.add ~friendly:"String of characters" "string" ~conv:Conv.string
    ~compare:Sort.compare_string

let float = Custom.add ~friendly:"Decimal number" "float" ~conv:Conv.float

let rec show : type a. a t -> string = function
  | Custom c -> (Custom.find c).name
  | List ty -> Printf.sprintf "[%s]" (show ty)
  | Arrow (t, u) -> Printf.sprintf "(%s -> %s)" (show t) (show u)
  | Pair (t, u) -> Printf.sprintf "(%s, %s)" (show t) (show u)

let ( *** ) a b = Pair (a, b)

let ( --> ) a b = Arrow (a, b)

let list a = List a

let rec test : type a b. a t -> b t -> (a, b) Equality.t Prelude.optional =
 fun t u ->
  let open Equality in
  match (t, u) with
  | Custom c, Custom c' -> (
    match (Custom.find c).cast (snd c') with
    | None -> fail "Type %s and %s are not equal" (fst c) (fst c')
    | Some Refl -> return Refl )
  | List t, List u -> (
    match test t u with Error e -> Error e | Ok Refl -> Ok Refl )
  | Pair (t, u), Pair (t', u') -> (
    match (test t t', test u u') with
    | Ok Refl, Ok Refl -> Ok Refl
    | Error e, _ -> Error e
    | _, Error e -> Error e )
  | Arrow (t, u), Arrow (t', u') -> (
    match (test t t', test u u') with
    | Ok Refl, Ok Refl -> Ok Refl
    | Error e, _ -> Error e
    | _, Error e -> Error e )
  | a, b -> fail "Type %s and %s are not equal" (show a) (show b)

let equal t u = match test t u with Ok _ -> true | _ -> false

let cast : type a b. a t -> b t -> a -> b Prelude.optional =
 fun t u x -> match test t u with Error e -> Error e | Ok Refl -> Ok x

module Untyped = struct
  type _t = T : 'a t -> _t

  type descr =
    | List of descr
    | Pair of descr * descr
    | Arrow of descr * descr
    | Custom of string
  [@@deriving yojson]

  let rec project : type a. a t -> descr = function
    | List t -> List (project t)
    | Pair (t, u) -> Pair (project t, project u)
    | Arrow (t, u) -> Arrow (project t, project u)
    | Custom (s, _) -> Custom s

  let rec inject : descr -> _t = function
    | List t -> ( match inject t with T t' -> T (List t') )
    | Pair (t, u) -> (
      match (inject t, inject u) with T t, T u -> T (Pair (t, u)) )
    | Arrow (t, u) -> (
      match (inject t, inject u) with T t, T u -> T (Arrow (t, u)) )
    | Custom s -> (
      try
        find_map
          (fun (Custom.E data) ->
            if data.name = s then
              Some (T (Custom (data.name, data.constructor)))
            else None)
          !Custom.dict
      with Not_found ->
        failwith (Printf.sprintf "Unknown custom type %s" s) )

  let parse =
    let open Parsing in
    let simple = trim >> ident >>= fun s -> trim >> return (Custom s) in
    let list x = square x >>| fun x -> List x in
    let basic x = simple <|> list x <|> parens x in
    fix (fun ty ->
        chainl1 (basic ty)
          (keyword "->" >> return (fun sigma tau -> Arrow (sigma, tau))))

  let rec show = function
    | Custom s -> s
    | List l -> "[" ^ show l ^ "]"
    | Pair (s, t) -> Printf.sprintf "(%s, %s)" (show s) (show t)
    | Arrow (s, t) -> Printf.sprintf "(%s -> %s)" (show s) (show t)

  let conv =
    Conv.make
      ~show:(fun (T x) -> show (project x))
      ~parse:(Angstrom.map ~f:inject parse)

  type t = _t
end

module Methods = struct
  let compare : type a. a t -> a -> a -> int = function
    | Custom c -> (Custom.find c).compare
    | _ -> compare

  let rec conv : type a. a t -> a Conv.t = function
    | Custom (c, k) -> (
      try Option.get (Custom.find (c, k)).conv
      with _ ->
        failwith @@ Printf.sprintf "Custom type %s has no conv functions" c )
    | Arrow (_, _) -> failwith "Cannot conv functional value"
    | List typ -> Conv.list (conv typ)
    | Pair (a, b) -> Conv.pair (conv a) (conv b)

  let show ty = (conv ty).Conv.show

  let parse ty = (conv ty).Conv.parse

  let to_yojson ty x = `String (show ty x)

  let of_yojson ty = function
    | `String s -> Conv.parse (conv ty) s
    | _ -> fail ""
end

module Object = struct
  type obj = O : 'a t * 'a -> obj

  let unpack ty (O (ty', o)) =
    enter (cast ty' ty) o
      "When unpacking an object, expected type: %s, got %s" (show ty)
      (show ty')

  let show (O (x, y)) = Conv.show (Methods.conv x) y

  let parse =
    Parsing.(
      Untyped.conv.parse
      >>= fun (T typ) ->
      keyword ":" >> (Methods.conv typ).parse >>| fun v -> O (typ, v))

  let conv = Conv.make ~show ~parse

  type Markup.t += Link of obj * Markup.t
end

let ( = ) = equal

module Matcher (S : sig
  type 'a t
end) =
struct
  type clause = C : 'a t * 'a S.t -> clause

  type list_clause = {list: 'a. 'a S.t -> 'a list S.t}

  type pair_clause = {pair: 'a 'b. 'a S.t -> 'b S.t -> ('a * 'b) S.t}

  let on ty ~return = C (ty, return)

  let rec run_clauses : type a. a t -> clause list -> a S.t optional =
   fun ty -> function
    | [] -> fail ""
    | C (ty', x) :: q -> (
      match test ty ty' with Ok Refl -> return x | _ -> run_clauses ty q )

  let rec run :
      type a.
         ?list:list_clause
      -> ?pair:pair_clause
      -> clause list
      -> a t
      -> a S.t optional =
   fun ?list ?pair clauses ty ->
    match ty with
    | Pair (t, u) ->
        let* f = match pair with Some f -> return f | _ -> fail "" in
        let* t' = run ?list ?pair clauses t in
        let* u' = run ?list ?pair clauses u in
        return (f.pair t' u')
    | List t ->
        let* f = match list with Some f -> return f | _ -> fail "" in
        let* t' = run ?list ?pair clauses t in
        return @@ f.list t'
    | _ -> run_clauses ty clauses

  let run_def ?list ?pair ~default cs ty =
    Result.value ~default (run ?list ?pair cs ty)

  type create_function = {f: 'a. 'a t -> 'a S.t -> unit}

  type dictionary = clause list ref * create_function

  let create () =
    let r = ref [] in
    (r, {f = (fun a b -> r := on a ~return:b :: !r)})

  let add (_, {f}) = f

  let lookup ?list ?pair ~default (r, _) ty =
    run_def ?list ?pair !r ty ~default

  let lookup_opt ?list ?pair (r, _) ty = run ?list ?pair !r ty
end
