Rework Decision Trees to use path to find the subject to test
This commit is contained in:
parent
20385a7ecd
commit
43e859f1ba
|
@ -1,6 +1,6 @@
|
|||
use std::{cmp::Ordering, rc::Rc};
|
||||
|
||||
use itertools::{Itertools, Tuples};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
ast::{Pattern, TypedClause, TypedPattern},
|
||||
|
@ -13,7 +13,7 @@ struct Occurrence {
|
|||
amount: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
enum Path {
|
||||
Pair(usize),
|
||||
Tuple(usize),
|
||||
|
@ -80,22 +80,22 @@ impl Ord for CaseTest {
|
|||
}
|
||||
}
|
||||
|
||||
enum DecisionTree {
|
||||
enum DecisionTree<'a> {
|
||||
Switch {
|
||||
subject_name: String,
|
||||
subject_tipo: Rc<Type>,
|
||||
column_to_test: usize,
|
||||
cases: Vec<(CaseTest, Vec<Path>, DecisionTree)>,
|
||||
default: (Vec<Path>, Box<DecisionTree>),
|
||||
path: Vec<Path>,
|
||||
cases: Vec<(CaseTest, DecisionTree<'a>)>,
|
||||
default: Box<DecisionTree<'a>>,
|
||||
},
|
||||
Leaf(TypedExpr),
|
||||
Leaf(Vec<Assign>, &'a TypedExpr),
|
||||
}
|
||||
|
||||
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
|
||||
while let Some((p, rest)) = path.split_first() {
|
||||
let index = p.get_index();
|
||||
|
||||
subject_tipo = subject_tipo.arg_types().unwrap().remove(index);
|
||||
subject_tipo = subject_tipo.arg_types().unwrap().swap_remove(index);
|
||||
path = rest
|
||||
}
|
||||
subject_tipo
|
||||
|
@ -104,7 +104,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
|
|||
fn map_pattern_to_row<'a>(
|
||||
pattern: &'a TypedPattern,
|
||||
subject_name: &String,
|
||||
subject_tipo: Rc<Type>,
|
||||
subject_tipo: &Rc<Type>,
|
||||
path: Vec<Path>,
|
||||
) -> (Vec<Assign>, Vec<RowItem<'a>>) {
|
||||
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
|
||||
|
@ -141,7 +141,7 @@ fn map_pattern_to_row<'a>(
|
|||
path: path.clone(),
|
||||
assigned: name.clone(),
|
||||
}],
|
||||
map_pattern_to_row(pattern, subject_name, subject_tipo.clone(), path).1,
|
||||
map_pattern_to_row(pattern, subject_name, subject_tipo, path).1,
|
||||
),
|
||||
Pattern::Int { .. }
|
||||
| Pattern::ByteArray { .. }
|
||||
|
@ -166,10 +166,10 @@ fn map_pattern_to_row<'a>(
|
|||
snd_path.push(Path::Pair(1));
|
||||
|
||||
let (mut assigns, mut patts) =
|
||||
map_pattern_to_row(fst, subject_name, subject_tipo.clone(), fst_path);
|
||||
map_pattern_to_row(fst, subject_name, subject_tipo, fst_path);
|
||||
|
||||
let (assign_snd, patt_snd) =
|
||||
map_pattern_to_row(snd, subject_name, subject_tipo.clone(), snd_path);
|
||||
map_pattern_to_row(snd, subject_name, subject_tipo, snd_path);
|
||||
|
||||
assigns.extend(assign_snd.into_iter());
|
||||
|
||||
|
@ -187,7 +187,7 @@ fn map_pattern_to_row<'a>(
|
|||
item_path.push(Path::Tuple(index));
|
||||
|
||||
let (assigns, patts) =
|
||||
map_pattern_to_row(item, subject_name, subject_tipo.clone(), item_path);
|
||||
map_pattern_to_row(item, subject_name, subject_tipo, item_path);
|
||||
|
||||
acc.0.extend(assigns.into_iter());
|
||||
acc.1.extend(patts.into_iter());
|
||||
|
@ -206,16 +206,16 @@ fn match_wild_card(pattern: &TypedPattern) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn build_tree(
|
||||
pub fn build_tree<'a>(
|
||||
subject_name: &String,
|
||||
subject_tipo: Rc<Type>,
|
||||
clauses: &Vec<TypedClause>,
|
||||
) -> DecisionTree {
|
||||
subject_tipo: &Rc<Type>,
|
||||
clauses: &'a Vec<TypedClause>,
|
||||
) -> DecisionTree<'a> {
|
||||
let rows = clauses
|
||||
.iter()
|
||||
.map(|clause| {
|
||||
let (assign, row_items) =
|
||||
map_pattern_to_row(&clause.pattern, subject_name, subject_tipo.clone(), vec![]);
|
||||
map_pattern_to_row(&clause.pattern, subject_name, subject_tipo, vec![]);
|
||||
|
||||
Row {
|
||||
assigns: assign.into_iter().collect_vec(),
|
||||
|
@ -232,7 +232,7 @@ pub fn do_build_tree<'a>(
|
|||
subject_name: &String,
|
||||
subject_tipo: &Rc<Type>,
|
||||
matrix: PatternMatrix<'a>,
|
||||
) -> DecisionTree {
|
||||
) -> DecisionTree<'a> {
|
||||
let column_length = matrix.rows[0].columns.len();
|
||||
|
||||
assert!(matrix
|
||||
|
@ -274,98 +274,99 @@ pub fn do_build_tree<'a>(
|
|||
}
|
||||
});
|
||||
|
||||
if column_length > 1 {
|
||||
DecisionTree::Switch {
|
||||
subject_name: subject_name.clone(),
|
||||
subject_tipo: subject_tipo.clone(),
|
||||
column_to_test: highest_occurrence.0,
|
||||
cases: todo!(),
|
||||
default: todo!(),
|
||||
}
|
||||
} else {
|
||||
let mut collection_vec = matrix.rows.into_iter().fold(
|
||||
vec![],
|
||||
|mut collection_vec: Vec<(CaseTest, Vec<Path>, Vec<Row<'a>>)>, mut item: Row<'a>| {
|
||||
let col = item.columns.remove(highest_occurrence.0);
|
||||
let (path, mut collection_vec) = matrix.rows.into_iter().fold(
|
||||
(vec![], vec![]),
|
||||
|mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| {
|
||||
let col = item.columns.remove(highest_occurrence.0);
|
||||
|
||||
assert!(!matches!(col.pattern, Pattern::Assign { .. }));
|
||||
assert!(!matches!(col.pattern, Pattern::Assign { .. }));
|
||||
|
||||
let (mapped_args, case) = match col.pattern {
|
||||
Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())),
|
||||
Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())),
|
||||
Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild),
|
||||
Pattern::List { elements, .. } => (
|
||||
elements
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, item)| {
|
||||
let mut item_path = col.path.clone();
|
||||
let (mapped_args, case) = match col.pattern {
|
||||
Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())),
|
||||
Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())),
|
||||
Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild),
|
||||
Pattern::List { elements, .. } => (
|
||||
elements
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, item)| {
|
||||
let mut item_path = col.path.clone();
|
||||
|
||||
item_path.push(Path::Tuple(index));
|
||||
item_path.push(Path::Tuple(index));
|
||||
|
||||
map_pattern_to_row(
|
||||
item,
|
||||
subject_name,
|
||||
subject_tipo.clone(),
|
||||
item_path,
|
||||
)
|
||||
})
|
||||
.collect_vec(),
|
||||
CaseTest::List(elements.len()),
|
||||
),
|
||||
map_pattern_to_row(item, subject_name, subject_tipo, item_path)
|
||||
})
|
||||
.collect_vec(),
|
||||
CaseTest::List(elements.len()),
|
||||
),
|
||||
|
||||
Pattern::Constructor { .. } => {
|
||||
todo!()
|
||||
}
|
||||
_ => unreachable!("{:#?}", col.pattern),
|
||||
};
|
||||
|
||||
item.assigns
|
||||
.extend(mapped_args.iter().map(|x| x.0.clone()).flatten());
|
||||
item.columns
|
||||
.extend(mapped_args.into_iter().map(|x| x.1).flatten());
|
||||
|
||||
if let Some(index) = collection_vec.iter().position(|item| item.0 == case) {
|
||||
let entry = collection_vec.get_mut(index).unwrap();
|
||||
|
||||
entry.2.push(item);
|
||||
collection_vec
|
||||
} else {
|
||||
collection_vec.push((case, col.path, vec![item]));
|
||||
|
||||
collection_vec
|
||||
Pattern::Constructor { .. } => {
|
||||
todo!()
|
||||
}
|
||||
},
|
||||
);
|
||||
_ => unreachable!("{:#?}", col.pattern),
|
||||
};
|
||||
|
||||
collection_vec.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
let mut collection_iter = collection_vec
|
||||
.into_iter()
|
||||
.map(|x| {
|
||||
(
|
||||
x.0,
|
||||
x.1,
|
||||
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.2 }),
|
||||
)
|
||||
})
|
||||
.peekable();
|
||||
item.assigns
|
||||
.extend(mapped_args.iter().map(|x| x.0.clone()).flatten());
|
||||
item.columns
|
||||
.extend(mapped_args.into_iter().map(|x| x.1).flatten());
|
||||
|
||||
let cases = collection_iter
|
||||
.peeking_take_while(|a| !matches!(a.0, CaseTest::Wild))
|
||||
assert!(collection_vec.0.is_empty() || collection_vec.0 == col.path);
|
||||
|
||||
if collection_vec.0.is_empty() {
|
||||
collection_vec.0 = col.path;
|
||||
}
|
||||
|
||||
if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) {
|
||||
entry.1.push(item);
|
||||
collection_vec
|
||||
} else {
|
||||
collection_vec.1.push((case, vec![item]));
|
||||
|
||||
collection_vec
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
collection_vec.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut collection_iter = collection_vec.into_iter().peekable();
|
||||
|
||||
let cases = collection_iter
|
||||
.peeking_take_while(|a| !matches!(a.0, CaseTest::Wild))
|
||||
.map(|x| {
|
||||
(
|
||||
x.0,
|
||||
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
if cases.is_empty() {
|
||||
let mut fallback = collection_iter.collect_vec();
|
||||
|
||||
assert!(fallback.len() == 1);
|
||||
|
||||
let mut remaining = fallback.swap_remove(0).1;
|
||||
|
||||
assert!(remaining.len() == 1);
|
||||
|
||||
let thing = remaining.swap_remove(0);
|
||||
|
||||
DecisionTree::Leaf(thing.assigns, thing.then)
|
||||
} else {
|
||||
let mut fallback = collection_iter
|
||||
.map(|x| do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.1 }).into())
|
||||
.collect_vec();
|
||||
|
||||
let mut fallback = collection_iter.map(|x| (x.1, x.2.into())).collect_vec();
|
||||
|
||||
assert!(fallback.len() == 1);
|
||||
|
||||
DecisionTree::Switch {
|
||||
subject_name: subject_name.clone(),
|
||||
subject_tipo: subject_tipo.clone(),
|
||||
column_to_test: highest_occurrence.0,
|
||||
path,
|
||||
cases,
|
||||
default: fallback.remove(0),
|
||||
default: fallback.swap_remove(0),
|
||||
}
|
||||
};
|
||||
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue