diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index 4d9b9d76..c6a9276b 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -342,57 +342,6 @@ pub fn erase_opaque_type_operations( } } -/// Determine whether this air_tree node introduces any shadowing over `potential_matches` -pub fn find_introduced_variables(air_tree: &AirTree) -> Vec { - match air_tree { - AirTree::Let { name, .. } => vec![name.clone()], - AirTree::SoftCastLet { name, .. } => vec![name.clone()], - AirTree::TupleGuard { indices, .. } | AirTree::TupleClause { indices, .. } => { - indices.iter().map(|(_, name)| name.clone()).collect() - } - AirTree::PairGuard { - fst_name, snd_name, .. - } => fst_name - .iter() - .cloned() - .chain(snd_name.iter().cloned()) - .collect_vec(), - AirTree::PairAccessor { fst, snd, .. } => { - fst.iter().cloned().chain(snd.iter().cloned()).collect_vec() - } - AirTree::PairClause { - fst_name, snd_name, .. - } => fst_name - .iter() - .cloned() - .chain(snd_name.iter().cloned()) - .collect_vec(), - AirTree::Fn { params, .. } => params.to_vec(), - AirTree::ListAccessor { names, .. } => names.clone(), - AirTree::ListExpose { - tail, - tail_head_names, - .. - } => { - let mut ret = vec![]; - if let Some((_, head)) = tail { - ret.push(head.clone()) - } - - for name in tail_head_names.iter().map(|(_, head)| head) { - ret.push(name.clone()); - } - ret - } - AirTree::TupleAccessor { names, .. } => names.clone(), - AirTree::FieldsExpose { indices, .. } => { - indices.iter().map(|(_, name, _)| name.clone()).collect() - } - AirTree::When { subject_name, .. } => vec![subject_name.clone()], - _ => vec![], - } -} - /// Determine whether a function is recursive, and if so, get the arguments pub fn is_recursive_function_call<'a>( air_tree: &'a AirTree, diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index 397f5ccc..a7094d54 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,4 +1,4 @@ -use std::{cmp::Ordering, rc::Rc}; +use std::rc::Rc; use itertools::{Itertools, Position}; @@ -21,18 +21,18 @@ pub enum Path { ListTail(usize), } -#[derive(Clone, Debug)] -struct RowItem<'a> { - path: Vec, - pattern: &'a TypedPattern, -} - #[derive(Clone, Debug)] pub struct Assigned { path: Vec, assigned: String, } +#[derive(Clone, Debug)] +struct RowItem<'a> { + path: Vec, + pattern: &'a TypedPattern, +} + #[derive(Clone, Debug)] struct Row<'a> { assigns: Vec, @@ -55,28 +55,6 @@ pub enum CaseTest { Wild, } -impl PartialOrd for CaseTest { - fn partial_cmp(&self, other: &Self) -> Option { - match (self, other) { - (CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal), - (CaseTest::Wild, _) => Some(Ordering::Greater), - (_, CaseTest::Wild) => Some(Ordering::Less), - (_, _) => Some(Ordering::Equal), - } - } -} - -impl Ord for CaseTest { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (CaseTest::Wild, CaseTest::Wild) => Ordering::Equal, - (CaseTest::Wild, _) => Ordering::Greater, - (_, CaseTest::Wild) => Ordering::Less, - (_, _) => Ordering::Equal, - } - } -} - #[derive(Debug, Clone)] pub enum DecisionTree<'a> { Switch { @@ -92,11 +70,11 @@ pub enum DecisionTree<'a> { path: Vec, cases: Vec<(CaseTest, DecisionTree<'a>)>, tail_cases: Vec<(CaseTest, DecisionTree<'a>)>, - default: Box>, + default: Option>>, }, Leaf(Vec, &'a TypedExpr), HoistedLeaf(String), - HoistThen(Vec, &'a TypedExpr, Box>), + HoistThen(String, Box>, Box>), } fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { @@ -215,29 +193,6 @@ fn match_wild_card(pattern: &TypedPattern) -> bool { } } -pub fn build_tree<'a>( - subject_name: &String, - subject_tipo: &Rc, - clauses: &'a [TypedClause], -) -> DecisionTree<'a> { - let rows = clauses - .iter() - .map(|clause| { - let (assign, row_items) = map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); - - Row { - assigns: assign.into_iter().collect_vec(), - columns: row_items, - then: &clause.then, - } - }) - .collect_vec(); - - println!("INITIAL ROWS ARE {:#?}", rows); - - do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) -} - // A function to get which column has the most pattern matches before a wild card fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize { let occurrences = [Occurrence::default()].repeat(column_length); @@ -277,6 +232,27 @@ fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize { highest_occurrence.0 } +pub fn build_tree<'a>( + subject_name: &String, + subject_tipo: &Rc, + clauses: &'a [TypedClause], +) -> DecisionTree<'a> { + let rows = clauses + .iter() + .map(|clause| { + let (assign, row_items) = map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); + + Row { + assigns: assign.into_iter().collect_vec(), + columns: row_items, + then: &clause.then, + } + }) + .collect_vec(); + + do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) +} + fn do_build_tree<'a>( subject_name: &String, subject_tipo: &Rc, @@ -292,43 +268,74 @@ fn do_build_tree<'a>( let occurrence_col = highest_occurrence(&matrix, column_length); - let mut longest_elems = None; + let mut longest_elems_no_tail = None; + let mut longest_elems_with_tail = None; + let mut has_list_pattern = false; matrix.rows.iter().for_each(|item| { let col = &item.columns[occurrence_col]; match col.pattern { - Pattern::List { elements, .. } => match longest_elems { - Some(elems_count) => { - if elems_count < elements.len() { - longest_elems = Some(elements.len()); + Pattern::List { elements, tail, .. } => { + has_list_pattern = true; + if tail.is_none() { + match longest_elems_no_tail { + Some(elems_count) => { + if elems_count < elements.len() { + longest_elems_no_tail = Some(elements.len()); + } + } + None => { + longest_elems_no_tail = Some(elements.len()); + } + } + } else { + match longest_elems_with_tail { + Some(elems_count) => { + if elems_count < elements.len() { + longest_elems_with_tail = Some(elements.len()); + } + } + None => { + longest_elems_with_tail = Some(elements.len()); + } } } - None => { - longest_elems = Some(elements.len()); - } - }, + } _ => (), } }); - let (path, mut collection_vec) = matrix.rows.into_iter().fold( - (vec![], vec![]), - |mut collection_vec: (Vec, Vec<(CaseTest, Vec>)>), mut item: Row<'a>| { - if item.columns.is_empty() { - collection_vec.1.push((CaseTest::Wild, vec![item])); - return collection_vec; + let path = matrix + .rows + .get(0) + .unwrap() + .columns + .get(occurrence_col) + .map(|col| col.path.clone()) + .unwrap_or(vec![]); + + let mut row_iter = matrix.rows.into_iter().peekable(); + + let specialized_matrices = row_iter + .peeking_take_while(|row| !match_wild_card(&row.columns[occurrence_col].pattern)) + .fold(vec![], |mut case_matrices, mut row| { + if row.columns.is_empty() { + case_matrices.push((CaseTest::Wild, vec![row])); + return case_matrices; } - let col = item.columns.remove(occurrence_col); + let col = row.columns.remove(occurrence_col); - assert!(!matches!(col.pattern, Pattern::Assign { .. })); - - let (mapped_args, case) = match col.pattern { - Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())), - Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())), - Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild), + let (case, remaining_patts) = match col.pattern { + Pattern::Int { value, .. } => (CaseTest::Int(value.clone()), vec![]), + Pattern::ByteArray { value, .. } => (CaseTest::Bytes(value.clone()), vec![]), Pattern::List { elements, tail, .. } => ( + if tail.is_none() { + CaseTest::List(elements.len()) + } else { + CaseTest::ListWithTail(elements.len()) + }, elements .iter() .chain(tail.as_ref().map(|tail| tail.as_ref())) @@ -362,109 +369,246 @@ fn do_build_tree<'a>( } }) .collect_vec(), - if tail.is_none() { - CaseTest::List(elements.len()) - } else { - CaseTest::ListWithTail(elements.len()) - }, ), Pattern::Constructor { .. } => { todo!() } - _ => unreachable!("{:#?}", col.pattern), + Pattern::Tuple { .. } + | Pattern::Pair { .. } + | Pattern::Assign { .. } + | Pattern::Var { .. } + | Pattern::Discard { .. } => { + unreachable!("{:#?}", col.pattern) + } }; - // Assert path is matches for each row except for wild_card - assert!( - collection_vec.0.is_empty() - || collection_vec.0 == col.path - || matches!(case, CaseTest::Wild) - ); - - if collection_vec.0.is_empty() { - collection_vec.0 = col.path; - } + // Assert path is the same for each specialized row + assert!(path == col.path); // expand assigns by newly added ones - item.assigns - .extend(mapped_args.iter().flat_map(|x| x.0.clone())); + row.assigns + .extend(remaining_patts.iter().flat_map(|x| x.0.clone())); // Add inner patterns to existing row - item.columns - .extend(mapped_args.into_iter().flat_map(|x| x.1)); + row.columns + .extend(remaining_patts.into_iter().flat_map(|x| x.1)); - // TODO: Handle special casetest of ListWithTail - if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) { - entry.1.push(item); - collection_vec + // For lists with tail it's a special case where we also add it to existing patterns + // all the way to the longest element. The reason being that each list size greater + // than the list with tail could also match with could also match depending on the inner pattern. + // See tests below for an example + if let CaseTest::ListWithTail(elems_len) = case { + if let Some(longest_elems_no_tail) = longest_elems_no_tail { + for elem_count in elems_len..=longest_elems_no_tail { + let case = CaseTest::List(elem_count); + + let mut row = row.clone(); + + let tail = row.columns.pop().unwrap(); + + let columns_to_fill = (0..(elem_count - elems_len)) + .map(|_| tail.clone()) + .collect_vec(); + + row.columns.extend(columns_to_fill); + + if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } + } + } + + let Some(longest_elems_with_tail) = longest_elems_with_tail else { + unreachable!() + }; + + for elem_count in elems_len..=longest_elems_with_tail { + let case = CaseTest::ListWithTail(elem_count); + + let mut row = row.clone(); + + let tail = row.columns.pop().unwrap(); + + let columns_to_fill = (0..(elem_count - elems_len)) + .map(|_| tail.clone()) + .collect_vec(); + + row.columns.extend(columns_to_fill); + + if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } + } } else { - collection_vec.1.push((case, vec![item])); - - collection_vec + if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) { + entry.1.push(row); + } else { + case_matrices.push((case, vec![row])); + } } - }, - ); - collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); + case_matrices + }); - let mut collection_iter = collection_vec.into_iter().peekable(); + let default_matrix = PatternMatrix { + rows: row_iter.collect_vec(), + }; - let cases = collection_iter - .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) - .collect_vec(); + if has_list_pattern { + // Since the list_tail case might cover the rest of the possible matches extensively + // then fallback is optional here + let fallback_option = if default_matrix.rows.is_empty() { + fallback_option + } else { + Some(do_build_tree( + subject_name, + subject_tipo, + // Since everything after this point had a wild card on or above + // the row for the selected column in front. Then we ignore the + // cases and continue to check other columns. + default_matrix, + fallback_option, + )) + }; - if cases.is_empty() { - let mut fallback = collection_iter.collect_vec(); + let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices + .into_iter() + .partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_))); - assert!(fallback.len() == 1); + // TODO: pass in interner and use unique string + let hoisted_name = "HoistedThing".to_string(); - let mut remaining = fallback.swap_remove(0).1; + if let Some(fallback) = fallback_option { + DecisionTree::HoistThen( + hoisted_name.clone(), + fallback.into(), + DecisionTree::ListSwitch { + subject_name: subject_name.clone(), + subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), + path, + cases: cases + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + tail_cases: tail_cases + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + default: Some(DecisionTree::HoistedLeaf(hoisted_name).into()), + } + .into(), + ) + } else { + DecisionTree::ListSwitch { + subject_name: subject_name.clone(), + subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), + path, + cases: cases + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + None, + ), + ) + }) + .collect_vec(), + tail_cases: tail_cases + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + None, + ), + ) + }) + .collect_vec(), + default: None, + } + } + } else if specialized_matrices.is_empty() { + // No more patterns to match on so we grab the first default row and return that + let mut fallback = default_matrix.rows; - assert!(remaining.len() == 1); - - let row = remaining.swap_remove(0); + let row = fallback.swap_remove(0); DecisionTree::Leaf(row.assigns, row.then) } else { - let mut fallback = collection_iter - .map(|x| { - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - None, - ) - }) - .collect_vec(); - assert!(fallback.len() == 1 || fallback_option.is_some()); - - let fallback = if !fallback.is_empty() { - fallback.swap_remove(0) - } else { + let fallback = if default_matrix.rows.is_empty() { fallback_option.unwrap() + } else { + do_build_tree( + subject_name, + subject_tipo, + // Since everything after this point had a wild card on or above + // the row for the selected column in front. Then we ignore the + // cases and continue to check other columns. + default_matrix, + fallback_option, + ) }; - DecisionTree::Switch { - subject_name: subject_name.clone(), - subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), - path, - cases: cases - .into_iter() - .map(|x| { - ( - x.0, - do_build_tree( - subject_name, - subject_tipo, - PatternMatrix { rows: x.1 }, - Some(fallback.clone()), - ), - ) - }) - .collect_vec(), - default: fallback.into(), - } + // TODO: pass in interner and use unique string + let hoisted_name = "HoistedThing".to_string(); + + DecisionTree::HoistThen( + hoisted_name.clone(), + fallback.into(), + DecisionTree::Switch { + subject_name: subject_name.clone(), + subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), + path, + cases: specialized_matrices + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(DecisionTree::HoistedLeaf(hoisted_name.clone())), + ), + ) + }) + .collect_vec(), + default: DecisionTree::HoistedLeaf(hoisted_name).into(), + } + .into(), + ) } } @@ -600,7 +744,82 @@ mod tester { ); println!("TREE IS {:#?}", tree); + } - panic!() + #[test] + fn thing3() { + let source_code = r#" + test thing(){ + when (1,2,#"",[]) is { + (2,b,#"", []) -> 4 == 4 + (a,b,#"", [2, ..y]) -> True + (1,b,#"", [a]) -> False + (3,b,#"aa", [x, y, ..z]) -> 2 == 2 + _ -> 1 == 1 + } + } + "#; + + let (_, ast) = check(parse(source_code)).unwrap(); + + let Definition::Test(function) = &ast.definitions[0] else { + panic!() + }; + + let TypedExpr::When { clauses, .. } = &function.body else { + panic!() + }; + + let tree = build_tree( + &"subject".to_string(), + &Type::tuple(vec![ + Type::int(), + Type::int(), + Type::byte_array(), + Type::list(Type::int()), + ]), + clauses, + ); + + println!("TREE IS {:#?}", tree); + } + + #[test] + fn thing4() { + let source_code = r#" + test thing(){ + when (1,2,#"",[]) is { + (2,b,#"", []) -> 4 == 4 + (a,b,#"", [2, ..y]) -> True + (1,b,#"", [a]) -> False + (3,b,#"aa", [x, y, ..z]) -> 2 == 2 + (3,b, c, [x, 3]) -> fail + _ -> 1 == 1 + } + } + "#; + + let (_, ast) = check(parse(source_code)).unwrap(); + + let Definition::Test(function) = &ast.definitions[0] else { + panic!() + }; + + let TypedExpr::When { clauses, .. } = &function.body else { + panic!() + }; + + let tree = build_tree( + &"subject".to_string(), + &Type::tuple(vec![ + Type::int(), + Type::int(), + Type::byte_array(), + Type::list(Type::int()), + ]), + clauses, + ); + + println!("TREE IS {:#?}", tree); } }