diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index 8d591a85..87f39186 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, rc::Rc}; -use itertools::Itertools; +use itertools::{Itertools, Tuples}; use crate::{ ast::{Pattern, TypedClause, TypedPattern}, @@ -13,60 +13,32 @@ struct Occurrence { amount: usize, } -#[derive(Clone)] +#[derive(Clone, Debug)] +enum Path { + Pair(usize), + Tuple(usize), +} + +impl Path { + pub fn get_index(&self) -> usize { + match self { + Path::Pair(u) | Path::Tuple(u) => *u, + } + } +} + +#[derive(Clone, Debug)] struct RowItem<'a> { - assign: Option, + path: Vec, pattern: &'a TypedPattern, } -#[derive(Clone, Eq, PartialEq)] -pub enum CaseTest { - Constr(PatternConstructor), - Int(String), - Bytes(Vec), - List(usize), - Wild, -} -impl PartialOrd for CaseTest { - fn partial_cmp(&self, other: &Self) -> Option { - match (self, other) { - (CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal), - (CaseTest::Wild, _) => Some(Ordering::Less), - (_, CaseTest::Wild) => Some(Ordering::Greater), - (_, _) => Some(Ordering::Equal), - } - } -} - -impl Ord for CaseTest { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (CaseTest::Wild, CaseTest::Wild) => Ordering::Equal, - (CaseTest::Wild, _) => Ordering::Less, - (_, CaseTest::Wild) => Ordering::Greater, - (_, _) => Ordering::Equal, - } - } -} - +#[derive(Clone, Debug)] struct Assign { - subject_name: String, - subject_tuple_index: Option, + path: Vec, assigned: String, } -enum DecisionTree { - Switch { - subject_name: String, - subject_tuple_index: Option, - subject_tipo: Rc, - column_to_test: usize, - cases: Vec<(CaseTest, DecisionTree)>, - default: Box, - }, - Leaf(TypedExpr), -} - struct Row<'a> { assigns: Vec, columns: Vec>, @@ -77,60 +49,152 @@ struct PatternMatrix<'a> { rows: Vec>, } -fn map_to_row<'a>( +#[derive(Clone, Eq, PartialEq)] +pub enum CaseTest { + Constr(PatternConstructor), + Int(String), + Bytes(Vec), + List(usize), + Wild, +} + +impl PartialOrd for CaseTest { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal), + (CaseTest::Wild, _) => Some(Ordering::Less), + (_, CaseTest::Wild) => Some(Ordering::Greater), + (_, _) => Some(Ordering::Equal), + } + } +} + +impl Ord for CaseTest { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (CaseTest::Wild, CaseTest::Wild) => Ordering::Equal, + (CaseTest::Wild, _) => Ordering::Less, + (_, CaseTest::Wild) => Ordering::Greater, + (_, _) => Ordering::Equal, + } + } +} + +enum DecisionTree { + Switch { + subject_name: String, + subject_tipo: Rc, + column_to_test: usize, + cases: Vec<(CaseTest, Vec, DecisionTree)>, + default: (Vec, Box), + }, + Leaf(TypedExpr), +} + +fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { + while let Some((p, rest)) = path.split_first() { + let index = p.get_index(); + + subject_tipo = subject_tipo.arg_types().unwrap().remove(index); + path = rest + } + subject_tipo +} + +fn map_pattern_to_row<'a>( pattern: &'a TypedPattern, subject_name: &String, - column_count: usize, -) -> Vec> { - match pattern { - Pattern::Var { name, .. } => vec![RowItem { - assign: Some(name.clone()), - pattern, - }] - .into_iter() - .cycle() - .take(column_count) - .collect_vec(), + subject_tipo: Rc, + path: Vec, +) -> (Vec, Vec>) { + let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); - Pattern::Assign { name, pattern, .. } => { - let p = map_to_row(pattern, subject_name, column_count); - p.into_iter() - .map(|mut item| { - item.assign = Some(name.clone()); - item - }) - .collect_vec() - } + let new_columns_added = if current_tipo.is_pair() { + 2 + } else if current_tipo.is_tuple() { + let Type::Tuple { elems, .. } = subject_tipo.as_ref() else { + unreachable!() + }; + elems.len() + } else { + 1 + }; + + match pattern { + Pattern::Var { name, .. } => ( + vec![Assign { + path: path.clone(), + assigned: name.clone(), + }], + vec![RowItem { + pattern, + path: path.clone(), + }] + .into_iter() + .cycle() + .take(new_columns_added) + .collect_vec(), + ), + + Pattern::Assign { name, pattern, .. } => ( + vec![Assign { + path: path.clone(), + assigned: name.clone(), + }], + map_pattern_to_row(pattern, subject_name, subject_tipo.clone(), path).1, + ), Pattern::Int { .. } | Pattern::ByteArray { .. } | Pattern::Discard { .. } | Pattern::List { .. } - | Pattern::Constructor { .. } => vec![RowItem { - assign: None, - pattern, - }] - .into_iter() - .cycle() - .take(column_count) - .collect_vec(), - - Pattern::Pair { fst, snd, .. } => vec![ - RowItem { - assign: None, - pattern: fst, - }, - RowItem { - assign: None, - pattern: snd, - }, - ], - Pattern::Tuple { elems, .. } => elems - .iter() - .map(|elem| RowItem { - assign: None, - pattern: elem, - }) + | Pattern::Constructor { .. } => ( + vec![], + vec![RowItem { + pattern, + path: path.clone(), + }] + .into_iter() + .cycle() + .take(new_columns_added) .collect_vec(), + ), + + Pattern::Pair { fst, snd, .. } => { + let mut fst_path = path.clone(); + fst_path.push(Path::Pair(0)); + let mut snd_path = path; + snd_path.push(Path::Pair(1)); + + let (mut assigns, mut patts) = + map_pattern_to_row(fst, subject_name, subject_tipo.clone(), fst_path); + + let (assign_snd, patt_snd) = + map_pattern_to_row(snd, subject_name, subject_tipo.clone(), snd_path); + + assigns.extend(assign_snd.into_iter()); + + patts.extend(patt_snd.into_iter()); + + (assigns, patts) + } + Pattern::Tuple { elems, .. } => { + elems + .iter() + .enumerate() + .fold((vec![], vec![]), |mut acc, (index, item)| { + let mut item_path = path.clone(); + + item_path.push(Path::Tuple(index)); + + let (assigns, patts) = + map_pattern_to_row(item, subject_name, subject_tipo.clone(), item_path); + + acc.0.extend(assigns.into_iter()); + acc.1.extend(patts.into_iter()); + + acc + }) + } } } @@ -147,64 +211,36 @@ pub fn build_tree( subject_tipo: Rc, clauses: &Vec, ) -> DecisionTree { - let column_count = if subject_tipo.is_pair() { - 2 - } else if subject_tipo.is_tuple() { - let Type::Tuple { elems, .. } = subject_tipo.as_ref() else { - unreachable!() - }; - elems.len() - } else { - 1 - }; - let rows = clauses .iter() .map(|clause| { - let row_items = map_to_row(&clause.pattern, subject_name, column_count); + let (assign, row_items) = + map_pattern_to_row(&clause.pattern, subject_name, subject_tipo.clone(), vec![]); Row { - assigns: vec![], + assigns: assign.into_iter().collect_vec(), columns: row_items, then: &clause.then, } }) .collect_vec(); - let subject_per_column = if column_count > 1 { - (0..column_count) - .map(|index| (subject_name.clone(), Some(index))) - .collect_vec() - } else { - vec![(subject_name.clone(), None)] - }; - - do_build_tree( - subject_name, - subject_tipo, - subject_per_column, - PatternMatrix { rows }, - ) + do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows }) } pub fn do_build_tree<'a>( subject_name: &String, - subject_tipo: Rc, - subject_per_column: Vec<(String, Option)>, + subject_tipo: &Rc, matrix: PatternMatrix<'a>, ) -> DecisionTree { - let column_count = if subject_tipo.is_pair() { - 2 - } else if subject_tipo.is_tuple() { - let Type::Tuple { elems, .. } = subject_tipo.as_ref() else { - unreachable!() - }; - elems.len() - } else { - 1 - }; + let column_length = matrix.rows[0].columns.len(); - let occurrences = [Occurrence::default()].repeat(column_count); + assert!(matrix + .rows + .iter() + .all(|row| { row.columns.len() == column_length })); + + let occurrences = [Occurrence::default()].repeat(column_length); let occurrences = matrix @@ -238,10 +274,9 @@ pub fn do_build_tree<'a>( } }); - if column_count > 1 { + if column_length > 1 { DecisionTree::Switch { subject_name: subject_name.clone(), - subject_tuple_index: None, subject_tipo: subject_tipo.clone(), column_to_test: highest_occurrence.0, cases: todo!(), @@ -250,42 +285,53 @@ pub fn do_build_tree<'a>( } else { let mut collection_vec = matrix.rows.into_iter().fold( vec![], - |mut collection_vec: Vec<(CaseTest, Vec>)>, mut item: Row<'a>| { + |mut collection_vec: Vec<(CaseTest, Vec, Vec>)>, mut item: Row<'a>| { let col = item.columns.remove(highest_occurrence.0); - let mut patt = col.pattern; - if let Pattern::Assign { pattern, .. } = patt { - patt = pattern; - } + assert!(!matches!(col.pattern, Pattern::Assign { .. })); - if let Some(assign) = col.assign { - item.assigns.push(Assign { - subject_name: subject_name.clone(), - subject_tuple_index: None, - assigned: assign, - }); - } + let (mapped_args, case) = match col.pattern { + Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())), + Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())), + Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild), + Pattern::List { elements, .. } => ( + elements + .iter() + .enumerate() + .map(|(index, item)| { + let mut item_path = col.path.clone(); - let case = match patt { - Pattern::Int { value, .. } => CaseTest::Int(value.clone()), - Pattern::ByteArray { value, .. } => CaseTest::Bytes(value.clone()), - Pattern::Var { .. } | Pattern::Discard { .. } => CaseTest::Wild, - Pattern::List { elements, .. } => CaseTest::List(elements.len()), - Pattern::Constructor { constructor, .. } => { - CaseTest::Constr(constructor.clone()) + item_path.push(Path::Tuple(index)); + + map_pattern_to_row( + item, + subject_name, + subject_tipo.clone(), + item_path, + ) + }) + .collect_vec(), + CaseTest::List(elements.len()), + ), + + Pattern::Constructor { .. } => { + todo!() } - Pattern::Pair { .. } => todo!(), - Pattern::Tuple { .. } => todo!(), - _ => unreachable!(), + _ => unreachable!("{:#?}", col.pattern), }; + item.assigns + .extend(mapped_args.iter().map(|x| x.0.clone()).flatten()); + item.columns + .extend(mapped_args.into_iter().map(|x| x.1).flatten()); + if let Some(index) = collection_vec.iter().position(|item| item.0 == case) { let entry = collection_vec.get_mut(index).unwrap(); - entry.1.push(item); + entry.2.push(item); collection_vec } else { - collection_vec.push((case, vec![item])); + collection_vec.push((case, col.path, vec![item])); collection_vec } @@ -293,27 +339,31 @@ pub fn do_build_tree<'a>( ); collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); - let mut collection_iter = collection_vec.into_iter().peekable(); + let mut collection_iter = collection_vec + .into_iter() + .map(|x| { + ( + x.0, + x.1, + do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.2 }), + ) + }) + .peekable(); let cases = collection_iter .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) .collect_vec(); - let mut fallback = collection_iter.collect_vec(); + let mut fallback = collection_iter.map(|x| (x.1, x.2.into())).collect_vec(); assert!(fallback.len() == 1); - let fallback_matrix = PatternMatrix { - rows: fallback.remove(0).1, - }; - DecisionTree::Switch { subject_name: subject_name.clone(), - subject_tuple_index: None, subject_tipo: subject_tipo.clone(), column_to_test: highest_occurrence.0, - cases: todo!(), - default: todo!(), + cases, + default: fallback.remove(0), } };