Incremental commit for dealing with list tails

This commit is contained in:
microproofs 2024-10-15 14:37:19 -04:00
parent ca161d8a68
commit b340de2cfd
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 102 additions and 32 deletions

View File

@ -1,6 +1,6 @@
use std::{cmp::Ordering, rc::Rc}; use std::{cmp::Ordering, rc::Rc};
use itertools::Itertools; use itertools::{Itertools, Position};
use crate::{ use crate::{
ast::{Pattern, TypedClause, TypedPattern}, ast::{Pattern, TypedClause, TypedPattern},
@ -18,6 +18,7 @@ pub enum Path {
Pair(usize), Pair(usize),
Tuple(usize), Tuple(usize),
List(usize), List(usize),
ListTail(usize),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -27,14 +28,14 @@ struct RowItem<'a> {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Assign { pub struct Assigned {
path: Vec<Path>, path: Vec<Path>,
assigned: String, assigned: String,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Row<'a> { struct Row<'a> {
assigns: Vec<Assign>, assigns: Vec<Assigned>,
columns: Vec<RowItem<'a>>, columns: Vec<RowItem<'a>>,
then: &'a TypedExpr, then: &'a TypedExpr,
} }
@ -49,7 +50,8 @@ pub enum CaseTest {
Constr(PatternConstructor), Constr(PatternConstructor),
Int(String), Int(String),
Bytes(Vec<u8>), Bytes(Vec<u8>),
List(usize, bool), List(usize),
ListWithTail(usize),
Wild, Wild,
} }
@ -84,7 +86,17 @@ pub enum DecisionTree<'a> {
cases: Vec<(CaseTest, DecisionTree<'a>)>, cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Box<DecisionTree<'a>>, default: Box<DecisionTree<'a>>,
}, },
Leaf(Vec<Assign>, &'a TypedExpr), ListSwitch {
subject_name: String,
subject_tipo: Rc<Type>,
path: Vec<Path>,
cases: Vec<(CaseTest, DecisionTree<'a>)>,
tail_cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Box<DecisionTree<'a>>,
},
Leaf(Vec<Assigned>, &'a TypedExpr),
HoistedLeaf(String),
HoistThen(Vec<Assigned>, &'a TypedExpr, Box<DecisionTree<'a>>),
} }
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> { fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
@ -94,6 +106,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
subject_tipo.get_inner_types().swap_remove(*index) subject_tipo.get_inner_types().swap_remove(*index)
} }
Path::List(_) => subject_tipo.get_inner_types().swap_remove(0), Path::List(_) => subject_tipo.get_inner_types().swap_remove(0),
Path::ListTail(_) => subject_tipo,
}; };
path = rest path = rest
@ -105,7 +118,7 @@ fn map_pattern_to_row<'a>(
pattern: &'a TypedPattern, pattern: &'a TypedPattern,
subject_tipo: &Rc<Type>, subject_tipo: &Rc<Type>,
path: Vec<Path>, path: Vec<Path>,
) -> (Vec<Assign>, Vec<RowItem<'a>>) { ) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path); let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let new_columns_added = if current_tipo.is_pair() { let new_columns_added = if current_tipo.is_pair() {
@ -121,7 +134,7 @@ fn map_pattern_to_row<'a>(
match pattern { match pattern {
Pattern::Var { name, .. } => ( Pattern::Var { name, .. } => (
vec![Assign { vec![Assigned {
path: path.clone(), path: path.clone(),
assigned: name.clone(), assigned: name.clone(),
}], }],
@ -136,7 +149,7 @@ fn map_pattern_to_row<'a>(
), ),
Pattern::Assign { name, pattern, .. } => ( Pattern::Assign { name, pattern, .. } => (
vec![Assign { vec![Assigned {
path: path.clone(), path: path.clone(),
assigned: name.clone(), assigned: name.clone(),
}], }],
@ -225,19 +238,8 @@ pub fn build_tree<'a>(
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None) do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None)
} }
fn do_build_tree<'a>( // A function to get which column has the most pattern matches before a wild card
subject_name: &String, fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize {
subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>,
fallback_option: Option<DecisionTree<'a>>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
let occurrences = [Occurrence::default()].repeat(column_length); let occurrences = [Occurrence::default()].repeat(column_length);
let occurrences = let occurrences =
@ -272,6 +274,44 @@ fn do_build_tree<'a>(
} }
}); });
highest_occurrence.0
}
fn do_build_tree<'a>(
subject_name: &String,
subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>,
fallback_option: Option<DecisionTree<'a>>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
let occurrence_col = highest_occurrence(&matrix, column_length);
let mut longest_elems = None;
matrix.rows.iter().for_each(|item| {
let col = &item.columns[occurrence_col];
match col.pattern {
Pattern::List { elements, .. } => match longest_elems {
Some(elems_count) => {
if elems_count < elements.len() {
longest_elems = Some(elements.len());
}
}
None => {
longest_elems = Some(elements.len());
}
},
_ => (),
}
});
let (path, mut collection_vec) = matrix.rows.into_iter().fold( let (path, mut collection_vec) = matrix.rows.into_iter().fold(
(vec![], vec![]), (vec![], vec![]),
|mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| { |mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| {
@ -280,7 +320,7 @@ fn do_build_tree<'a>(
return collection_vec; return collection_vec;
} }
let col = item.columns.remove(highest_occurrence.0); let col = item.columns.remove(occurrence_col);
assert!(!matches!(col.pattern, Pattern::Assign { .. })); assert!(!matches!(col.pattern, Pattern::Assign { .. }));
@ -293,15 +333,40 @@ fn do_build_tree<'a>(
.iter() .iter()
.chain(tail.as_ref().map(|tail| tail.as_ref())) .chain(tail.as_ref().map(|tail| tail.as_ref()))
.enumerate() .enumerate()
.map(|(index, item)| { .with_position()
.map(|elem| match elem {
Position::First((index, element))
| Position::Middle((index, element))
| Position::Only((index, element)) => {
let mut item_path = col.path.clone(); let mut item_path = col.path.clone();
item_path.push(Path::List(index)); item_path.push(Path::List(index));
map_pattern_to_row(item, subject_tipo, item_path) map_pattern_to_row(element, subject_tipo, item_path)
}
Position::Last((index, element)) => {
if tail.is_none() {
let mut item_path = col.path.clone();
item_path.push(Path::List(index));
map_pattern_to_row(element, subject_tipo, item_path)
} else {
let mut item_path = col.path.clone();
item_path.push(Path::ListTail(index));
map_pattern_to_row(element, subject_tipo, item_path)
}
}
}) })
.collect_vec(), .collect_vec(),
CaseTest::List(elements.len(), tail.is_some()), if tail.is_none() {
CaseTest::List(elements.len())
} else {
CaseTest::ListWithTail(elements.len())
},
), ),
Pattern::Constructor { .. } => { Pattern::Constructor { .. } => {
@ -310,11 +375,7 @@ fn do_build_tree<'a>(
_ => unreachable!("{:#?}", col.pattern), _ => unreachable!("{:#?}", col.pattern),
}; };
item.assigns // Assert path is matches for each row except for wild_card
.extend(mapped_args.iter().flat_map(|x| x.0.clone()));
item.columns
.extend(mapped_args.into_iter().flat_map(|x| x.1));
assert!( assert!(
collection_vec.0.is_empty() collection_vec.0.is_empty()
|| collection_vec.0 == col.path || collection_vec.0 == col.path
@ -325,6 +386,15 @@ fn do_build_tree<'a>(
collection_vec.0 = col.path; collection_vec.0 = col.path;
} }
// expand assigns by newly added ones
item.assigns
.extend(mapped_args.iter().flat_map(|x| x.0.clone()));
// Add inner patterns to existing row
item.columns
.extend(mapped_args.into_iter().flat_map(|x| x.1));
// TODO: Handle special casetest of ListWithTail
if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) { if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) {
entry.1.push(item); entry.1.push(item);
collection_vec collection_vec