From fa2893bc77024915112fec79374e4b82b8b9c8b5 Mon Sep 17 00:00:00 2001 From: MultisampledNight Date: Fri, 19 Jan 2024 02:59:15 +0100 Subject: [PATCH] feat(ir): actually get toposort working --- crates/ir/src/lib.rs | 153 +++++++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/crates/ir/src/lib.rs b/crates/ir/src/lib.rs index 86f2ff4..3865d64 100644 --- a/crates/ir/src/lib.rs +++ b/crates/ir/src/lib.rs @@ -1,6 +1,5 @@ -use std::{collections::BTreeSet, iter, ops::RangeInclusive}; +use std::{num::NonZeroUsize, ops::RangeInclusive}; -use either::Either; use instruction::SocketCount; use serde::{Deserialize, Serialize}; @@ -9,7 +8,7 @@ pub mod instruction; pub mod semi_human; pub type Map = std::collections::BTreeMap; -pub type Set = std::collections::BTreeSet; +pub type Set = std::collections::BTreeSet; /// Gives you a super well typed graph IR for a given human-readable repr. /// @@ -41,8 +40,8 @@ pub fn from_ron(source: &str) -> ron::error::SpannedResult { /// to come back to an already visited node. /// /// Here, if an edge points from _A_ to _B_ (`A --> B`), -/// then _A_ is called a **dependency** of _B_, -/// and _B_ is called a **dependent** of _A_. +/// then _A_ is called a **dependency** or an **input source** of _B_, +/// and _B_ is called a **dependent** or an **output target** of _A_. /// /// The DAG also enables another neat operation: /// [Topological sorting](https://en.wikipedia.org/wiki/Topological_sorting). @@ -69,33 +68,8 @@ pub struct GraphIr { rev_edges: Map, } +// TODO: this impl block, but actually the whole module, screams for tests impl GraphIr { - /// Look "forwards" in the graph to see what other instructions this instruction feeds into. - /// - /// The output slots represent the top-level iterator, - /// and each one's connections are emitted one level below. - /// - /// Just [`Iterator::flatten`] if you are not interested in the slots. - /// - /// The same caveats as for [`GraphIr::resolve`] apply. - #[must_use] - pub fn dependents( - &self, - subject: &id::Instruction, - ) -> Option> + '_> { - let (subject, kind) = self.instructions.get_key_value(subject)?; - let SocketCount { inputs, .. } = kind.socket_count(); - - Some((0..inputs).map(|idx| { - let output = id::Output(socket(subject, idx)); - self.edges - .get(&output) - .map_or(Either::Right(iter::empty()), |targets| { - Either::Left(targets.iter().map(|input| &input.socket().belongs_to)) - }) - })) - } - /// Look "backwards" in the graph, /// and find out what instructions need to be done before this one. /// The input slots are visited in order. @@ -105,22 +79,41 @@ impl GraphIr { /// /// The same caveats as for [`GraphIr::resolve`] apply. #[must_use] - pub fn dependencies( + pub fn input_sources( &self, subject: &id::Instruction, - ) -> Option> + '_> { + ) -> Option> + '_> { let (subject, kind) = self.instructions.get_key_value(subject)?; let SocketCount { inputs, .. } = kind.socket_count(); Some((0..inputs).map(|idx| { let input = id::Input(socket(subject, idx)); - self.rev_edges - .get(&input) - .map(|output| &output.socket().belongs_to) + self.rev_edges.get(&input) + })) + } + + /// Look "forwards" in the graph to see what other instructions this instruction feeds into. + /// + /// The output slots represent the top-level iterator, + /// and each one's connections are emitted one level below. + /// + /// Just [`Iterator::flatten`] if you are not interested in the slots. + /// + /// The same caveats as for [`GraphIr::resolve`] apply. + #[must_use] + pub fn output_targets( + &self, + subject: &id::Instruction, + ) -> Option>> + '_> { + let (subject, kind) = self.instructions.get_key_value(subject)?; + let SocketCount { outputs, .. } = kind.socket_count(); + + Some((0..outputs).map(|idx| { + let output = id::Output(socket(subject, idx)); + self.edges.get(&output) })) } - // TODO: this function, but actually the whole module, screams for tests /// Returns the instruction corresponding to the given ID. /// Returns [`None`] if there is no such instruction in this graph IR. /// @@ -133,33 +126,14 @@ impl GraphIr { pub fn resolve<'ir>(&'ir self, subject: &id::Instruction) -> Option> { let (id, kind) = self.instructions.get_key_value(subject)?; - // just try each slot and see if it's connected - // very crude, but it works for a proof of concept - let SocketCount { inputs, outputs } = kind.socket_count(); - let socket = |id: &id::Instruction, idx| id::Socket { - belongs_to: id.clone(), - // impossible since the length is limited to a u16 already - #[allow(clippy::cast_possible_truncation)] - idx: id::SocketIdx(idx as u16), - }; - - let mut inputs_from = vec![None; inputs.into()]; - for (idx, slot) in inputs_from.iter_mut().enumerate() { - let input = id::Input(socket(id, idx)); - *slot = self.rev_edges.get(&input); - } - - let mut outputs_to = vec![None; outputs.into()]; - for (idx, slot) in outputs_to.iter_mut().enumerate() { - let output = id::Output(socket(id, idx)); - *slot = self.edges.get(&output); - } + let input_sources = self.input_sources(subject)?.collect(); + let output_targets = self.output_targets(subject)?.collect(); Some(Instruction { id, kind, - inputs_from, - outputs_to, + input_sources, + output_targets, }) } @@ -187,15 +161,18 @@ impl GraphIr { /// /// Panics if there are any cycles in the IR, as it needs to be a DAG. #[must_use] - // yes, this function could actually return an iterator and be lazy + // yes, this function could probably return an iterator and be lazy // no, not today pub fn topological_sort(&self) -> Vec { // count how many incoming edges each vertex has - let nonzero_input_counts: Map<_, usize> = + let mut nonzero_input_counts: Map<_, NonZeroUsize> = self.rev_edges .iter() .fold(Map::new(), |mut count, (input, _)| { - *count.entry(input.socket().belongs_to.clone()).or_default() += 1; + let _ = *count + .entry(input.socket().belongs_to.clone()) + .and_modify(|count| *count = count.saturating_add(1)) + .or_insert(NonZeroUsize::MIN); count }); @@ -204,32 +181,52 @@ impl GraphIr { let no_inputs: Vec<_> = { let nonzero: Set<_> = nonzero_input_counts.keys().collect(); let all: Set<_> = self.instructions.keys().collect(); - all.difference(&nonzero).copied().collect() + all.difference(&nonzero).copied().cloned().collect() }; - let mut active_queue = no_inputs; // then let's find the order! let mut order = Vec::new(); + let mut active_queue = no_inputs; while let Some(current) = active_queue.pop() { // now that this vertex is visited and resolved, // make sure all dependents notice that - for dependent in self - .dependents(current) + let dependents = self + .output_targets(¤t) .expect("graph to be consistent") .flatten() - { - dbg!(dependent); + .flatten(); + + for dependent_input in dependents { + let dependent = &dependent_input.socket().belongs_to; + + // how many inputs are connected to this dependent without us? + let count = nonzero_input_counts + .get_mut(dependent) + .expect("connected output must refer to non-zero input"); + + let new = NonZeroUsize::new(count.get() - 1); + if let Some(new) = new { + // aww, still some + *count = new; + continue; + } + + // none, that means this one is free now! let's throw it onto the active queue then + let (now_active, _) = nonzero_input_counts + .remove_entry(dependent) + .expect("connected output must refer to non-zero input"); + active_queue.push(now_active); } // TODO: check if this instruction is "well-fed", that is, has all the inputs it needs, // and if not, panic - order.push(self.resolve(current).expect("graph to be consistent")); + order.push(self.resolve(¤t).expect("graph to be consistent")); } assert!( - !nonzero_input_counts.is_empty(), + nonzero_input_counts.is_empty(), concat!( "topological sort didn't cover all instructions\n", "either there are unconnected inputs, or there is a cycle\n", @@ -250,8 +247,8 @@ pub struct Instruction<'ir> { pub kind: &'ir instruction::Kind, // can't have these two public since then a user might corrupt their length - inputs_from: Vec>, - outputs_to: Vec>>, + input_sources: Vec>, + output_targets: Vec>>, } impl<'ir> Instruction<'ir> { @@ -260,14 +257,14 @@ impl<'ir> Instruction<'ir> { /// [`None`] means that this input is unfilled, /// and must be filled before the instruction can be ran. #[must_use] - pub fn inputs_from(&self) -> &[Option<&'ir id::Output>] { - &self.inputs_from + pub fn input_sources(&self) -> &[Option<&'ir id::Output>] { + &self.input_sources } - /// To whom outputs are sent. [`None`] means that this output is unused. + /// To whom outputs are sent. #[must_use] - pub fn outputs_to(&self) -> &[Option<&'ir BTreeSet>] { - &self.outputs_to + pub fn output_targets(&self) -> &[Option<&'ir Set>] { + &self.output_targets } }