Incremental commit for dealing with list tails

This commit is contained in:
microproofs 2024-10-15 14:37:19 -04:00
parent ca161d8a68
commit b340de2cfd
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 102 additions and 32 deletions

View File

@ -1,6 +1,6 @@
use std::{cmp::Ordering, rc::Rc};
use itertools::Itertools;
use itertools::{Itertools, Position};
use crate::{
ast::{Pattern, TypedClause, TypedPattern},
@ -18,6 +18,7 @@ pub enum Path {
Pair(usize),
Tuple(usize),
List(usize),
ListTail(usize),
}
#[derive(Clone, Debug)]
@ -27,14 +28,14 @@ struct RowItem<'a> {
}
#[derive(Clone, Debug)]
pub struct Assign {
pub struct Assigned {
path: Vec<Path>,
assigned: String,
}
#[derive(Clone, Debug)]
struct Row<'a> {
assigns: Vec<Assign>,
assigns: Vec<Assigned>,
columns: Vec<RowItem<'a>>,
then: &'a TypedExpr,
}
@ -49,7 +50,8 @@ pub enum CaseTest {
Constr(PatternConstructor),
Int(String),
Bytes(Vec<u8>),
List(usize, bool),
List(usize),
ListWithTail(usize),
Wild,
}
@ -84,7 +86,17 @@ pub enum DecisionTree<'a> {
cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Box<DecisionTree<'a>>,
},
Leaf(Vec<Assign>, &'a TypedExpr),
ListSwitch {
subject_name: String,
subject_tipo: Rc<Type>,
path: Vec<Path>,
cases: Vec<(CaseTest, DecisionTree<'a>)>,
tail_cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Box<DecisionTree<'a>>,
},
Leaf(Vec<Assigned>, &'a TypedExpr),
HoistedLeaf(String),
HoistThen(Vec<Assigned>, &'a TypedExpr, Box<DecisionTree<'a>>),
}
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
@ -94,6 +106,7 @@ fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
subject_tipo.get_inner_types().swap_remove(*index)
}
Path::List(_) => subject_tipo.get_inner_types().swap_remove(0),
Path::ListTail(_) => subject_tipo,
};
path = rest
@ -105,7 +118,7 @@ fn map_pattern_to_row<'a>(
pattern: &'a TypedPattern,
subject_tipo: &Rc<Type>,
path: Vec<Path>,
) -> (Vec<Assign>, Vec<RowItem<'a>>) {
) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let new_columns_added = if current_tipo.is_pair() {
@ -121,7 +134,7 @@ fn map_pattern_to_row<'a>(
match pattern {
Pattern::Var { name, .. } => (
vec![Assign {
vec![Assigned {
path: path.clone(),
assigned: name.clone(),
}],
@ -136,7 +149,7 @@ fn map_pattern_to_row<'a>(
),
Pattern::Assign { name, pattern, .. } => (
vec![Assign {
vec![Assigned {
path: path.clone(),
assigned: name.clone(),
}],
@ -225,19 +238,8 @@ pub fn build_tree<'a>(
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows }, None)
}
fn do_build_tree<'a>(
subject_name: &String,
subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>,
fallback_option: Option<DecisionTree<'a>>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
// A function to get which column has the most pattern matches before a wild card
fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> usize {
let occurrences = [Occurrence::default()].repeat(column_length);
let occurrences =
@ -272,6 +274,44 @@ fn do_build_tree<'a>(
}
});
highest_occurrence.0
}
fn do_build_tree<'a>(
subject_name: &String,
subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>,
fallback_option: Option<DecisionTree<'a>>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
let occurrence_col = highest_occurrence(&matrix, column_length);
let mut longest_elems = None;
matrix.rows.iter().for_each(|item| {
let col = &item.columns[occurrence_col];
match col.pattern {
Pattern::List { elements, .. } => match longest_elems {
Some(elems_count) => {
if elems_count < elements.len() {
longest_elems = Some(elements.len());
}
}
None => {
longest_elems = Some(elements.len());
}
},
_ => (),
}
});
let (path, mut collection_vec) = matrix.rows.into_iter().fold(
(vec![], vec![]),
|mut collection_vec: (Vec<Path>, Vec<(CaseTest, Vec<Row<'a>>)>), mut item: Row<'a>| {
@ -280,7 +320,7 @@ fn do_build_tree<'a>(
return collection_vec;
}
let col = item.columns.remove(highest_occurrence.0);
let col = item.columns.remove(occurrence_col);
assert!(!matches!(col.pattern, Pattern::Assign { .. }));
@ -293,15 +333,40 @@ fn do_build_tree<'a>(
.iter()
.chain(tail.as_ref().map(|tail| tail.as_ref()))
.enumerate()
.map(|(index, item)| {
let mut item_path = col.path.clone();
.with_position()
.map(|elem| match elem {
Position::First((index, element))
| Position::Middle((index, element))
| Position::Only((index, element)) => {
let mut item_path = col.path.clone();
item_path.push(Path::List(index));
item_path.push(Path::List(index));
map_pattern_to_row(item, subject_tipo, item_path)
map_pattern_to_row(element, subject_tipo, item_path)
}
Position::Last((index, element)) => {
if tail.is_none() {
let mut item_path = col.path.clone();
item_path.push(Path::List(index));
map_pattern_to_row(element, subject_tipo, item_path)
} else {
let mut item_path = col.path.clone();
item_path.push(Path::ListTail(index));
map_pattern_to_row(element, subject_tipo, item_path)
}
}
})
.collect_vec(),
CaseTest::List(elements.len(), tail.is_some()),
if tail.is_none() {
CaseTest::List(elements.len())
} else {
CaseTest::ListWithTail(elements.len())
},
),
Pattern::Constructor { .. } => {
@ -310,11 +375,7 @@ fn do_build_tree<'a>(
_ => unreachable!("{:#?}", col.pattern),
};
item.assigns
.extend(mapped_args.iter().flat_map(|x| x.0.clone()));
item.columns
.extend(mapped_args.into_iter().flat_map(|x| x.1));
// Assert path is matches for each row except for wild_card
assert!(
collection_vec.0.is_empty()
|| collection_vec.0 == col.path
@ -325,6 +386,15 @@ fn do_build_tree<'a>(
collection_vec.0 = col.path;
}
// expand assigns by newly added ones
item.assigns
.extend(mapped_args.iter().flat_map(|x| x.0.clone()));
// Add inner patterns to existing row
item.columns
.extend(mapped_args.into_iter().flat_map(|x| x.1));
// TODO: Handle special casetest of ListWithTail
if let Some(entry) = collection_vec.1.iter_mut().find(|item| item.0 == case) {
entry.1.push(item);
collection_vec