From def268d96614c37076387cd5c5718c83e2e91608 Mon Sep 17 00:00:00 2001 From: microproofs Date: Fri, 11 Oct 2024 01:18:49 -0400 Subject: [PATCH] Now working for all kinds of patterns except for constr --- .../aiken-lang/src/gen_uplc/decision_tree.rs | 233 +++++++++++++++--- 1 file changed, 201 insertions(+), 32 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc/decision_tree.rs b/crates/aiken-lang/src/gen_uplc/decision_tree.rs index 19f5353e..13c0633c 100644 --- a/crates/aiken-lang/src/gen_uplc/decision_tree.rs +++ b/crates/aiken-lang/src/gen_uplc/decision_tree.rs @@ -17,14 +17,7 @@ struct Occurrence { pub enum Path { Pair(usize), Tuple(usize), -} - -impl Path { - pub fn get_index(&self) -> usize { - match self { - Path::Pair(u) | Path::Tuple(u) => *u, - } - } + List(usize), } #[derive(Clone, Debug)] @@ -39,17 +32,19 @@ pub struct Assign { assigned: String, } +#[derive(Clone, Debug)] struct Row<'a> { assigns: Vec, columns: Vec>, then: &'a TypedExpr, } +#[derive(Clone, Debug)] struct PatternMatrix<'a> { rows: Vec>, } -#[derive(Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub enum CaseTest { Constr(PatternConstructor), Int(String), @@ -62,8 +57,8 @@ impl PartialOrd for CaseTest { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { (CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal), - (CaseTest::Wild, _) => Some(Ordering::Less), - (_, CaseTest::Wild) => Some(Ordering::Greater), + (CaseTest::Wild, _) => Some(Ordering::Greater), + (_, CaseTest::Wild) => Some(Ordering::Less), (_, _) => Some(Ordering::Equal), } } @@ -73,13 +68,14 @@ impl Ord for CaseTest { fn cmp(&self, other: &Self) -> Ordering { match (self, other) { (CaseTest::Wild, CaseTest::Wild) => Ordering::Equal, - (CaseTest::Wild, _) => Ordering::Less, - (_, CaseTest::Wild) => Ordering::Greater, + (CaseTest::Wild, _) => Ordering::Greater, + (_, CaseTest::Wild) => Ordering::Less, (_, _) => Ordering::Equal, } } } +#[derive(Debug, Clone)] pub enum DecisionTree<'a> { Switch { subject_name: String, @@ -93,9 +89,13 @@ pub enum DecisionTree<'a> { fn get_tipo_by_path(mut subject_tipo: Rc, mut path: &[Path]) -> Rc { while let Some((p, rest)) = path.split_first() { - let index = p.get_index(); + subject_tipo = match p { + Path::Pair(index) | Path::Tuple(index) => { + subject_tipo.get_inner_types().swap_remove(*index) + } + Path::List(_) => subject_tipo.get_inner_types().swap_remove(0), + }; - subject_tipo = subject_tipo.arg_types().unwrap().swap_remove(index); path = rest } subject_tipo @@ -112,7 +112,7 @@ fn map_pattern_to_row<'a>( let new_columns_added = if current_tipo.is_pair() { 2 } else if current_tipo.is_tuple() { - let Type::Tuple { elems, .. } = subject_tipo.as_ref() else { + let Type::Tuple { elems, .. } = current_tipo.as_ref() else { unreachable!() }; elems.len() @@ -225,13 +225,16 @@ pub fn build_tree<'a>( }) .collect_vec(); - do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows }) + println!("INITIAL ROWS ARE {:#?}", rows); + + do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows }, None) } fn do_build_tree<'a>( subject_name: &String, subject_tipo: &Rc, matrix: PatternMatrix<'a>, + fallback_option: Option>, ) -> DecisionTree<'a> { let column_length = matrix.rows[0].columns.len(); @@ -277,6 +280,11 @@ fn do_build_tree<'a>( let (path, mut collection_vec) = matrix.rows.into_iter().fold( (vec![], vec![]), |mut collection_vec: (Vec, Vec<(CaseTest, Vec>)>), mut item: Row<'a>| { + if item.columns.is_empty() { + collection_vec.1.push((CaseTest::Wild, vec![item])); + return collection_vec; + } + let col = item.columns.remove(highest_occurrence.0); assert!(!matches!(col.pattern, Pattern::Assign { .. })); @@ -292,7 +300,7 @@ fn do_build_tree<'a>( .map(|(index, item)| { let mut item_path = col.path.clone(); - item_path.push(Path::Tuple(index)); + item_path.push(Path::List(index)); map_pattern_to_row(item, subject_name, subject_tipo, item_path) }) @@ -311,7 +319,11 @@ fn do_build_tree<'a>( item.columns .extend(mapped_args.into_iter().map(|x| x.1).flatten()); - assert!(collection_vec.0.is_empty() || collection_vec.0 == col.path); + assert!( + collection_vec.0.is_empty() + || collection_vec.0 == col.path + || matches!(case, CaseTest::Wild) + ); if collection_vec.0.is_empty() { collection_vec.0 = col.path; @@ -334,12 +346,6 @@ fn do_build_tree<'a>( let cases = collection_iter .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) - .map(|x| { - ( - x.0, - do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }), - ) - }) .collect_vec(); if cases.is_empty() { @@ -351,22 +357,185 @@ fn do_build_tree<'a>( assert!(remaining.len() == 1); - let thing = remaining.swap_remove(0); + let row = remaining.swap_remove(0); - DecisionTree::Leaf(thing.assigns, thing.then) + DecisionTree::Leaf(row.assigns, row.then) } else { let mut fallback = collection_iter - .map(|x| do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }).into()) + .map(|x| { + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + None, + ) + .into() + }) .collect_vec(); + assert!(fallback.len() == 1 || fallback_option.is_some()); - assert!(fallback.len() == 1); + let fallback = if !fallback.is_empty() { + fallback.swap_remove(0) + } else { + fallback_option.unwrap() + }; DecisionTree::Switch { subject_name: subject_name.clone(), - subject_tipo: subject_tipo.clone(), + subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path), path, - cases, - default: fallback.swap_remove(0), + cases: cases + .into_iter() + .map(|x| { + ( + x.0, + do_build_tree( + subject_name, + subject_tipo, + PatternMatrix { rows: x.1 }, + Some(fallback.clone()), + ), + ) + }) + .collect_vec(), + default: fallback.into(), } } } + +#[cfg(test)] +mod tester { + use std::collections::HashMap; + + use crate::{ + ast::{Definition, ModuleKind, TraceLevel, Tracing, TypedModule, UntypedModule}, + builtins, + expr::{Type, TypedExpr}, + gen_uplc::decision_tree::build_tree, + parser, + tipo::error::{Error, Warning}, + IdGenerator, + }; + + fn parse(source_code: &str) -> UntypedModule { + let kind = ModuleKind::Lib; + let (ast, _) = parser::module(source_code, kind).expect("Failed to parse module"); + ast + } + + fn check_module( + ast: UntypedModule, + extra: Vec<(String, UntypedModule)>, + kind: ModuleKind, + tracing: Tracing, + ) -> Result<(Vec, TypedModule), (Vec, Error)> { + let id_gen = IdGenerator::new(); + + let mut warnings = vec![]; + + let mut module_types = HashMap::new(); + module_types.insert("aiken".to_string(), builtins::prelude(&id_gen)); + module_types.insert("aiken/builtin".to_string(), builtins::plutus(&id_gen)); + + for (package, module) in extra { + let mut warnings = vec![]; + let typed_module = module + .infer( + &id_gen, + kind, + &package, + &module_types, + Tracing::All(TraceLevel::Verbose), + &mut warnings, + None, + ) + .expect("extra dependency did not compile"); + module_types.insert(package.clone(), typed_module.type_info.clone()); + } + + let result = ast.infer( + &id_gen, + kind, + "test/project", + &module_types, + tracing, + &mut warnings, + None, + ); + + result + .map(|o| (warnings.clone(), o)) + .map_err(|e| (warnings, e)) + } + + fn check(ast: UntypedModule) -> Result<(Vec, TypedModule), (Vec, Error)> { + check_module(ast, Vec::new(), ModuleKind::Lib, Tracing::verbose()) + } + + #[test] + fn thing() { + let source_code = r#" + test thing(){ + when [1, 2, 3] is { + [] -> False + [4] -> fail + [a, 2, b] -> True + _ -> False + } + } + "#; + + let (_, ast) = check(parse(source_code)).unwrap(); + + let Definition::Test(function) = &ast.definitions[0] else { + panic!() + }; + + let TypedExpr::When { clauses, .. } = &function.body else { + panic!() + }; + + let tree = build_tree(&"subject".to_string(), &Type::list(Type::int()), clauses); + + println!("TREE IS {:#?}", tree); + } + + #[test] + fn thing2() { + let source_code = r#" + test thing(){ + when (1,2,#"",[]) is { + (a,b,#"", []) -> True + (1,b,#"", [1]) -> False + (3,b,#"aa", _) -> 2 == 2 + _ -> 1 == 1 + } + } + "#; + + let (_, ast) = check(parse(source_code)).unwrap(); + + let Definition::Test(function) = &ast.definitions[0] else { + panic!() + }; + + let TypedExpr::When { clauses, .. } = &function.body else { + panic!() + }; + + let tree = build_tree( + &"subject".to_string(), + &Type::tuple(vec![ + Type::int(), + Type::int(), + Type::byte_array(), + Type::list(Type::int()), + ]), + clauses, + ); + + println!("TREE IS {:#?}", tree); + + panic!() + } +}