#import "@preview/cetz:0.1.2"
#import cetz.draw: *

// quick reference
// - `graphic` is for inline icons/images
// - `canvas` is for centered more prominent stuff
// - `group` can be used in either of ^, but not standalone

#let graphic(what) = box({
  cetz.canvas({
    // any preamble-ish stuff can go here
    set-style(
      mark: (angle: 90deg),
      stroke: (cap: "round", join: "round"),
    )

    what
  })
})

#let canvas(what) = {
  align(center, graphic(what))
}


// smaller stuff

#let arrow(length: 0.4cm, lift: 3pt, stroke: 1pt) = graphic({
  line((0, lift), (rel: (length, 0)), mark: (end: ">", stroke: stroke))

  // hack for the bounding box bottom
  // so that `lift` even has any effect
  line((0, 0), (0, 0), stroke: none)
})

// larger stuff

#let sequence(
  distance: 3cm,
  arrow-spacing: 0.15cm,
  // cetz will support rounded rects in 0.2.0
  style: (frame: "rect", padding: 0.1cm),
  ..labels,
) = group({
  let labels = labels.pos()

  // draw each label itself
  for (i, label) in labels.enumerate() {
    if i != 0 {
      set-origin((distance, 0))
    }
    content((0, 0), name: "label-" + str(i), label, ..style)
  }

  // then draw an arrow from each to each
  // since an arrow is between two, the last one can't be connected with the "next-to-last" one
  // so we leave it out
  for i in range(labels.len() - 1) {
    line(
      (rel: (arrow-spacing, 0), to: "label-" + str(i) + ".right"),
      (rel: (-arrow-spacing, 0), to: "label-" + str(i + 1) + ".left"),
      mark: (end: ">"),
    )
  }
})

#let stages-overview = canvas({
  sequence(
    [Source],
    [Graph IR],
    [Runtime],
  )
})

// A few commands to help demonstration in the docs.
// Supply a string to mark the input or output as simple.
// (fwiw in typst, parenthesis around a single expression just evaluate the expression, and don't put it into an array)
#let cmds = (
  "const": (
    inputs: (),
    outputs: ("data",),
  ),
  "open": (
    inputs: ("path",),
    outputs: ("data",),
  ),
  "save": (
    inputs: ("data", "path"),
    outputs: (),
  ),
  "show": (
    inputs: ("data",),
    outputs: (),
  ),
  "invert": (
    inputs: ("base",),
    outputs: ("",),
  ),
  "mask": (
    inputs: ("base", "stencil"),
    outputs: ("masked", "rest"),
  ),
)

#let opposite(anchor) = {
  (
    "bottom": "top",
    "top": "bottom",
  )
  .at(anchor)
}

#let sockets(
  start,
  stop,
  sockets,
  socket-size: (0.5, 0.1),
  socket-shape: "circle",
  parent-name: "",
  label-anchor: "bottom",
) = {
  for (i, socket) in sockets.enumerate() {
    let x-ratio = (i + 1) / (sockets.len() + 1)
    let center = (start, x-ratio, stop)

    let socket-name = parent-name + "/" + socket

    let common-args = (name: socket-name, fill: black)
    if socket-shape == "rect" {
      rect(
        (rel: ((0, 0), -0.5, socket-size), to: center),
        (rel: socket-size),
        ..common-args,
      )
    } else if socket-shape == "circle" {
      circle(
        center,
        radius: socket-size.at(1),
        ..common-args,
      )
    } else {
      panic("unknown socket shape: `" + socket-shape + "`")
    }
    set-style(fill: none)

    // don't ask why, I don't know myself
    let use-opposite-anchor = socket-shape == "circle"
    content(
      socket-name + "." + if use-opposite-anchor { opposite(label-anchor) } else { label-anchor },
      anchor: opposite(label-anchor),
      box(inset: 0.25em, text(8pt, socket)),
    )
  }
}

#let node(
  at,
  size: (3, 1.5),
  ty: none,
  body: none,
  name: "unnamed",
) = {
  set-origin(at)
  let label = [#ty]
  if body != none {
    label += [\ ] + text(0.7em, font: "IBM Plex Mono", body)
    size.at(1) += 0.5
  }
  rect((0, 0), (rel: size), name: name)
  content(((0, 0), 0.5, size), align(center, label))

  // input and output sockets
  if ty == none { return }
  let ty = cmds.at(ty)

  let sockets = sockets.with(parent-name: name)
  sockets(
    ((0, 0), "|-", size),
    size,
    ty.inputs,
  )
  sockets(
    (0, 0),
    ((0, 0), "-|", size),
    label-anchor: "top",
    ty.outputs,
  )

  // helper text
  let helper(base, label, where) = {
    if not type(base) != list or base.len() != 0 {
      content(
        name + "." + where + "-left",
        anchor: "right",
        box(inset: 0.25em, text(fill: luma(75%), label))
      )
    }
  }

  helper(ty.inputs, [in], "top")
  helper(ty.outputs, [out], "bottom")

  // reset the origin transform so other nodes can still work in the global coord system
  // can't use groups since otherwise the anchors are not exported
  set-origin(((0, 0), -1, at))
}

#let connect(from, to, bend: 1.5, mark-cfg: (size: 0.25, offset: 0.1)) = {
  bezier(
    from,
    to,
    (rel: (0, -bend * 1cm), to: from),
    (rel: (0, bend * 1cm), to: to),
  )
  mark(
    (rel: (0, mark-cfg.size + mark-cfg.offset), to: to),
    (rel: (0, -mark-cfg.size)),
    symbol: ">",
  )
}

// i wonder if layouting could be automatized
// if the graph is guaranteed to be acyclic,
// then we could just lay them out in "columns"
#let graph-example = canvas({
  let x = 2.25
  let y = -3
  node((-x, -0.75 * y), ty: "const", body: "\"base.png\"", name: "base")
  node((x, -0.75 * y), ty: "const", body: "\"stencil.png\"", name: "stencil")
  node((-x, 0), ty: "open", name: "a")
  node((x, 0), ty: "open", name: "b")
  node((0, y), ty: "mask", name: "c")
  node((-x, 2 * y), ty: "invert", name: "d")
  node((-x, 2.75 * y), ty: "show", name: "e")
  node((x, 2.75 * y), ty: "show", name: "f")

  connect("base/data", "a/path")
  connect("stencil/data", "b/path")

  connect("a/data", "c/base")
  connect("b/data", "c/stencil")

  connect("c/masked", "d/base")
  connect("d/", "e/data")

  connect("c/rest", "f/data", bend: 2.5)
})

// literally just for standalone display of the graphics alone
#import "../../template.typ": conf
#show: conf.with(render-outline: false)
#set page(width: auto, height: auto)

#graph-example