Interactive DAGs
  • Confounders
  • Mediators
  • Colliders

Colliders

d3 = require("d3@7")
jStat = require("jstat@1.9.6")
dag = import(new URL("js/dag-utils.js", document.baseURI).href)
Relationships between nodes
viewof strength_xz = Inputs.range([0, 1], {
  value: 0.5, 
  step: 0.05, 
  label: html`<span class="node node-x">X</span> → <span class="node node-z">Z</span> strength`
})

viewof strength_yz = Inputs.range([0, 1], {
  value: 0.5, 
  step: 0.05, 
  label: html`<span class="node node-y">Y</span> → <span class="node node-z">Z</span> strength`
})
Relationship between X and Y
viewof xy_exists = Inputs.toggle({
  label: html`<span class="node node-x">X</span> → <span class="node node-y">Y</span> exists (<em>true causal effect</em>)`
})

viewof strength_xy = Inputs.range([0, 1], {
  value: 0.3, 
  step: 0.05,
  label: html`<span class="node node-x">X</span> → <span class="node node-y">Y</span> strength`
})
Adjustments
viewof adjust_z = Inputs.toggle({
  label: html`<span class="node node-z">Adjust for Z</span> (<em>or "condition on"</em> <span class="node node-z">Z</span>)`
})
// ----------------
// Status readout
// ----------------
{
  const trueEffect = xy_exists ? strength_xy : 0;

  // Helper functions for formatting true negative signs
  const fmt_number = (x, d = 3) =>
    (x < 0 ? "\u2212" : "") + Math.abs(x).toFixed(d);
  const fmt_signed = (x, d = 3) =>
    (x >= 0 ? "+" : "\u2212") + Math.abs(x).toFixed(d);
  
  const bias = slope_cond - slope_all;

  return html`<div class="alert alert-secondary status-readout">
    <h5 class="alert-heading">Observed effects</h5>
    <table>
      <tr>
        <td>True <span class="node node-x">X</span> → <span class="node node-y">Y</span> effect</td>
        <td>${trueEffect === 0 ? "none" : fmt_number(trueEffect, 2)}</td>
      </tr>
      <tr>
        <td>Overall slope</td>
        <td>${fmt_number(slope_all)}</td>
      </tr>
      <tr class="summary ${adjust_z ? '' : 'dimmed'}">
        <td>Apparent slope (conditioning on <span class="node node-z">Z</span>)</td>
        <td>${adjust_z ? fmt_number(slope_cond) : "—"}</td>
      </tr>
      <tr>
        <td>N (selected / total)</td>
        <td>${adjust_z ? html`${n_selected} / ${simData.length}` : html`${simData.length} / ${simData.length}`}</td>
      </tr>
      <tr class="${adjust_z ? '' : 'dimmed'}">
        <td>Bias introduced</td>
        <td>${adjust_z
          ? html`<span class="${Math.abs(bias) > 0.05 ? 'text-danger' : ''}">${fmt_signed(bias)}</span>`
          : "—"
        }</td>
      </tr>
    </table>
  </div>`;
}
// Javascript doesn't have native seed functions like set.seed(), so we make one here
_seed = {
  // Mulberry32 seeded PRNG
  function mulberry32(a) {
    return () => {
      a |= 0; a = a + 0x6D2B79F5 | 0;
      let t = Math.imul(a ^ a >>> 15, 1 | a);
      t = t + Math.imul(t ^ t >>> 7, 61 | t) ^ t;
      return ((t ^ t >>> 14) >>> 0) / 4294967296;
    }
  }

  const rng = mulberry32(674751);  // From random.org
  const N = 500;
  const randn = (sd = 1) =>
    jStat.normal.inv(rng() * 0.998 + 0.001, 0, sd);

  return {
    xVals: Array.from({ length: N }, () => randn()),
    noiseY: Array.from({ length: N }, () => randn()),
    noiseZ: Array.from({ length: N }, () => randn(0.5))
  };
}
// ----------------
// Simulated data
// ----------------
// Z is binary: 1 if the latent combination exceeds 0
simData = {
  const { xVals, noiseY, noiseZ } = _seed;
  const N = xVals.length;
  const beta = xy_exists ? strength_xy : 0;

  return xVals.map((x, i) => {
    const y = beta * x + noiseY[i];
    const z_latent = strength_xz * x + strength_yz * y + noiseZ[i];
    const z = z_latent > 0 ? 1 : 0;
    return { x, y, z, group: z === 1 ? "Z = 1" : "Z = 0" };
  });
}
z1_points = simData.filter(d => d.z === 1)
z0_points = simData.filter(d => d.z === 0)
selected = adjust_z ? z1_points : simData
n_selected = selected.length
function ols_slope(data) {
  if (data.length < 3) return 0;
  const xs = data.map(d => d.x);
  const ys = data.map(d => d.y);
  const r = jStat.corrcoeff(xs, ys);
  return r * jStat.stdev(ys, true) / jStat.stdev(xs, true);
}

slope_all = ols_slope(simData)
slope_cond = ols_slope(selected)
// -----------------
// Interactive DAG
// -----------------
{
  const width = 600;
  const height = 250;
  const nodeRadius = 36;

  const nodes = {
    X: { x: 130, y: 200, label: "X" },
    Z: { x: width / 2, y: 60, label: "Z" },
    Y: { x: 470, y: 200, label: "Y" }
  };

  const svg = d3.create("svg")
    .attr("viewBox", `0 0 ${width} ${height}`)
    .attr("width", width)
    .attr("height", height)
    .style("max-width", "100%");

  const defs = svg.append("defs");

  dag.addArrowMarkers(defs);

  // Arrows
  const edges = [
    {
      id: "xz", from: nodes.X, to: nodes.Z,
      strength: strength_xz, blocked: false
    },
    {
      id: "yz", from: nodes.Y, to: nodes.Z,
      strength: strength_yz, blocked: false
    }
  ];

  if (xy_exists) {
    edges.push({
      id: "xy", from: nodes.X, to: nodes.Y,
      strength: strength_xy, blocked: false
    });
  }

  for (const edge of edges) {
    dag.drawEdge(svg, edge, nodeRadius);
  }

  // Add highlighted area behind X and Y when adjusting for Z
  if (adjust_z) {
    const padX = 8;
    const padY = 10;
    const left = nodes.X.x - nodeRadius - padX;
    const right = nodes.Y.x + nodeRadius + padX;
    const top = nodes.X.y - nodeRadius - padY;
    const h = (nodeRadius + padY) * 2;
    const w = right - left;

    svg.append("rect")
      .attr("x", left)
      .attr("y", top)
      .attr("width", w)
      .attr("height", h)
      .attr("rx", h / 2)
      .attr("ry", h / 2)
      .attr("fill", dag.colorZ)
      .attr("opacity", 0.18);
  }

  // Nodes
  dag.drawSolidNode(
    svg, nodes.X.x, nodes.X.y, nodeRadius, dag.colorX
  );

  dag.drawSolidNode(
    svg, nodes.Y.x, nodes.Y.y, nodeRadius, dag.colorY
  );

  // Z: solid gold, semi-transparent when not conditioning,
  // fully opaque when conditioning
  dag.drawSolidNode(
    svg, nodes.Z.x, nodes.Z.y, nodeRadius, dag.colorZ,
    adjust_z ? 1 : 0.35
  );

  // Labels
  for (const n of Object.values(nodes)) {
    dag.drawLabel(svg, n.x, n.y, n.label);
  }

  return svg.node();
}
// Scatterplot of points with regression line(s)
Plot.plot({
  width: 550,
  height: 340,
  style: { fontSize: "12px" },
  x: { label: "X" },
  y: { label: "Y" },
  color: {
    domain: ["Z = 0", "Z = 1"],
    range: [dag.colorZ0, dag.colorZ]
  },
  marks: [
    // Faded excluded points (Z = 0 when conditioning)
    adjust_z
      ? Plot.dot(z0_points, {
          x: "x", y: "y",
          fill: dag.colorZ0,
          r: 3,
          fillOpacity: 0.2
        })
      : null,

    // Active points
    Plot.dot(selected, {
      x: "x", y: "y",
      fill: "group",
      r: 4,
      fillOpacity: 0.75,
      stroke: "#fff",
      strokeWidth: 0.5
    }),

    // Overall regression line
    Plot.linearRegressionY(simData, {
      x: "x", y: "y",
      stroke: "#8b8b99",
      strokeWidth: 2
    }),

    // Conditioned regression line
    adjust_z
      ? Plot.linearRegressionY(z1_points, {
          x: "x", y: "y",
          stroke: dag.apparentLine,
          strokeWidth: 2.5,
          strokeDasharray: "8 5"
        })
      : null
  ].filter(Boolean)
})