diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index a7094d54..ae7e907f 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,12 +1,15 @@ use std::rc::Rc; +use indexmap::IndexMap; use itertools::{Itertools, Position}; use crate::{ - ast::{Pattern, TypedClause, TypedPattern}, - expr::{PatternConstructor, Type, TypedExpr}, + ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern}, + expr::{lookup_data_type_by_tipo, PatternConstructor, Type, TypedExpr}, }; +use super::interner::AirInterner; + #[derive(Clone, Default, Copy)] struct Occurrence { passed_wild_card: bool, @@ -77,6 +80,406 @@ pub enum DecisionTree<'a> { HoistThen(String, Box>, Box>), } +pub struct TreeGen<'a, 'b> { + interner: &'b mut AirInterner, + data_types: &'b IndexMap<&'a DataTypeKey, &'a TypedDataType>, +} + +impl<'a, 'b> TreeGen<'a, 'b> { + pub fn build_tree( + mut self, + subject_name: &String, + subject_tipo: &Rc, + clauses: &'a [TypedClause], + ) -> DecisionTree<'a> { + let rows = clauses + .iter() + .map(|clause| { + let (assign, row_items) = map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); + + Row { + assigns: assign.into_iter().collect_vec(), + columns: row_items, + then: &clause.then, + } + }) + .collect_vec(); + + let tree_gen = &mut self; + + tree_gen.do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) + } + + fn do_build_tree( + &mut self, + subject_name: &String, + subject_tipo: &Rc, + matrix: PatternMatrix<'a>, + fallback_option: Option>, + ) -> DecisionTree<'a> { + let column_length = matrix.rows[0].columns.len(); + + assert!(matrix + .rows + .iter() + .all(|row| { row.columns.len() == column_length })); + + let occurrence_col = highest_occurrence(&matrix, column_length); + + let mut longest_elems_no_tail = None; + let mut longest_elems_with_tail = None; + let mut has_list_pattern = false; + + matrix.rows.iter().for_each(|item| { + let col = &item.columns[occurrence_col]; + + match col.pattern { + Pattern::List { elements, tail, .. } => { + has_list_pattern = true; + if tail.is_none() { + match longest_elems_no_tail { + Some(elems_count) => { + if elems_count < elements.len() { + longest_elems_no_tail = Some(elements.len()); + } + } + None => { + longest_elems_no_tail = Some(elements.len()); + } + } + } else { + match longest_elems_with_tail { + Some(elems_count) => { + if elems_count < elements.len() { + longest_elems_with_tail = Some(elements.len()); + } + } + None => { + longest_elems_with_tail = Some(elements.len()); + } + } + } + } + _ => (), + } + }); + + let path = matrix + .rows + .get(0) + .unwrap() + .columns + .get(occurrence_col) + .map(|col| col.path.clone()) + .unwrap_or(vec![]); + + let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path); + + let mut row_iter = matrix.rows.into_iter().peekable(); + + let specialized_matrices = row_iter + .peeking_take_while(|row| !match_wild_card(&row.columns[occurrence_col].pattern)) + .fold(vec![], |mut case_matrices, mut row| { + if row.columns.is_empty() { + case_matrices.push((CaseTest::Wild, vec![row])); + return case_matrices; + } + + let col = row.columns.remove(occurrence_col); + + let (case, remaining_patts) = match col.pattern { + Pattern::Int { value, .. } => (CaseTest::Int(value.clone()), vec![]), + Pattern::ByteArray { value, .. } => (CaseTest::Bytes(value.clone()), vec![]), + Pattern::List { elements, tail, .. } => ( + if tail.is_none() { + CaseTest::List(elements.len()) + } else { + CaseTest::ListWithTail(elements.len()) + }, + elements + .iter() + .chain(tail.as_ref().map(|tail| tail.as_ref())) + .enumerate() + .with_position() + .map(|elem| match elem { + Position::First((index, element)) + | Position::Middle((index, element)) + | Position::Only((index, element)) => { + let mut item_path = col.path.clone(); + + item_path.push(Path::List(index)); + + map_pattern_to_row(element, subject_tipo, item_path) + } + + Position::Last((index, element)) => { + if tail.is_none() { + let mut item_path = col.path.clone(); + + item_path.push(Path::List(index)); + + map_pattern_to_row(element, subject_tipo, item_path) + } else { + let mut item_path = col.path.clone(); + + item_path.push(Path::ListTail(index)); + + map_pattern_to_row(element, subject_tipo, item_path) + } + } + }) + .collect_vec(), + ), + + Pattern::Constructor { + name, arguments, .. + } => { + let data_type = + lookup_data_type_by_tipo(&self.data_types, &specialized_tipo); + + todo!() + } + Pattern::Tuple { .. } + | Pattern::Pair { .. } + | Pattern::Assign { .. } + | Pattern::Var { .. } + | Pattern::Discard { .. } => { + unreachable!("{:#?}", col.pattern) + } + }; + + // Assert path is the same for each specialized row + assert!(path == col.path); + + // expand assigns by newly added ones + row.assigns + .extend(remaining_patts.iter().flat_map(|x| x.0.clone())); + + // Add inner patterns to existing row + row.columns + .extend(remaining_patts.into_iter().flat_map(|x| x.1)); + + // For lists with tail it's a special case where we also add it to existing patterns + // all the way to the longest element. The reason being that each list size greater + // than the list with tail could also match with could also match depending on the inner pattern. + // See tests below for an example + if let CaseTest::ListWithTail(elems_len) = case { + if let Some(longest_elems_no_tail) = longest_elems_no_tail { + for elem_count in elems_len..=longest_elems_no_tail { + let case = CaseTest::List(elem_count); + + let mut row = row.clone(); + + let tail = row.columns.pop().unwrap(); + + let columns_to_fill = (0..(elem_count - elems_len)) + .map(|_| tail.clone()) + .collect_vec(); + + row.columns.extend(columns_to_fill); + + if let Some(entry) = + case_matrices.iter_mut().find(|item| item.0 == case) + { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } + } + } + + let Some(longest_elems_with_tail) = longest_elems_with_tail else { + unreachable!() + }; + + for elem_count in elems_len..=longest_elems_with_tail { + let case = CaseTest::ListWithTail(elem_count); + + let mut row = row.clone(); + + let tail = row.columns.pop().unwrap(); + + let columns_to_fill = (0..(elem_count - elems_len)) + .map(|_| tail.clone()) + .collect_vec(); + + row.columns.extend(columns_to_fill); + + if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } + } + } else { + if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } + } + + case_matrices + }); + + let default_matrix = PatternMatrix { + rows: row_iter.collect_vec(), + }; + + if has_list_pattern { + // Since the list_tail case might cover the rest of the possible matches extensively + // then fallback is optional here + let fallback_option = if default_matrix.rows.is_empty() { + fallback_option + } else { + Some(self.do_build_tree( + subject_name, + subject_tipo, + // Since everything after this point had a wild card on or above + // the row for the selected column in front. Then we ignore the + // cases and continue to check other columns. + default_matrix, + fallback_option, + )) + }; + + let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices + .into_iter() + .partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_))); + + // TODO: pass in interner and use unique string + let hoisted_name = "HoistedThing".to_string(); + + if let Some(fallback) = fallback_option { + DecisionTree::HoistThen( + hoisted_name.clone(), + fallback.into(), + DecisionTree::ListSwitch { + subject_name: subject_name.clone(), + subject_tipo: specialized_tipo.clone(), + path, + cases: cases + .into_iter() + .map(|x| { + ( + x.0, + self.do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + tail_cases: tail_cases + .into_iter() + .map(|x| { + ( + x.0, + self.do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + default: Some(DecisionTree::HoistedLeaf(hoisted_name).into()), + } + .into(), + ) + } else { + DecisionTree::ListSwitch { + subject_name: subject_name.clone(), + subject_tipo: specialized_tipo.clone(), + path, + cases: cases + .into_iter() + .map(|x| { + ( + x.0, + self.do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + None, + ), + ) + }) + .collect_vec(), + tail_cases: tail_cases + .into_iter() + .map(|x| { + ( + x.0, + self.do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + None, + ), + ) + }) + .collect_vec(), + default: None, + } + } + } else if specialized_matrices.is_empty() { + // No more patterns to match on so we grab the first default row and return that + let mut fallback = default_matrix.rows; + + let row = fallback.swap_remove(0); + + DecisionTree::Leaf(row.assigns, row.then) + } else { + let fallback = if default_matrix.rows.is_empty() { + fallback_option.unwrap() + } else { + self.do_build_tree( + subject_name, + subject_tipo, + // Since everything after this point had a wild card on or above + // the row for the selected column in front. Then we ignore the + // cases and continue to check other columns. + default_matrix, + fallback_option, + ) + }; + + // TODO: pass in interner and use unique string + let hoisted_name = "HoistedThing".to_string(); + + DecisionTree::HoistThen( + hoisted_name.clone(), + fallback.into(), + DecisionTree::Switch { + subject_name: subject_name.clone(), + subject_tipo: specialized_tipo.clone(), + path, + cases: specialized_matrices + .into_iter() + .map(|x| { + ( + x.0, + self.do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + default: DecisionTree::HoistedLeaf(hoisted_name).into(), + } + .into(), + ) + } + } +} + fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { while let Some((p, rest)) = path.split_first() { subject_tipo = match p { @@ -232,395 +635,17 @@ fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize { highest_occurrence.0 } -pub fn build_tree<'a>( - subject_name: &String, - subject_tipo: &Rc, - clauses: &'a [TypedClause], -) -> DecisionTree<'a> { - let rows = clauses - .iter() - .map(|clause| { - let (assign, row_items) = map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); - - Row { - assigns: assign.into_iter().collect_vec(), - columns: row_items, - then: &clause.then, - } - }) - .collect_vec(); - - do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) -} - -fn do_build_tree<'a>( - subject_name: &String, - subject_tipo: &Rc, - matrix: PatternMatrix<'a>, - fallback_option: Option>, -) -> DecisionTree<'a> { - let column_length = matrix.rows[0].columns.len(); - - assert!(matrix - .rows - .iter() - .all(|row| { row.columns.len() == column_length })); - - let occurrence_col = highest_occurrence(&matrix, column_length); - - let mut longest_elems_no_tail = None; - let mut longest_elems_with_tail = None; - let mut has_list_pattern = false; - - matrix.rows.iter().for_each(|item| { - let col = &item.columns[occurrence_col]; - - match col.pattern { - Pattern::List { elements, tail, .. } => { - has_list_pattern = true; - if tail.is_none() { - match longest_elems_no_tail { - Some(elems_count) => { - if elems_count < elements.len() { - longest_elems_no_tail = Some(elements.len()); - } - } - None => { - longest_elems_no_tail = Some(elements.len()); - } - } - } else { - match longest_elems_with_tail { - Some(elems_count) => { - if elems_count < elements.len() { - longest_elems_with_tail = Some(elements.len()); - } - } - None => { - longest_elems_with_tail = Some(elements.len()); - } - } - } - } - _ => (), - } - }); - - let path = matrix - .rows - .get(0) - .unwrap() - .columns - .get(occurrence_col) - .map(|col| col.path.clone()) - .unwrap_or(vec![]); - - let mut row_iter = matrix.rows.into_iter().peekable(); - - let specialized_matrices = row_iter - .peeking_take_while(|row| !match_wild_card(&row.columns[occurrence_col].pattern)) - .fold(vec![], |mut case_matrices, mut row| { - if row.columns.is_empty() { - case_matrices.push((CaseTest::Wild, vec![row])); - return case_matrices; - } - - let col = row.columns.remove(occurrence_col); - - let (case, remaining_patts) = match col.pattern { - Pattern::Int { value, .. } => (CaseTest::Int(value.clone()), vec![]), - Pattern::ByteArray { value, .. } => (CaseTest::Bytes(value.clone()), vec![]), - Pattern::List { elements, tail, .. } => ( - if tail.is_none() { - CaseTest::List(elements.len()) - } else { - CaseTest::ListWithTail(elements.len()) - }, - elements - .iter() - .chain(tail.as_ref().map(|tail| tail.as_ref())) - .enumerate() - .with_position() - .map(|elem| match elem { - Position::First((index, element)) - | Position::Middle((index, element)) - | Position::Only((index, element)) => { - let mut item_path = col.path.clone(); - - item_path.push(Path::List(index)); - - map_pattern_to_row(element, subject_tipo, item_path) - } - - Position::Last((index, element)) => { - if tail.is_none() { - let mut item_path = col.path.clone(); - - item_path.push(Path::List(index)); - - map_pattern_to_row(element, subject_tipo, item_path) - } else { - let mut item_path = col.path.clone(); - - item_path.push(Path::ListTail(index)); - - map_pattern_to_row(element, subject_tipo, item_path) - } - } - }) - .collect_vec(), - ), - - Pattern::Constructor { .. } => { - todo!() - } - Pattern::Tuple { .. } - | Pattern::Pair { .. } - | Pattern::Assign { .. } - | Pattern::Var { .. } - | Pattern::Discard { .. } => { - unreachable!("{:#?}", col.pattern) - } - }; - - // Assert path is the same for each specialized row - assert!(path == col.path); - - // expand assigns by newly added ones - row.assigns - .extend(remaining_patts.iter().flat_map(|x| x.0.clone())); - - // Add inner patterns to existing row - row.columns - .extend(remaining_patts.into_iter().flat_map(|x| x.1)); - - // For lists with tail it's a special case where we also add it to existing patterns - // all the way to the longest element. The reason being that each list size greater - // than the list with tail could also match with could also match depending on the inner pattern. - // See tests below for an example - if let CaseTest::ListWithTail(elems_len) = case { - if let Some(longest_elems_no_tail) = longest_elems_no_tail { - for elem_count in elems_len..=longest_elems_no_tail { - let case = CaseTest::List(elem_count); - - let mut row = row.clone(); - - let tail = row.columns.pop().unwrap(); - - let columns_to_fill = (0..(elem_count - elems_len)) - .map(|_| tail.clone()) - .collect_vec(); - - row.columns.extend(columns_to_fill); - - if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { - entry.1.push(row); - } else { - case_matrices.push((case, vec![row])); - } - } - } - - let Some(longest_elems_with_tail) = longest_elems_with_tail else { - unreachable!() - }; - - for elem_count in elems_len..=longest_elems_with_tail { - let case = CaseTest::ListWithTail(elem_count); - - let mut row = row.clone(); - - let tail = row.columns.pop().unwrap(); - - let columns_to_fill = (0..(elem_count - elems_len)) - .map(|_| tail.clone()) - .collect_vec(); - - row.columns.extend(columns_to_fill); - - if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { - entry.1.push(row); - } else { - case_matrices.push((case, vec![row])); - } - } - } else { - if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { - entry.1.push(row); - } else { - case_matrices.push((case, vec![row])); - } - } - - case_matrices - }); - - let default_matrix = PatternMatrix { - rows: row_iter.collect_vec(), - }; - - if has_list_pattern { - // Since the list_tail case might cover the rest of the possible matches extensively - // then fallback is optional here - let fallback_option = if default_matrix.rows.is_empty() { - fallback_option - } else { - Some(do_build_tree( - subject_name, - subject_tipo, - // Since everything after this point had a wild card on or above - // the row for the selected column in front. Then we ignore the - // cases and continue to check other columns. - default_matrix, - fallback_option, - )) - }; - - let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices - .into_iter() - .partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_))); - - // TODO: pass in interner and use unique string - let hoisted_name = "HoistedThing".to_string(); - - if let Some(fallback) = fallback_option { - DecisionTree::HoistThen( - hoisted_name.clone(), - fallback.into(), - DecisionTree::ListSwitch { - subject_name: subject_name.clone(), - subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), - path, - cases: cases - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), - ), - ) - }) - .collect_vec(), - tail_cases: tail_cases - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), - ), - ) - }) - .collect_vec(), - default: Some(DecisionTree::HoistedLeaf(hoisted_name).into()), - } - .into(), - ) - } else { - DecisionTree::ListSwitch { - subject_name: subject_name.clone(), - subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), - path, - cases: cases - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - None, - ), - ) - }) - .collect_vec(), - tail_cases: tail_cases - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - None, - ), - ) - }) - .collect_vec(), - default: None, - } - } - } else if specialized_matrices.is_empty() { - // No more patterns to match on so we grab the first default row and return that - let mut fallback = default_matrix.rows; - - let row = fallback.swap_remove(0); - - DecisionTree::Leaf(row.assigns, row.then) - } else { - let fallback = if default_matrix.rows.is_empty() { - fallback_option.unwrap() - } else { - do_build_tree( - subject_name, - subject_tipo, - // Since everything after this point had a wild card on or above - // the row for the selected column in front. Then we ignore the - // cases and continue to check other columns. - default_matrix, - fallback_option, - ) - }; - - // TODO: pass in interner and use unique string - let hoisted_name = "HoistedThing".to_string(); - - DecisionTree::HoistThen( - hoisted_name.clone(), - fallback.into(), - DecisionTree::Switch { - subject_name: subject_name.clone(), - subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), - path, - cases: specialized_matrices - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), - ), - ) - }) - .collect_vec(), - default: DecisionTree::HoistedLeaf(hoisted_name).into(), - } - .into(), - ) - } -} - #[cfg(test)] mod tester { use std::collections::HashMap; + use indexmap::IndexMap; + use crate::{ ast::{Definition, ModuleKind, TraceLevel, Tracing, TypedModule, UntypedModule}, builtins, expr::{Type, TypedExpr}, - gen_uplc::decision_tree::build_tree, + gen_uplc::{decision_tree::TreeGen, interner::AirInterner}, parser, tipo::error::{Error, Warning}, IdGenerator, @@ -703,8 +728,16 @@ mod tester { let TypedExpr::When { clauses, .. } = &function.body else { panic!() }; + let mut air_interner = AirInterner::new(); - let tree = build_tree(&"subject".to_string(), &Type::list(Type::int()), clauses); + let data_types = IndexMap::new(); + + let tree_gen = TreeGen { + interner: &mut air_interner, + data_types: &data_types, + }; + + let tree = tree_gen.build_tree(&"subject".to_string(), &Type::list(Type::int()), clauses); println!("TREE IS {:#?}", tree); } @@ -731,8 +764,16 @@ mod tester { let TypedExpr::When { clauses, .. } = &function.body else { panic!() }; + let mut air_interner = AirInterner::new(); - let tree = build_tree( + let data_types = IndexMap::new(); + + let tree_gen = TreeGen { + interner: &mut air_interner, + data_types: &data_types, + }; + + let tree = tree_gen.build_tree( &"subject".to_string(), &Type::tuple(vec![ Type::int(), @@ -770,7 +811,16 @@ mod tester { panic!() }; - let tree = build_tree( + let mut air_interner = AirInterner::new(); + + let data_types = IndexMap::new(); + + let tree_gen = TreeGen { + interner: &mut air_interner, + data_types: &data_types, + }; + + let tree = tree_gen.build_tree( &"subject".to_string(), &Type::tuple(vec![ Type::int(), @@ -793,7 +843,7 @@ mod tester { (a,b,#"", [2, ..y]) -> True (1,b,#"", [a]) -> False (3,b,#"aa", [x, y, ..z]) -> 2 == 2 - (3,b, c, [x, 3]) -> fail + (3,b, c, [x, 3 as q]) -> fail _ -> 1 == 1 } } @@ -809,7 +859,16 @@ mod tester { panic!() }; - let tree = build_tree( + let mut air_interner = AirInterner::new(); + + let data_types = IndexMap::new(); + + let tree_gen = TreeGen { + interner: &mut air_interner, + data_types: &data_types, + }; + + let tree = tree_gen.build_tree( &"subject".to_string(), &Type::tuple(vec![ Type::int(),