From 3a8d1f63baf99686b2d9d81198807aa1e7d758f0 Mon Sep 17 00:00:00 2001 From: MultisampledNight Date: Fri, 19 Jan 2024 02:01:30 +0100 Subject: [PATCH] feat(ir): dependents and dependencies --- Cargo.lock | 1 + crates/cli/src/main.rs | 2 +- crates/ir/Cargo.toml | 1 + crates/ir/src/lib.rs | 136 +++++++++++++++++++++++++++++++++++++---- 4 files changed, 127 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 867d4b8..373ca0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,7 @@ dependencies = [ name = "ir" version = "0.1.0" dependencies = [ + "either", "ron", "serde", ] diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index c7b00f9..16def99 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -17,5 +17,5 @@ fn main() { .expect("reading IR failed — come back to this later handle errors properly"); let pl = ir::from_ron(&f).expect("handle me properly"); - dbg!(pl); + dbg!(pl.topological_sort()); } diff --git a/crates/ir/Cargo.toml b/crates/ir/Cargo.toml index fdd00e8..77f8de4 100644 --- a/crates/ir/Cargo.toml +++ b/crates/ir/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +either = "1.9" ron = "0.8" serde = { version = "1.0.193", features = ["derive"] } diff --git a/crates/ir/src/lib.rs b/crates/ir/src/lib.rs index 316e31d..cfe93aa 100644 --- a/crates/ir/src/lib.rs +++ b/crates/ir/src/lib.rs @@ -1,8 +1,6 @@ -use std::{ - collections::{BTreeMap, BTreeSet}, - ops::RangeInclusive, -}; +use std::{collections::BTreeSet, iter, ops::RangeInclusive}; +use either::Either; use instruction::SocketCount; use serde::{Deserialize, Serialize}; @@ -72,6 +70,56 @@ pub struct GraphIr { } impl GraphIr { + /// Look "forwards" in the graph to see what other instructions this instruction feeds into. + /// + /// The output slots represent the top-level iterator, + /// and each one's connections are emitted one level below. + /// + /// Just [`Iterator::flatten`] if you are not interested in the slots. + /// + /// The same caveats as for [`GraphIr::resolve`] apply. + #[must_use] + pub fn dependents( + &self, + subject: &id::Instruction, + ) -> Option> + '_> { + let (subject, kind) = self.instructions.get_key_value(subject)?; + let SocketCount { inputs, .. } = kind.socket_count(); + + Some((0..inputs).map(|idx| { + let output = id::Output(socket(subject, idx)); + self.edges + .get(&output) + .map_or(Either::Right(iter::empty()), |targets| { + Either::Left(targets.iter().map(|input| &input.socket().belongs_to)) + }) + })) + } + + /// Look "backwards" in the graph, + /// and find out what instructions need to be done before this one. + /// The input slots are visited in order. + /// + /// - The iterator returns individually [`Some`]`(`[`None`]`)` if the corresponding slot is + /// not connected. + /// + /// The same caveats as for [`GraphIr::resolve`] apply. + #[must_use] + pub fn dependencies( + &self, + subject: &id::Instruction, + ) -> Option> + '_> { + let (subject, kind) = self.instructions.get_key_value(subject)?; + let SocketCount { inputs, .. } = kind.socket_count(); + + Some((0..inputs).map(|idx| { + let input = id::Input(socket(subject, idx)); + self.rev_edges + .get(&input) + .map(|output| &output.socket().belongs_to) + })) + } + // TODO: this function, but actually the whole module, screams for tests /// Returns the instruction corresponding to the given ID. /// Returns [`None`] if there is no such instruction in this graph IR. @@ -82,8 +130,8 @@ impl GraphIr { /// to actually have multiple [`GraphIr`]s at one point in time. /// Open an issue if that poses a problem for you. #[must_use] - pub fn resolve<'ir>(&'ir self, id: &id::Instruction) -> Option> { - let (id, kind) = self.instructions.get_key_value(id)?; + pub fn resolve<'ir>(&'ir self, subject: &id::Instruction) -> Option> { + let (id, kind) = self.instructions.get_key_value(subject)?; // just try each slot and see if it's connected // very crude, but it works for a proof of concept @@ -131,17 +179,73 @@ impl GraphIr { self.resolve(&output.socket().belongs_to) } + /// Returns the order in which the instructions could be visited + /// in order to ensure that all dependencies are resolved + /// before a vertex is visited. + /// + /// # Panics + /// + /// Panics if there are any cycles in the IR, as it needs to be a DAG. #[must_use] + // yes, this function could actually return an iterator and be lazy + // no, not today pub fn topological_sort(&self) -> Vec { // count how many incoming edges each vertex has - // chances are the BTreeMap is overkill - let incoming_counts: BTreeMap<_, _> = self - .rev_edges - .iter() - .map(|(input, _)| (self.owner_of_input(input), 1)) + let input_counts: Map<_, usize> = + self.rev_edges + .iter() + .fold(Map::new(), |mut count, (input, _)| { + *count.entry(input.socket().belongs_to.clone()).or_default() += 1; + count + }); + + // could experiment with a VecDeque here + let mut active_queue = Vec::new(); + + // what vertices can we start with? in other words, which ones have 0 inputs? + let unresolved_input_count: Map = input_counts + .into_iter() + .filter_map(|(instr, count)| { + dbg!(count); + if count == 0 { + active_queue.push(instr); + None + } else { + Some((instr, count)) + } + }) .collect(); - todo!() + // then let's find the order! + let mut order = Vec::new(); + + while let Some(current) = active_queue.pop() { + // now that this vertex is visited and resolved, + // make sure all dependents notice that + + for dependent in self + .dependents(¤t) + .expect("graph to be consistent") + .flatten() + { + dbg!(dependent); + } + + order.push(self.resolve(¤t).expect("graph to be consistent")); + } + + assert!( + !unresolved_input_count.is_empty(), + concat!( + "topological sort didn't cover all instructions\n", + "either there are unconnected inputs, or there is a cycle\n", + "unresolved instructions:\n", + "{:#?}" + ), + unresolved_input_count + ); + + order } } @@ -192,3 +296,11 @@ impl From> for Span { } } } + +/// Constructs an [`id::Socket`] a bit more tersely. +fn socket(id: &id::Instruction, idx: u16) -> id::Socket { + id::Socket { + belongs_to: id.clone(), + idx: id::SocketIdx(idx), + } +}