From b340de2cfdc5819229472490a444905fc12d4bc3 Mon Sep 17 00:00:00 2001 From: microproofs Date: Tue, 15 Oct 2024 14:37:19 -0400 Subject: [PATCH] Incremental commit for dealing with list tails --- .../aiken-lang/src/gen_uplc/decision_tree.rs | 134 +++++++++++++----- 1 file changed, 102 insertions(+), 32 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index da13080a..397f5ccc 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, rc::Rc}; -use itertools::Itertools; +use itertools::{Itertools, Position}; use crate::{ ast::{Pattern, TypedClause, TypedPattern}, @@ -18,6 +18,7 @@ pub enum Path { Pair(usize), Tuple(usize), List(usize), + ListTail(usize), } #[derive(Clone, Debug)] @@ -27,14 +28,14 @@ struct RowItem<'a> { } #[derive(Clone, Debug)] -pub struct Assign { +pub struct Assigned { path: Vec, assigned: String, } #[derive(Clone, Debug)] struct Row<'a> { - assigns: Vec, + assigns: Vec, columns: Vec>, then: &'a TypedExpr, } @@ -49,7 +50,8 @@ pub enum CaseTest { Constr(PatternConstructor), Int(String), Bytes(Vec), - List(usize, bool), + List(usize), + ListWithTail(usize), Wild, } @@ -84,7 +86,17 @@ pub enum DecisionTree<'a> { cases: Vec<(CaseTest, DecisionTree<'a>)>, default: Box>, }, - Leaf(Vec, &'a TypedExpr), + ListSwitch { + subject_name: String, + subject_tipo: Rc, + path: Vec, + cases: Vec<(CaseTest, DecisionTree<'a>)>, + tail_cases: Vec<(CaseTest, DecisionTree<'a>)>, + default: Box>, + }, + Leaf(Vec, &'a TypedExpr), + HoistedLeaf(String), + HoistThen(Vec, &'a TypedExpr, Box>), } fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { @@ -94,6 +106,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { subject_tipo.get_inner_types().swap_remove(*index) } Path::List(_) => subject_tipo.get_inner_types().swap_remove(0), + Path::ListTail(_) => subject_tipo, }; path = rest @@ -105,7 +118,7 @@ fn map_pattern_to_row<'a>( pattern: &'a TypedPattern, subject_tipo: &Rc, path: Vec, -) -> (Vec, Vec>) { +) -> (Vec, Vec>) { let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); let new_columns_added = if current_tipo.is_pair() { @@ -121,7 +134,7 @@ fn map_pattern_to_row<'a>( match pattern { Pattern::Var { name, .. } => ( - vec![Assign { + vec![Assigned { path: path.clone(), assigned: name.clone(), }], @@ -136,7 +149,7 @@ fn map_pattern_to_row<'a>( ), Pattern::Assign { name, pattern, .. } => ( - vec![Assign { + vec![Assigned { path: path.clone(), assigned: name.clone(), }], @@ -225,19 +238,8 @@ pub fn build_tree<'a>( do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) } -fn do_build_tree<'a>( - subject_name: &String, - subject_tipo: &Rc, - matrix: PatternMatrix<'a>, - fallback_option: Option>, -) -> DecisionTree<'a> { - let column_length = matrix.rows[0].columns.len(); - - assert!(matrix - .rows - .iter() - .all(|row| { row.columns.len() == column_length })); - +// A function to get which column has the most pattern matches before a wild card +fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize { let occurrences = [Occurrence::default()].repeat(column_length); let occurrences = @@ -272,6 +274,44 @@ fn do_build_tree<'a>( } }); + highest_occurrence.0 +} + +fn do_build_tree<'a>( + subject_name: &String, + subject_tipo: &Rc, + matrix: PatternMatrix<'a>, + fallback_option: Option>, +) -> DecisionTree<'a> { + let column_length = matrix.rows[0].columns.len(); + + assert!(matrix + .rows + .iter() + .all(|row| { row.columns.len() == column_length })); + + let occurrence_col = highest_occurrence(&matrix, column_length); + + let mut longest_elems = None; + + matrix.rows.iter().for_each(|item| { + let col = &item.columns[occurrence_col]; + + match col.pattern { + Pattern::List { elements, .. } => match longest_elems { + Some(elems_count) => { + if elems_count < elements.len() { + longest_elems = Some(elements.len()); + } + } + None => { + longest_elems = Some(elements.len()); + } + }, + _ => (), + } + }); + let (path, mut collection_vec) = matrix.rows.into_iter().fold( (vec![], vec![]), |mut collection_vec: (Vec, Vec<(CaseTest, Vec>)>), mut item: Row<'a>| { @@ -280,7 +320,7 @@ fn do_build_tree<'a>( return collection_vec; } - let col = item.columns.remove(highest_occurrence.0); + let col = item.columns.remove(occurrence_col); assert!(!matches!(col.pattern, Pattern::Assign { .. })); @@ -293,15 +333,40 @@ fn do_build_tree<'a>( .iter() .chain(tail.as_ref().map(|tail| tail.as_ref())) .enumerate() - .map(|(index, item)| { - let mut item_path = col.path.clone(); + .with_position() + .map(|elem| match elem { + Position::First((index, element)) + | Position::Middle((index, element)) + | Position::Only((index, element)) => { + let mut item_path = col.path.clone(); - item_path.push(Path::List(index)); + item_path.push(Path::List(index)); - map_pattern_to_row(item, subject_tipo, item_path) + map_pattern_to_row(element, subject_tipo, item_path) + } + + Position::Last((index, element)) => { + if tail.is_none() { + let mut item_path = col.path.clone(); + + item_path.push(Path::List(index)); + + map_pattern_to_row(element, subject_tipo, item_path) + } else { + let mut item_path = col.path.clone(); + + item_path.push(Path::ListTail(index)); + + map_pattern_to_row(element, subject_tipo, item_path) + } + } }) .collect_vec(), - CaseTest::List(elements.len(), tail.is_some()), + if tail.is_none() { + CaseTest::List(elements.len()) + } else { + CaseTest::ListWithTail(elements.len()) + }, ), Pattern::Constructor { .. } => { @@ -310,11 +375,7 @@ fn do_build_tree<'a>( _ => unreachable!("{:#?}", col.pattern), }; - item.assigns - .extend(mapped_args.iter().flat_map(|x| x.0.clone())); - item.columns - .extend(mapped_args.into_iter().flat_map(|x| x.1)); - + // Assert path is matches for each row except for wild_card assert!( collection_vec.0.is_empty() || collection_vec.0 == col.path @@ -325,6 +386,15 @@ fn do_build_tree<'a>( collection_vec.0 = col.path; } + // expand assigns by newly added ones + item.assigns + .extend(mapped_args.iter().flat_map(|x| x.0.clone())); + + // Add inner patterns to existing row + item.columns + .extend(mapped_args.into_iter().flat_map(|x| x.1)); + + // TODO: Handle special casetest of ListWithTail if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) { entry.1.push(item); collection_vec