diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index ae7e907f..0bf1ba18 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -5,25 +5,43 @@ use itertools::{Itertools, Position}; use crate::{ ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern}, - expr::{lookup_data_type_by_tipo, PatternConstructor, Type, TypedExpr}, + expr::{lookup_data_type_by_tipo, Type, TypedExpr}, }; use super::interner::AirInterner; +const PAIR_NEW_COLUMNS: usize = 2; + +const MIN_NEW_COLUMNS: usize = 1; + #[derive(Clone, Default, Copy)] struct Occurrence { passed_wild_card: bool, amount: usize, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub enum Path { Pair(usize), Tuple(usize), + Constr(Rc, usize), List(usize), ListTail(usize), } +impl PartialEq for Path { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Path::Pair(a), Path::Pair(b)) + | (Path::Tuple(a), Path::Tuple(b)) + | (Path::Constr(_, a), Path::Constr(_, b)) + | (Path::List(a), Path::List(b)) + | (Path::ListTail(a), Path::ListTail(b)) => a == b, + _ => false, + } + } +} + #[derive(Clone, Debug)] pub struct Assigned { path: Vec, @@ -50,7 +68,7 @@ struct PatternMatrix<'a> { #[derive(Debug, Clone, Eq, PartialEq)] pub enum CaseTest { - Constr(PatternConstructor), + Constr(usize), Int(String), Bytes(Vec), List(usize), @@ -95,7 +113,8 @@ impl<'a, 'b> TreeGen<'a, 'b> { let rows = clauses .iter() .map(|clause| { - let (assign, row_items) = map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); + let (assign, row_items) = + self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]); Row { assigns: assign.into_iter().collect_vec(), @@ -209,7 +228,7 @@ impl<'a, 'b> TreeGen<'a, 'b> { item_path.push(Path::List(index)); - map_pattern_to_row(element, subject_tipo, item_path) + self.map_pattern_to_row(element, subject_tipo, item_path) } Position::Last((index, element)) => { @@ -218,13 +237,13 @@ impl<'a, 'b> TreeGen<'a, 'b> { item_path.push(Path::List(index)); - map_pattern_to_row(element, subject_tipo, item_path) + self.map_pattern_to_row(element, subject_tipo, item_path) } else { let mut item_path = col.path.clone(); item_path.push(Path::ListTail(index)); - map_pattern_to_row(element, subject_tipo, item_path) + self.map_pattern_to_row(element, subject_tipo, item_path) } } }) @@ -232,12 +251,35 @@ impl<'a, 'b> TreeGen<'a, 'b> { ), Pattern::Constructor { - name, arguments, .. + name, + arguments, + tipo, + .. } => { let data_type = - lookup_data_type_by_tipo(&self.data_types, &specialized_tipo); + lookup_data_type_by_tipo(&self.data_types, &specialized_tipo).unwrap(); - todo!() + let (constr_index, _) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, dt)| &dt.name == name) + .unwrap(); + + ( + CaseTest::Constr(constr_index), + arguments + .iter() + .enumerate() + .map(|(index, arg)| { + let mut item_path = col.path.clone(); + + item_path.push(Path::Constr(tipo.clone(), index)); + + self.map_pattern_to_row(&arg.value, subject_tipo, item_path) + }) + .collect_vec(), + ) } Pattern::Tuple { .. } | Pattern::Pair { .. } @@ -478,6 +520,145 @@ impl<'a, 'b> TreeGen<'a, 'b> { ) } } + + fn map_pattern_to_row( + &self, + pattern: &'a TypedPattern, + subject_tipo: &Rc, + path: Vec, + ) -> (Vec, Vec>) { + let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); + + let new_columns_added = if current_tipo.is_pair() { + PAIR_NEW_COLUMNS + } else if current_tipo.is_tuple() { + let Type::Tuple { elems, .. } = current_tipo.as_ref() else { + unreachable!() + }; + elems.len() + } else if let Some(data) = lookup_data_type_by_tipo(self.data_types, subject_tipo) { + if data.constructors.len() == 1 { + data.constructors[0].arguments.len() + } else { + MIN_NEW_COLUMNS + } + } else { + MIN_NEW_COLUMNS + }; + + match pattern { + Pattern::Var { name, .. } => ( + vec![Assigned { + path: path.clone(), + assigned: name.clone(), + }], + vec![RowItem { + pattern, + path: path.clone(), + }] + .into_iter() + .cycle() + .take(new_columns_added) + .collect_vec(), + ), + + Pattern::Assign { name, pattern, .. } => ( + vec![Assigned { + path: path.clone(), + assigned: name.clone(), + }], + self.map_pattern_to_row(pattern, subject_tipo, path).1, + ), + Pattern::Int { .. } + | Pattern::ByteArray { .. } + | Pattern::Discard { .. } + | Pattern::List { .. } => ( + vec![], + vec![RowItem { + pattern, + path: path.clone(), + }] + .into_iter() + .cycle() + .take(new_columns_added) + .collect_vec(), + ), + + Pattern::Constructor { + arguments, tipo, .. + } => { + let data_type = lookup_data_type_by_tipo(self.data_types, ¤t_tipo).unwrap(); + + if data_type.constructors.len() == 1 { + arguments + .iter() + .enumerate() + .fold((vec![], vec![]), |mut acc, (index, arg)| { + let arg_value = &arg.value; + + let mut item_path = path.clone(); + + item_path.push(Path::Constr(tipo.clone(), index)); + + let (assigns, patts) = + self.map_pattern_to_row(arg_value, subject_tipo, item_path); + + acc.0.extend(assigns); + acc.1.extend(patts); + + acc + }) + } else { + ( + vec![], + vec![RowItem { + pattern, + path: path.clone(), + }] + .into_iter() + .cycle() + .take(new_columns_added) + .collect_vec(), + ) + } + } + + Pattern::Pair { fst, snd, .. } => { + let mut fst_path = path.clone(); + fst_path.push(Path::Pair(0)); + let mut snd_path = path; + snd_path.push(Path::Pair(1)); + + let (mut assigns, mut patts) = self.map_pattern_to_row(fst, subject_tipo, fst_path); + + let (assign_snd, patt_snd) = self.map_pattern_to_row(snd, subject_tipo, snd_path); + + assigns.extend(assign_snd); + + patts.extend(patt_snd); + + (assigns, patts) + } + Pattern::Tuple { elems, .. } => { + elems + .iter() + .enumerate() + .fold((vec![], vec![]), |mut acc, (index, item)| { + let mut item_path = path.clone(); + + item_path.push(Path::Tuple(index)); + + let (assigns, patts) = + self.map_pattern_to_row(item, subject_tipo, item_path); + + acc.0.extend(assigns); + acc.1.extend(patts); + + acc + }) + } + } + } } fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { @@ -488,6 +669,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { } Path::List(_) => subject_tipo.get_inner_types().swap_remove(0), Path::ListTail(_) => subject_tipo, + Path::Constr(tipo, index) => tipo.arg_types().unwrap().swap_remove(*index), }; path = rest @@ -495,99 +677,6 @@ fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { subject_tipo } -fn map_pattern_to_row<'a>( - pattern: &'a TypedPattern, - subject_tipo: &Rc, - path: Vec, -) -> (Vec, Vec>) { - let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); - - let new_columns_added = if current_tipo.is_pair() { - 2 - } else if current_tipo.is_tuple() { - let Type::Tuple { elems, .. } = current_tipo.as_ref() else { - unreachable!() - }; - elems.len() - } else { - 1 - }; - - match pattern { - Pattern::Var { name, .. } => ( - vec![Assigned { - path: path.clone(), - assigned: name.clone(), - }], - vec![RowItem { - pattern, - path: path.clone(), - }] - .into_iter() - .cycle() - .take(new_columns_added) - .collect_vec(), - ), - - Pattern::Assign { name, pattern, .. } => ( - vec![Assigned { - path: path.clone(), - assigned: name.clone(), - }], - map_pattern_to_row(pattern, subject_tipo, path).1, - ), - Pattern::Int { .. } - | Pattern::ByteArray { .. } - | Pattern::Discard { .. } - | Pattern::List { .. } - | Pattern::Constructor { .. } => ( - vec![], - vec![RowItem { - pattern, - path: path.clone(), - }] - .into_iter() - .cycle() - .take(new_columns_added) - .collect_vec(), - ), - - Pattern::Pair { fst, snd, .. } => { - let mut fst_path = path.clone(); - fst_path.push(Path::Pair(0)); - let mut snd_path = path; - snd_path.push(Path::Pair(1)); - - let (mut assigns, mut patts) = map_pattern_to_row(fst, subject_tipo, fst_path); - - let (assign_snd, patt_snd) = map_pattern_to_row(snd, subject_tipo, snd_path); - - assigns.extend(assign_snd); - - patts.extend(patt_snd); - - (assigns, patts) - } - Pattern::Tuple { elems, .. } => { - elems - .iter() - .enumerate() - .fold((vec![], vec![]), |mut acc, (index, item)| { - let mut item_path = path.clone(); - - item_path.push(Path::Tuple(index)); - - let (assigns, patts) = map_pattern_to_row(item, subject_tipo, item_path); - - acc.0.extend(assigns); - acc.1.extend(patts); - - acc - }) - } - } -} - fn match_wild_card(pattern: &TypedPattern) -> bool { match pattern { Pattern::Var { .. } | Pattern::Discard { .. } => true, @@ -642,13 +731,15 @@ mod tester { use indexmap::IndexMap; use crate::{ - ast::{Definition, ModuleKind, TraceLevel, Tracing, TypedModule, UntypedModule}, + ast::{ + well_known, Definition, ModuleKind, TraceLevel, Tracing, TypedModule, UntypedModule, + }, builtins, expr::{Type, TypedExpr}, gen_uplc::{decision_tree::TreeGen, interner::AirInterner}, parser, tipo::error::{Error, Warning}, - IdGenerator, + utils, IdGenerator, }; fn parse(source_code: &str) -> UntypedModule { @@ -881,4 +972,55 @@ mod tester { println!("TREE IS {:#?}", tree); } + + #[test] + fn thing5() { + let source_code = r#" + test thing(){ + when (1,[],#"",None) is { + (2,b,#"", Some(Seeded { choices: #"", .. })) -> 4 == 4 + (a,b,#"", None) -> True + (1,b,#"", Some(Seeded{ choices: #"", ..})) -> False + (3,b,#"aa", Some(Replayed(..))) -> 2 == 2 + (3,b, c, y) -> fail + _ -> 1 == 1 + } + } + "#; + + let (_, ast) = check(parse(source_code)).unwrap(); + + let Definition::Test(function) = &ast.definitions[0] else { + panic!() + }; + + let TypedExpr::When { clauses, .. } = &function.body else { + panic!() + }; + + let mut air_interner = AirInterner::new(); + + let id_gen = IdGenerator::new(); + + let data_types = builtins::prelude_data_types(&id_gen); + + let tree_gen = TreeGen { + interner: &mut air_interner, + data_types: &utils::indexmap::as_ref_values(&data_types), + }; + + let tree = tree_gen.build_tree( + &"subject".to_string(), + &Type::tuple(vec![ + Type::int(), + Type::int(), + Type::byte_array(), + Type::option(Type::prng()), + ]), + clauses, + ); + + println!("TREE IS {:#?}", tree); + panic!("SUPPPPPPPPPPPPPPPPPPPPPPPER DOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOONE"); + } }