open Prelude

type 'a t =
  | Builtin : ('param, 'inp, 'ret) Builtins.t -> ('param -> 'inp -> 'ret) t
  | Const : 'a Ty.t * 'a -> 'a t
  | App : ('a -> 'b) t * 'a t -> 'b t
  | PostCompose : ('a -> 'b) t * ('b -> 'c) t -> ('a -> 'c) t
  | Id : 'a Ty.t -> ('a -> 'a) t
  | Pair : 'a t * 'b t -> ('a * 'b) t

type 'a tm = 'a t

let rec typ_of : type a. a t -> a Ty.t = function
  | Builtin b -> Builtins.typ_of b
  | Const (typ, _) -> typ
  | Id typ -> Ty.(typ --> typ)
  | Pair (t, u) -> Ty.(typ_of t *** typ_of u)
  | App (t, _) -> (
    match typ_of t with
    | Ty.Arrow (_, sigma) -> sigma
    | _ -> failwith "typ_of" )
  | PostCompose (t, u) -> (
    match (typ_of t, typ_of u) with
    | Ty.Arrow (sigma, _), Ty.Arrow (_, tau) -> Ty.Arrow (sigma, tau)
    | _ -> failwith "typ_of" )

let apply_builtin :
    ('param, 'inp, 'out) Builtins.t -> 'param t -> ('inp -> 'out) t =
 fun b param -> App (Builtin b, param)

let unapply_builtin :
    type param.
    (param, 'inp, 'out) Builtins.t -> ('inp -> 'out) t -> param t optional =
 fun builtin term ->
  match term with
  | App (Builtin builtin', param) -> (
    match Ty.test builtin'.param builtin.param with
    | Ok Refl -> Ok param
    | _ -> fail "" )
  | _ -> fail ""

let compose :
    (unit, 'obj, 'out) Builtins.t -> ('out -> 'ret) t -> ('obj -> 'ret) t =
 fun b k -> PostCompose (App (Builtin b, Const (Ty.unit, ())), k)

let uncompose (type param inp out ret) (b : (param, inp, out) Builtins.t) :
    (inp -> ret) t -> (out -> ret) t optional = function
  | PostCompose (App (Builtin b', _), t) -> (
    match Ty.test b.Builtins.ret b'.Builtins.ret with
    | Ok Refl -> return t
    | _ -> fail "" )
  | _ -> fail ""

let rec eval : type a. a t -> a =
 fun term ->
  match term with
  | Const (_, v) -> v
  | Pair (t, u) -> (eval t, eval u)
  | Builtin b -> Builtins.eval b
  | App (t, u) -> eval t (eval u)
  | PostCompose (t, u) -> fun x -> eval u (eval t x)
  | Id _ -> fun x -> x

let rec show_aux : type a. a t -> string = function
  | Builtin b -> Printf.sprintf "%s" b.name
  | Pair (t, u) -> Printf.sprintf "(%s, %s)" (show t) (show u)
  | Const (ty, o) ->
      Printf.sprintf "!%s:%s" (Ty.show ty) (Ty.Methods.show ty o)
  | App (t, u) ->
      Printf.sprintf "%s @ %s" (show ~paren:true t) (show ~paren:true u)
  | PostCompose (t, u) ->
      Printf.sprintf ".[%s] %s" (show t) (show ~paren:true u)
  | Id typ -> Printf.sprintf "id:%s" (Ty.show typ)

and show : type a. ?paren:bool -> a t -> string =
 fun ?(paren = false) x ->
  let paren =
    paren
    &&
    match x with
    | Const _ | Builtin _ | Pair (_, _) | Id _ -> false
    | _ -> true
  in
  (if paren then Printf.sprintf "(%s)" else fun s -> s) (show_aux x)

module Parser = struct
  type packed = H : 'a Ty.t * 'a t -> packed

  type raw =
    | Builtin of string
    | Const of Ty.Object.obj
    | App of raw * raw
    | PostCompose of raw * raw
    | Id of Ty.Untyped.t

  open Parsing

  let rec inject : raw -> packed Parsing.t = function
    | Const (Ty.Object.O (typ, v)) ->
        Parsing.return (H (typ, Const (typ, v)))
    | Id (Ty.Untyped.T t) -> Parsing.return (H (Ty.(t --> t), Id t))
    | Builtin s ->
        let (Builtins.U b) = Builtins.find_by_name_untyped s in
        Parsing.return (H (Builtins.typ_of b, Builtin b))
    | App (t, u) -> (
        let* (H (ty, t')) = inject t in
        let* (H (sigma', u')) = inject u in
        match ty with
        | Ty.Arrow (sigma, tau) -> (
          match Ty.test sigma sigma' with
          | Ok Refl -> return @@ H (tau, App (t', u'))
          | _ ->
              fail
              @@ Printf.sprintf
                   "In application, type of argument %s is incompatible \
                    with argument type %s"
                   (Ty.show sigma') (Ty.show sigma) )
        | _ -> fail "" )
    | PostCompose (t, u) -> (
        let* (H (ty, t')) = inject t in
        let* (H (ty', u')) = inject u in
        match (ty, ty') with
        | Ty.Arrow (sigma, tau), Ty.Arrow (tau', upsilon) -> (
          match Ty.test tau tau' with
          | Ok Refl ->
              return @@ H (Ty.Arrow (sigma, upsilon), PostCompose (t', u'))
          | _ ->
              fail
              @@ Printf.sprintf
                   "In composition, middle types %s and %s do not match"
                   (Ty.show tau) (Ty.show tau') )
        | _ -> fail "inject: arrow types are not arrows" )

  let builtin : raw t =
    let* s = ident in
    return @@ Builtin s

  let id = keyword "id:" >> Ty.Untyped.conv.parse >>| fun x -> Id x

  let const = keyword "!" >> Ty.Object.conv.parse >>| fun n -> Const n

  let post_compose term =
    let* t1 = delim ".[" "]" term in
    let* t2 = trim >> term in
    return @@ PostCompose (t1, t2)

  let raw =
    fix (fun term ->
        let basic =
          const <|> post_compose term <|> parens term <|> builtin <|> const
          <|> id
        in
        chainl1 basic (keyword "@" >> return (fun x y -> App (x, y))))

  let packed = raw >>= inject
end

let parse (type a) (ty : a Ty.t) : a t Parsing.t =
  let open Parsing in
  Parser.packed
  >>= fun (Parser.H (ty', t)) ->
  match Ty.test ty ty' with
  | Ok Refl -> (return t : a tm Parsing.t)
  | _ ->
      fail
      @@ Printf.sprintf "Expected term of type %s, got term of type %s"
           (Ty.show ty) (Ty.show ty')

let conv ty = Conv.make ~show:(fun x -> show x) ~parse:(parse ty)
