From 43e859f1ba18fe63316b2259a88d4f3994b7566e Mon Sep 17 00:00:00 2001 From: microproofs Date: Thu, 10 Oct 2024 12:11:13 -0400 Subject: [PATCH] Rework Decision Trees to use path to find the subject to test --- .../aiken-lang/src/gen_uplc/decision_tree.rs | 193 +++++++++--------- 1 file changed, 97 insertions(+), 96 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index 87f39186..ccb2e830 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, rc::Rc}; -use itertools::{Itertools, Tuples}; +use itertools::Itertools; use crate::{ ast::{Pattern, TypedClause, TypedPattern}, @@ -13,7 +13,7 @@ struct Occurrence { amount: usize, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] enum Path { Pair(usize), Tuple(usize), @@ -80,22 +80,22 @@ impl Ord for CaseTest { } } -enum DecisionTree { +enum DecisionTree<'a> { Switch { subject_name: String, subject_tipo: Rc, - column_to_test: usize, - cases: Vec<(CaseTest, Vec, DecisionTree)>, - default: (Vec, Box), + path: Vec, + cases: Vec<(CaseTest, DecisionTree<'a>)>, + default: Box>, }, - Leaf(TypedExpr), + Leaf(Vec, &'a TypedExpr), } fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { while let Some((p, rest)) = path.split_first() { let index = p.get_index(); - subject_tipo = subject_tipo.arg_types().unwrap().remove(index); + subject_tipo = subject_tipo.arg_types().unwrap().swap_remove(index); path = rest } subject_tipo @@ -104,7 +104,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { fn map_pattern_to_row<'a>( pattern: &'a TypedPattern, subject_name: &String, - subject_tipo: Rc, + subject_tipo: &Rc, path: Vec, ) -> (Vec, Vec>) { let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); @@ -141,7 +141,7 @@ fn map_pattern_to_row<'a>( path: path.clone(), assigned: name.clone(), }], - map_pattern_to_row(pattern, subject_name, subject_tipo.clone(), path).1, + map_pattern_to_row(pattern, subject_name, subject_tipo, path).1, ), Pattern::Int { .. } | Pattern::ByteArray { .. } @@ -166,10 +166,10 @@ fn map_pattern_to_row<'a>( snd_path.push(Path::Pair(1)); let (mut assigns, mut patts) = - map_pattern_to_row(fst, subject_name, subject_tipo.clone(), fst_path); + map_pattern_to_row(fst, subject_name, subject_tipo, fst_path); let (assign_snd, patt_snd) = - map_pattern_to_row(snd, subject_name, subject_tipo.clone(), snd_path); + map_pattern_to_row(snd, subject_name, subject_tipo, snd_path); assigns.extend(assign_snd.into_iter()); @@ -187,7 +187,7 @@ fn map_pattern_to_row<'a>( item_path.push(Path::Tuple(index)); let (assigns, patts) = - map_pattern_to_row(item, subject_name, subject_tipo.clone(), item_path); + map_pattern_to_row(item, subject_name, subject_tipo, item_path); acc.0.extend(assigns.into_iter()); acc.1.extend(patts.into_iter()); @@ -206,16 +206,16 @@ fn match_wild_card(pattern: &TypedPattern) -> bool { } } -pub fn build_tree( +pub fn build_tree<'a>( subject_name: &String, - subject_tipo: Rc, - clauses: &Vec, -) -> DecisionTree { + subject_tipo: &Rc, + clauses: &'a Vec, +) -> DecisionTree<'a> { let rows = clauses .iter() .map(|clause| { let (assign, row_items) = - map_pattern_to_row(&clause.pattern, subject_name, subject_tipo.clone(), vec![]); + map_pattern_to_row(&clause.pattern, subject_name, subject_tipo, vec![]); Row { assigns: assign.into_iter().collect_vec(), @@ -232,7 +232,7 @@ pub fn do_build_tree<'a>( subject_name: &String, subject_tipo: &Rc, matrix: PatternMatrix<'a>, -) -> DecisionTree { +) -> DecisionTree<'a> { let column_length = matrix.rows[0].columns.len(); assert!(matrix @@ -274,98 +274,99 @@ pub fn do_build_tree<'a>( } }); - if column_length > 1 { - DecisionTree::Switch { - subject_name: subject_name.clone(), - subject_tipo: subject_tipo.clone(), - column_to_test: highest_occurrence.0, - cases: todo!(), - default: todo!(), - } - } else { - let mut collection_vec = matrix.rows.into_iter().fold( - vec![], - |mut collection_vec: Vec<(CaseTest, Vec, Vec>)>, mut item: Row<'a>| { - let col = item.columns.remove(highest_occurrence.0); + let (path, mut collection_vec) = matrix.rows.into_iter().fold( + (vec![], vec![]), + |mut collection_vec: (Vec, Vec<(CaseTest, Vec>)>), mut item: Row<'a>| { + let col = item.columns.remove(highest_occurrence.0); - assert!(!matches!(col.pattern, Pattern::Assign { .. })); + assert!(!matches!(col.pattern, Pattern::Assign { .. })); - let (mapped_args, case) = match col.pattern { - Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())), - Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())), - Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild), - Pattern::List { elements, .. } => ( - elements - .iter() - .enumerate() - .map(|(index, item)| { - let mut item_path = col.path.clone(); + let (mapped_args, case) = match col.pattern { + Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())), + Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())), + Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild), + Pattern::List { elements, .. } => ( + elements + .iter() + .enumerate() + .map(|(index, item)| { + let mut item_path = col.path.clone(); - item_path.push(Path::Tuple(index)); + item_path.push(Path::Tuple(index)); - map_pattern_to_row( - item, - subject_name, - subject_tipo.clone(), - item_path, - ) - }) - .collect_vec(), - CaseTest::List(elements.len()), - ), + map_pattern_to_row(item, subject_name, subject_tipo, item_path) + }) + .collect_vec(), + CaseTest::List(elements.len()), + ), - Pattern::Constructor { .. } => { - todo!() - } - _ => unreachable!("{:#?}", col.pattern), - }; - - item.assigns - .extend(mapped_args.iter().map(|x| x.0.clone()).flatten()); - item.columns - .extend(mapped_args.into_iter().map(|x| x.1).flatten()); - - if let Some(index) = collection_vec.iter().position(|item| item.0 == case) { - let entry = collection_vec.get_mut(index).unwrap(); - - entry.2.push(item); - collection_vec - } else { - collection_vec.push((case, col.path, vec![item])); - - collection_vec + Pattern::Constructor { .. } => { + todo!() } - }, - ); + _ => unreachable!("{:#?}", col.pattern), + }; - collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); - let mut collection_iter = collection_vec - .into_iter() - .map(|x| { - ( - x.0, - x.1, - do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.2 }), - ) - }) - .peekable(); + item.assigns + .extend(mapped_args.iter().map(|x| x.0.clone()).flatten()); + item.columns + .extend(mapped_args.into_iter().map(|x| x.1).flatten()); - let cases = collection_iter - .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) + assert!(collection_vec.0.is_empty() || collection_vec.0 == col.path); + + if collection_vec.0.is_empty() { + collection_vec.0 = col.path; + } + + if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) { + entry.1.push(item); + collection_vec + } else { + collection_vec.1.push((case, vec![item])); + + collection_vec + } + }, + ); + + collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut collection_iter = collection_vec.into_iter().peekable(); + + let cases = collection_iter + .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) + .map(|x| { + ( + x.0, + do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }), + ) + }) + .collect_vec(); + + if cases.is_empty() { + let mut fallback = collection_iter.collect_vec(); + + assert!(fallback.len() == 1); + + let mut remaining = fallback.swap_remove(0).1; + + assert!(remaining.len() == 1); + + let thing = remaining.swap_remove(0); + + DecisionTree::Leaf(thing.assigns, thing.then) + } else { + let mut fallback = collection_iter + .map(|x| do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }).into()) .collect_vec(); - let mut fallback = collection_iter.map(|x| (x.1, x.2.into())).collect_vec(); - assert!(fallback.len() == 1); DecisionTree::Switch { subject_name: subject_name.clone(), subject_tipo: subject_tipo.clone(), - column_to_test: highest_occurrence.0, + path, cases, - default: fallback.remove(0), + default: fallback.swap_remove(0), } - }; - - todo!() + } }