Now working for all kinds of patterns except for constr

This commit is contained in:
microproofs 2024-10-11 01:18:49 -04:00
parent 9369cbc1a3
commit def268d966
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 201 additions and 32 deletions

View File

@ -17,14 +17,7 @@ struct Occurrence {
pub enum Path { pub enum Path {
Pair(usize), Pair(usize),
Tuple(usize), Tuple(usize),
} List(usize),
impl Path {
pub fn get_index(&self) -> usize {
match self {
Path::Pair(u) | Path::Tuple(u) => *u,
}
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -39,17 +32,19 @@ pub struct Assign {
assigned: String, assigned: String,
} }
#[derive(Clone, Debug)]
struct Row<'a> { struct Row<'a> {
assigns: Vec<Assign>, assigns: Vec<Assign>,
columns: Vec<RowItem<'a>>, columns: Vec<RowItem<'a>>,
then: &'a TypedExpr, then: &'a TypedExpr,
} }
#[derive(Clone, Debug)]
struct PatternMatrix<'a> { struct PatternMatrix<'a> {
rows: Vec<Row<'a>>, rows: Vec<Row<'a>>,
} }
#[derive(Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub enum CaseTest { pub enum CaseTest {
Constr(PatternConstructor), Constr(PatternConstructor),
Int(String), Int(String),
@ -62,8 +57,8 @@ impl PartialOrd for CaseTest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self, other) { match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal), (CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal),
(CaseTest::Wild, _) => Some(Ordering::Less), (CaseTest::Wild, _) => Some(Ordering::Greater),
(_, CaseTest::Wild) => Some(Ordering::Greater), (_, CaseTest::Wild) => Some(Ordering::Less),
(_, _) => Some(Ordering::Equal), (_, _) => Some(Ordering::Equal),
} }
} }
@ -73,13 +68,14 @@ impl Ord for CaseTest {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
match (self, other) { match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Ordering::Equal, (CaseTest::Wild, CaseTest::Wild) => Ordering::Equal,
(CaseTest::Wild, _) => Ordering::Less, (CaseTest::Wild, _) => Ordering::Greater,
(_, CaseTest::Wild) => Ordering::Greater, (_, CaseTest::Wild) => Ordering::Less,
(_, _) => Ordering::Equal, (_, _) => Ordering::Equal,
} }
} }
} }
#[derive(Debug, Clone)]
pub enum DecisionTree<'a> { pub enum DecisionTree<'a> {
Switch { Switch {
subject_name: String, subject_name: String,
@ -93,9 +89,13 @@ pub enum DecisionTree<'a> {
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> { fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
while let Some((p, rest)) = path.split_first() { while let Some((p, rest)) = path.split_first() {
let index = p.get_index(); subject_tipo = match p {
Path::Pair(index) | Path::Tuple(index) => {
subject_tipo.get_inner_types().swap_remove(*index)
}
Path::List(_) => subject_tipo.get_inner_types().swap_remove(0),
};
subject_tipo = subject_tipo.arg_types().unwrap().swap_remove(index);
path = rest path = rest
} }
subject_tipo subject_tipo
@ -112,7 +112,7 @@ fn map_pattern_to_row<'a>(
let new_columns_added = if current_tipo.is_pair() { let new_columns_added = if current_tipo.is_pair() {
2 2
} else if current_tipo.is_tuple() { } else if current_tipo.is_tuple() {
let Type::Tuple { elems, .. } = subject_tipo.as_ref() else { let Type::Tuple { elems, .. } = current_tipo.as_ref() else {
unreachable!() unreachable!()
}; };
elems.len() elems.len()
@ -225,13 +225,16 @@ pub fn build_tree<'a>(
}) })
.collect_vec(); .collect_vec();
do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows }) println!("INITIAL ROWS ARE {:#?}", rows);
do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows }, None)
} }
fn do_build_tree<'a>( fn do_build_tree<'a>(
subject_name: &String, subject_name: &String,
subject_tipo: &Rc<Type>, subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>, matrix: PatternMatrix<'a>,
fallback_option: Option<DecisionTree<'a>>,
) -> DecisionTree<'a> { ) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len(); let column_length = matrix.rows[0].columns.len();
@ -277,6 +280,11 @@ fn do_build_tree<'a>(
let (path, mut collection_vec) = matrix.rows.into_iter().fold( let (path, mut collection_vec) = matrix.rows.into_iter().fold(
(vec![], vec![]), (vec![], vec![]),
|mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| { |mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| {
if item.columns.is_empty() {
collection_vec.1.push((CaseTest::Wild, vec![item]));
return collection_vec;
}
let col = item.columns.remove(highest_occurrence.0); let col = item.columns.remove(highest_occurrence.0);
assert!(!matches!(col.pattern, Pattern::Assign { .. })); assert!(!matches!(col.pattern, Pattern::Assign { .. }));
@ -292,7 +300,7 @@ fn do_build_tree<'a>(
.map(|(index, item)| { .map(|(index, item)| {
let mut item_path = col.path.clone(); let mut item_path = col.path.clone();
item_path.push(Path::Tuple(index)); item_path.push(Path::List(index));
map_pattern_to_row(item, subject_name, subject_tipo, item_path) map_pattern_to_row(item, subject_name, subject_tipo, item_path)
}) })
@ -311,7 +319,11 @@ fn do_build_tree<'a>(
item.columns item.columns
.extend(mapped_args.into_iter().map(|x| x.1).flatten()); .extend(mapped_args.into_iter().map(|x| x.1).flatten());
assert!(collection_vec.0.is_empty() || collection_vec.0 == col.path); assert!(
collection_vec.0.is_empty()
|| collection_vec.0 == col.path
|| matches!(case, CaseTest::Wild)
);
if collection_vec.0.is_empty() { if collection_vec.0.is_empty() {
collection_vec.0 = col.path; collection_vec.0 = col.path;
@ -334,12 +346,6 @@ fn do_build_tree<'a>(
let cases = collection_iter let cases = collection_iter
.peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild))
.map(|x| {
(
x.0,
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }),
)
})
.collect_vec(); .collect_vec();
if cases.is_empty() { if cases.is_empty() {
@ -351,22 +357,185 @@ fn do_build_tree<'a>(
assert!(remaining.len() == 1); assert!(remaining.len() == 1);
let thing = remaining.swap_remove(0); let row = remaining.swap_remove(0);
DecisionTree::Leaf(thing.assigns, thing.then) DecisionTree::Leaf(row.assigns, row.then)
} else { } else {
let mut fallback = collection_iter let mut fallback = collection_iter
.map(|x| do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }).into()) .map(|x| {
do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows: x.1 },
None,
)
.into()
})
.collect_vec(); .collect_vec();
assert!(fallback.len() == 1 || fallback_option.is_some());
assert!(fallback.len() == 1); let fallback = if !fallback.is_empty() {
fallback.swap_remove(0)
} else {
fallback_option.unwrap()
};
DecisionTree::Switch { DecisionTree::Switch {
subject_name: subject_name.clone(), subject_name: subject_name.clone(),
subject_tipo: subject_tipo.clone(), subject_tipo: get_tipo_by_path(subject_tipo.clone(), &path),
path, path,
cases, cases: cases
default: fallback.swap_remove(0), .into_iter()
.map(|x| {
(
x.0,
do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows: x.1 },
Some(fallback.clone()),
),
)
})
.collect_vec(),
default: fallback.into(),
} }
} }
} }
#[cfg(test)]
mod tester {
use std::collections::HashMap;
use crate::{
ast::{Definition, ModuleKind, TraceLevel, Tracing, TypedModule, UntypedModule},
builtins,
expr::{Type, TypedExpr},
gen_uplc::decision_tree::build_tree,
parser,
tipo::error::{Error, Warning},
IdGenerator,
};
fn parse(source_code: &str) -> UntypedModule {
let kind = ModuleKind::Lib;
let (ast, _) = parser::module(source_code, kind).expect("Failed to parse module");
ast
}
fn check_module(
ast: UntypedModule,
extra: Vec<(String, UntypedModule)>,
kind: ModuleKind,
tracing: Tracing,
) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
let id_gen = IdGenerator::new();
let mut warnings = vec![];
let mut module_types = HashMap::new();
module_types.insert("aiken".to_string(), builtins::prelude(&id_gen));
module_types.insert("aiken/builtin".to_string(), builtins::plutus(&id_gen));
for (package, module) in extra {
let mut warnings = vec![];
let typed_module = module
.infer(
&id_gen,
kind,
&package,
&module_types,
Tracing::All(TraceLevel::Verbose),
&mut warnings,
None,
)
.expect("extra dependency did not compile");
module_types.insert(package.clone(), typed_module.type_info.clone());
}
let result = ast.infer(
&id_gen,
kind,
"test/project",
&module_types,
tracing,
&mut warnings,
None,
);
result
.map(|o| (warnings.clone(), o))
.map_err(|e| (warnings, e))
}
fn check(ast: UntypedModule) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
check_module(ast, Vec::new(), ModuleKind::Lib, Tracing::verbose())
}
#[test]
fn thing() {
let source_code = r#"
test thing(){
when [1, 2, 3] is {
[] -> False
[4] -> fail
[a, 2, b] -> True
_ -> False
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When { clauses, .. } = &function.body else {
panic!()
};
let tree = build_tree(&"subject".to_string(), &Type::list(Type::int()), clauses);
println!("TREE IS {:#?}", tree);
}
#[test]
fn thing2() {
let source_code = r#"
test thing(){
when (1,2,#"",[]) is {
(a,b,#"", []) -> True
(1,b,#"", [1]) -> False
(3,b,#"aa", _) -> 2 == 2
_ -> 1 == 1
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When { clauses, .. } = &function.body else {
panic!()
};
let tree = build_tree(
&"subject".to_string(),
&Type::tuple(vec![
Type::int(),
Type::int(),
Type::byte_array(),
Type::list(Type::int()),
]),
clauses,
);
println!("TREE IS {:#?}", tree);
panic!()
}
}