Rework Decision Trees to use path to find the subject to test

This commit is contained in:
microproofs 2024-10-10 12:11:13 -04:00
parent 20385a7ecd
commit 43e859f1ba
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 97 additions and 96 deletions

View File

@ -1,6 +1,6 @@
use std::{cmp::Ordering, rc::Rc}; use std::{cmp::Ordering, rc::Rc};
use itertools::{Itertools, Tuples}; use itertools::Itertools;
use crate::{ use crate::{
ast::{Pattern, TypedClause, TypedPattern}, ast::{Pattern, TypedClause, TypedPattern},
@ -13,7 +13,7 @@ struct Occurrence {
amount: usize, amount: usize,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Eq, PartialEq)]
enum Path { enum Path {
Pair(usize), Pair(usize),
Tuple(usize), Tuple(usize),
@ -80,22 +80,22 @@ impl Ord for CaseTest {
} }
} }
enum DecisionTree { enum DecisionTree<'a> {
Switch { Switch {
subject_name: String, subject_name: String,
subject_tipo: Rc<Type>, subject_tipo: Rc<Type>,
column_to_test: usize, path: Vec<Path>,
cases: Vec<(CaseTest, Vec<Path>, DecisionTree)>, cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: (Vec<Path>, Box<DecisionTree>), default: Box<DecisionTree<'a>>,
}, },
Leaf(TypedExpr), Leaf(Vec<Assign>, &'a TypedExpr),
} }
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> { fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
while let Some((p, rest)) = path.split_first() { while let Some((p, rest)) = path.split_first() {
let index = p.get_index(); let index = p.get_index();
subject_tipo = subject_tipo.arg_types().unwrap().remove(index); subject_tipo = subject_tipo.arg_types().unwrap().swap_remove(index);
path = rest path = rest
} }
subject_tipo subject_tipo
@ -104,7 +104,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
fn map_pattern_to_row<'a>( fn map_pattern_to_row<'a>(
pattern: &'a TypedPattern, pattern: &'a TypedPattern,
subject_name: &String, subject_name: &String,
subject_tipo: Rc<Type>, subject_tipo: &Rc<Type>,
path: Vec<Path>, path: Vec<Path>,
) -> (Vec<Assign>, Vec<RowItem<'a>>) { ) -> (Vec<Assign>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
@ -141,7 +141,7 @@ fn map_pattern_to_row<'a>(
path: path.clone(), path: path.clone(),
assigned: name.clone(), assigned: name.clone(),
}], }],
map_pattern_to_row(pattern, subject_name, subject_tipo.clone(), path).1, map_pattern_to_row(pattern, subject_name, subject_tipo, path).1,
), ),
Pattern::Int { .. } Pattern::Int { .. }
| Pattern::ByteArray { .. } | Pattern::ByteArray { .. }
@ -166,10 +166,10 @@ fn map_pattern_to_row<'a>(
snd_path.push(Path::Pair(1)); snd_path.push(Path::Pair(1));
let (mut assigns, mut patts) = let (mut assigns, mut patts) =
map_pattern_to_row(fst, subject_name, subject_tipo.clone(), fst_path); map_pattern_to_row(fst, subject_name, subject_tipo, fst_path);
let (assign_snd, patt_snd) = let (assign_snd, patt_snd) =
map_pattern_to_row(snd, subject_name, subject_tipo.clone(), snd_path); map_pattern_to_row(snd, subject_name, subject_tipo, snd_path);
assigns.extend(assign_snd.into_iter()); assigns.extend(assign_snd.into_iter());
@ -187,7 +187,7 @@ fn map_pattern_to_row<'a>(
item_path.push(Path::Tuple(index)); item_path.push(Path::Tuple(index));
let (assigns, patts) = let (assigns, patts) =
map_pattern_to_row(item, subject_name, subject_tipo.clone(), item_path); map_pattern_to_row(item, subject_name, subject_tipo, item_path);
acc.0.extend(assigns.into_iter()); acc.0.extend(assigns.into_iter());
acc.1.extend(patts.into_iter()); acc.1.extend(patts.into_iter());
@ -206,16 +206,16 @@ fn match_wild_card(pattern: &TypedPattern) -> bool {
} }
} }
pub fn build_tree( pub fn build_tree<'a>(
subject_name: &String, subject_name: &String,
subject_tipo: Rc<Type>, subject_tipo: &Rc<Type>,
clauses: &Vec<TypedClause>, clauses: &'a Vec<TypedClause>,
) -> DecisionTree { ) -> DecisionTree<'a> {
let rows = clauses let rows = clauses
.iter() .iter()
.map(|clause| { .map(|clause| {
let (assign, row_items) = let (assign, row_items) =
map_pattern_to_row(&clause.pattern, subject_name, subject_tipo.clone(), vec![]); map_pattern_to_row(&clause.pattern, subject_name, subject_tipo, vec![]);
Row { Row {
assigns: assign.into_iter().collect_vec(), assigns: assign.into_iter().collect_vec(),
@ -232,7 +232,7 @@ pub fn do_build_tree<'a>(
subject_name: &String, subject_name: &String,
subject_tipo: &Rc<Type>, subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>, matrix: PatternMatrix<'a>,
) -> DecisionTree { ) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len(); let column_length = matrix.rows[0].columns.len();
assert!(matrix assert!(matrix
@ -274,18 +274,9 @@ pub fn do_build_tree<'a>(
} }
}); });
if column_length > 1 { let (path, mut collection_vec) = matrix.rows.into_iter().fold(
DecisionTree::Switch { (vec![], vec![]),
subject_name: subject_name.clone(), |mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| {
subject_tipo: subject_tipo.clone(),
column_to_test: highest_occurrence.0,
cases: todo!(),
default: todo!(),
}
} else {
let mut collection_vec = matrix.rows.into_iter().fold(
vec![],
|mut collection_vec: Vec<(CaseTest, Vec<Path>, Vec<Row<'a>>)>, mut item: Row<'a>| {
let col = item.columns.remove(highest_occurrence.0); let col = item.columns.remove(highest_occurrence.0);
assert!(!matches!(col.pattern, Pattern::Assign { .. })); assert!(!matches!(col.pattern, Pattern::Assign { .. }));
@ -303,12 +294,7 @@ pub fn do_build_tree<'a>(
item_path.push(Path::Tuple(index)); item_path.push(Path::Tuple(index));
map_pattern_to_row( map_pattern_to_row(item, subject_name, subject_tipo, item_path)
item,
subject_name,
subject_tipo.clone(),
item_path,
)
}) })
.collect_vec(), .collect_vec(),
CaseTest::List(elements.len()), CaseTest::List(elements.len()),
@ -325,13 +311,17 @@ pub fn do_build_tree<'a>(
item.columns item.columns
.extend(mapped_args.into_iter().map(|x| x.1).flatten()); .extend(mapped_args.into_iter().map(|x| x.1).flatten());
if let Some(index) = collection_vec.iter().position(|item| item.0 == case) { assert!(collection_vec.0.is_empty() || collection_vec.0 == col.path);
let entry = collection_vec.get_mut(index).unwrap();
entry.2.push(item); if collection_vec.0.is_empty() {
collection_vec.0 = col.path;
}
if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) {
entry.1.push(item);
collection_vec collection_vec
} else { } else {
collection_vec.push((case, col.path, vec![item])); collection_vec.1.push((case, vec![item]));
collection_vec collection_vec
} }
@ -339,33 +329,44 @@ pub fn do_build_tree<'a>(
); );
collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); collection_vec.sort_by(|a, b| a.0.cmp(&b.0));
let mut collection_iter = collection_vec
.into_iter() let mut collection_iter = collection_vec.into_iter().peekable();
.map(|x| {
(
x.0,
x.1,
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.2 }),
)
})
.peekable();
let cases = collection_iter let cases = collection_iter
.peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild))
.map(|x| {
(
x.0,
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }),
)
})
.collect_vec(); .collect_vec();
let mut fallback = collection_iter.map(|x| (x.1, x.2.into())).collect_vec(); if cases.is_empty() {
let mut fallback = collection_iter.collect_vec();
assert!(fallback.len() == 1);
let mut remaining = fallback.swap_remove(0).1;
assert!(remaining.len() == 1);
let thing = remaining.swap_remove(0);
DecisionTree::Leaf(thing.assigns, thing.then)
} else {
let mut fallback = collection_iter
.map(|x| do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }).into())
.collect_vec();
assert!(fallback.len() == 1); assert!(fallback.len() == 1);
DecisionTree::Switch { DecisionTree::Switch {
subject_name: subject_name.clone(), subject_name: subject_name.clone(),
subject_tipo: subject_tipo.clone(), subject_tipo: subject_tipo.clone(),
column_to_test: highest_occurrence.0, path,
cases, cases,
default: fallback.remove(0), default: fallback.swap_remove(0),
}
} }
};
todo!()
} }