diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index 1de84aee..45837787 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,6 +1,6 @@ use core::fmt; use pretty::RcDoc; -use std::{fmt::Display, rc::Rc}; +use std::{cmp::Ordering, fmt::Display, rc::Rc}; use indexmap::IndexMap; use itertools::{Itertools, Position}; @@ -131,6 +131,72 @@ pub enum DecisionTree<'a> { }, } +#[derive(Eq, Hash, PartialEq, Clone, Debug)] +pub enum ScopePath { + Case(usize), + Fallback, +} + +impl PartialOrd for ScopePath { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (ScopePath::Case(a), ScopePath::Case(b)) => Some(b.cmp(a)), + (ScopePath::Case(_), ScopePath::Fallback) => Some(Ordering::Greater), + (ScopePath::Fallback, ScopePath::Case(_)) => Some(Ordering::Less), + (ScopePath::Fallback, ScopePath::Fallback) => Some(Ordering::Equal), + } + } +} + +impl Ord for ScopePath { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +#[derive(Eq, Hash, PartialEq, Clone, Debug, Default, PartialOrd, Ord)] +pub struct Scope { + scope: Vec, +} + +impl Scope { + pub fn new() -> Self { + Self { scope: vec![] } + } + + pub fn push(&mut self, path: ScopePath) { + self.scope.push(path); + } + + pub fn pop(&mut self) { + self.scope.pop(); + } + + pub fn common_ancestor(&mut self, other: &Scope) { + let scope = std::mem::replace(&mut self.scope, vec![]); + + self.scope = scope + .into_iter() + .zip(other.scope.iter()) + .map_while(|(a, b)| if a == *b { Some(a) } else { None }) + .collect_vec() + } + + pub fn len(&self) -> usize { + self.scope.len() + } + + pub fn is_empty(&self) -> bool { + self.scope.is_empty() + } +} + +enum Marker<'a, 'b> { + Pop, + Push(ScopePath, &'b DecisionTree<'a>), + PopPush(ScopePath, &'b DecisionTree<'a>), +} + impl<'a> DecisionTree<'a> { pub fn to_pretty(&self) -> String { let mut w = Vec::new(); @@ -162,34 +228,34 @@ impl<'a> DecisionTree<'a> { default, .. } => RcDoc::text("Switch(") - .append(RcDoc::line()) .append( path.iter() - .fold(RcDoc::text("path("), |acc, p| { - acc.append(RcDoc::line()) - .append(RcDoc::text(format!("{}", p))) + .fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| { + acc.append(RcDoc::line().append(RcDoc::text(format!("{}", p)).nest(4))) }) - .append(RcDoc::line_()) - .nest(2) - .append(RcDoc::text(")")), + .append(RcDoc::line()) + .append(RcDoc::text(")")) + .nest(4), ) - .append(RcDoc::line()) .append( cases .iter() - .fold(RcDoc::text("cases("), |acc, (con, tree)| { - acc.append(RcDoc::line()) - .append(format!("({}): ", con)) - .append(RcDoc::line()) - .append(tree.to_doc().nest(2)) - }) - .append(RcDoc::line_()) - .nest(2) - .append(RcDoc::text(")")), + .fold( + RcDoc::line().append(RcDoc::text("cases(")), + |acc, (con, tree)| { + acc.append(RcDoc::line()) + .append(RcDoc::text(format!("({}): ", con))) + .append(RcDoc::line()) + .append(tree.to_doc().nest(4)) + }, + ) + .append(RcDoc::line()) + .append(RcDoc::text(")")) + .nest(4), ) - .append(RcDoc::line()) .append( - RcDoc::text("default : ") + RcDoc::line() + .append(RcDoc::text("default : ")) .append(RcDoc::line()) .append( default @@ -197,9 +263,9 @@ impl<'a> DecisionTree<'a> { .map(|i| i.to_doc()) .unwrap_or(RcDoc::text("None")), ) - .nest(2), + .append(RcDoc::line()) + .nest(4), ) - .append(RcDoc::line_()) .append(RcDoc::text(")")), DecisionTree::ListSwitch { path, @@ -210,42 +276,48 @@ impl<'a> DecisionTree<'a> { } => RcDoc::text("ListSwitch(") .append( path.iter() - .fold(RcDoc::text("path("), |acc, p| { - acc.append(RcDoc::line()) - .append(RcDoc::text(format!("{}", p))) + .fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| { + acc.append(RcDoc::line().append(RcDoc::text(format!("{}", p)).nest(4))) }) - .append(RcDoc::line_()) - .nest(2) - .append(RcDoc::text(")")), + .append(RcDoc::line()) + .append(RcDoc::text(")")) + .nest(4), ) .append( cases .iter() - .fold(RcDoc::text("cases("), |acc, (con, tree)| { - acc.append(RcDoc::line()) - .append(format!("({}): ", con)) - .append(RcDoc::line()) - .append(tree.to_doc().nest(2)) - }) - .append(RcDoc::line_()) - .nest(2) - .append(RcDoc::text(")")), + .fold( + RcDoc::line().append(RcDoc::text("cases(")), + |acc, (con, tree)| { + acc.append(RcDoc::line()) + .append(RcDoc::text(format!("({}): ", con))) + .append(RcDoc::line()) + .append(tree.to_doc().nest(4)) + }, + ) + .append(RcDoc::line()) + .append(RcDoc::text(")")) + .nest(4), ) .append( tail_cases .iter() - .fold(RcDoc::text("tail cases("), |acc, (con, tree)| { - acc.append(RcDoc::line()) - .append(format!("({}): ", con)) - .append(RcDoc::line()) - .append(tree.to_doc().nest(2)) - }) - .append(RcDoc::line_()) - .nest(2) - .append(RcDoc::text(")")), + .fold( + RcDoc::line().append(RcDoc::text("tail_cases(")), + |acc, (con, tree)| { + acc.append(RcDoc::line()) + .append(RcDoc::text(format!("({}): ", con))) + .append(RcDoc::line()) + .append(tree.to_doc().nest(4)) + }, + ) + .append(RcDoc::line()) + .append(RcDoc::text(")")) + .nest(4), ) .append( - RcDoc::text("default : ") + RcDoc::line() + .append(RcDoc::text("default : ")) .append(RcDoc::line()) .append( default @@ -253,12 +325,208 @@ impl<'a> DecisionTree<'a> { .map(|i| i.to_doc()) .unwrap_or(RcDoc::text("None")), ) - .nest(2), + .append(RcDoc::line()) + .nest(4), ) - .append(RcDoc::line_()) .append(RcDoc::text(")")), DecisionTree::HoistedLeaf(name, _) => RcDoc::text(format!("Leaf({})", name)), - DecisionTree::HoistThen { .. } => todo!(), + DecisionTree::HoistThen { name, pattern, .. } => RcDoc::text("HoistThen(") + .append( + RcDoc::line() + .append(RcDoc::text(format!("name : {}", name))) + .append(RcDoc::line()) + .nest(4), + ) + .append( + RcDoc::line() + .append(pattern.to_doc()) + .append(RcDoc::line()) + .nest(4), + ) + .append(RcDoc::text(")")), + } + } + + /// For fun I decided to do this without recursion + /// It doesn't look to bad lol + fn get_hoist_paths<'b>(&self, names: Vec<&'b String>) -> IndexMap<&'b String, Scope> { + let mut prev = vec![]; + + let mut current_path = Scope::new(); + + let mut tree = self; + + let mut scope_map: IndexMap<&String, Scope> = + names.into_iter().map(|item| (item, Scope::new())).collect(); + + loop { + match tree { + DecisionTree::Switch { cases, default, .. } => { + prev.push(Marker::Pop); + + if let Some(def) = default { + prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref())); + } + + cases + .iter() + .enumerate() + .rev() + .for_each(|(index, (_, detree))| { + if index == 0 { + prev.push(Marker::Push(ScopePath::Case(index), detree)); + } else { + prev.push(Marker::PopPush(ScopePath::Case(index), detree)); + } + }); + } + + DecisionTree::ListSwitch { + cases, + tail_cases, + default, + .. + } => { + prev.push(Marker::Pop); + + if let Some(def) = default { + prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref())); + } + + tail_cases + .iter() + .enumerate() + .rev() + .for_each(|(index, (_, detree))| { + prev.push(Marker::PopPush( + ScopePath::Case(index + cases.len()), + detree, + )); + }); + + cases + .iter() + .enumerate() + .rev() + .for_each(|(index, (_, detree))| { + if index == 0 { + prev.push(Marker::Push(ScopePath::Case(index), detree)); + } else { + prev.push(Marker::PopPush(ScopePath::Case(index), detree)); + } + }); + } + DecisionTree::HoistedLeaf(leaf_name, _) => { + let scope_for_name = scope_map + .get_mut(leaf_name) + .expect("Impossible, Leaf is based off of given names"); + + if scope_for_name.is_empty() { + *scope_for_name = current_path.clone(); + } else { + scope_for_name.common_ancestor(¤t_path); + } + } + DecisionTree::HoistThen { .. } => unreachable!(), + } + + if let Some(action) = prev.pop() { + match action { + Marker::Pop => { + current_path.pop(); + } + Marker::Push(p, dec_tree) => { + current_path.push(p); + + tree = dec_tree; + } + Marker::PopPush(p, dec_tree) => { + current_path.pop(); + + current_path.push(p); + + tree = dec_tree; + } + } + } else { + break; + } + } + + scope_map + } + + fn hoist_by_path( + &mut self, + current_path: &mut Scope, + name_paths: &mut Vec<(String, Scope)>, + hoistables: &mut IndexMap, &'a TypedExpr)>, + ) { + match self { + DecisionTree::Switch { cases, default, .. } => { + cases.iter_mut().enumerate().for_each(|(index, (_, tree))| { + current_path.push(ScopePath::Case(index)); + tree.hoist_by_path(current_path, name_paths, hoistables); + current_path.pop(); + }); + + current_path.push(ScopePath::Fallback); + if let Some(def) = default { + def.hoist_by_path(current_path, name_paths, hoistables); + } + current_path.pop(); + } + DecisionTree::ListSwitch { + cases, + tail_cases, + default, + .. + } => { + cases.iter_mut().enumerate().for_each(|(index, (_, tree))| { + current_path.push(ScopePath::Case(index)); + tree.hoist_by_path(current_path, name_paths, hoistables); + current_path.pop(); + }); + + tail_cases + .iter_mut() + .enumerate() + .for_each(|(index, (_, tree))| { + current_path.push(ScopePath::Case(index + cases.len())); + tree.hoist_by_path(current_path, name_paths, hoistables); + current_path.pop(); + }); + + current_path.push(ScopePath::Fallback); + if let Some(def) = default { + def.hoist_by_path(current_path, name_paths, hoistables); + } + current_path.pop(); + } + DecisionTree::HoistedLeaf(_, _) => (), + DecisionTree::HoistThen { .. } => unreachable!(), + } + + loop { + if let Some(name_path) = name_paths.pop() { + if name_path.1 == *current_path { + let (assigns, then) = hoistables.remove(&name_path.0).unwrap(); + let pattern = + std::mem::replace(self, DecisionTree::HoistedLeaf("".to_string(), vec![])); + + *self = DecisionTree::HoistThen { + name: name_path.0, + assigns, + pattern: pattern.into(), + then, + }; + } else { + name_paths.push(name_path); + break; + } + } else { + break; + } } } } @@ -297,7 +565,7 @@ impl<'a, 'b> TreeGen<'a, 'b> { subject_tipo: &Rc, clauses: &'a [TypedClause], ) -> DecisionTree<'a> { - let mut clause_then_map = IndexMap::new(); + let mut hoistables = IndexMap::new(); let rows = clauses .iter() @@ -311,7 +579,7 @@ impl<'a, 'b> TreeGen<'a, 'b> { .interner .lookup_interned(&format!("__clause_then_{}", index)); - clause_then_map.insert(clause_then_name.clone(), (vec![], &clause.then)); + hoistables.insert(clause_then_name.clone(), (vec![], &clause.then)); let row = Row { assigns: assign.into_iter().collect_vec(), @@ -325,13 +593,23 @@ impl<'a, 'b> TreeGen<'a, 'b> { }) .collect_vec(); - let tree = self.do_build_tree( + let mut tree = self.do_build_tree( subject_name, subject_tipo, PatternMatrix { rows }, - &mut clause_then_map, + &mut hoistables, ); + let scope_map = tree.get_hoist_paths(hoistables.keys().collect_vec()); + + let mut name_paths = scope_map + .into_iter() + .sorted_by(|a, b| a.1.cmp(&b.1)) + .map(|(name, path)| (name.clone(), path)) + .collect_vec(); + + tree.hoist_by_path(&mut Scope::new(), &mut name_paths, &mut hoistables); + // Do hoisting of thens here tree } @@ -1059,7 +1337,7 @@ mod tester { let tree = tree_gen.build_tree(&"subject".to_string(), &Type::list(Type::int()), clauses); - println!("{:#?}", tree); + println!("{}", tree); } #[test] @@ -1106,7 +1384,7 @@ mod tester { clauses, ); - println!("{:#?}", tree); + println!("{}", tree); } #[test] @@ -1155,7 +1433,7 @@ mod tester { clauses, ); - println!("{:#?}", tree); + println!("{}", tree); } #[test] @@ -1230,7 +1508,7 @@ mod tester { panic!() }; - let TypedExpr::When { clauses, tipo, .. } = &function.body else { + let TypedExpr::When { clauses, .. } = &function.body else { panic!() };