Handle tuples and pairs properly now

This commit is contained in:
microproofs 2024-10-09 22:04:12 -04:00
parent 97ee1a8ba6
commit 20385a7ecd
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 218 additions and 168 deletions

View File

@ -1,6 +1,6 @@
use std::{cmp::Ordering, rc::Rc}; use std::{cmp::Ordering, rc::Rc};
use itertools::Itertools; use itertools::{Itertools, Tuples};
use crate::{ use crate::{
ast::{Pattern, TypedClause, TypedPattern}, ast::{Pattern, TypedClause, TypedPattern},
@ -13,60 +13,32 @@ struct Occurrence {
amount: usize, amount: usize,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
enum Path {
Pair(usize),
Tuple(usize),
}
impl Path {
pub fn get_index(&self) -> usize {
match self {
Path::Pair(u) | Path::Tuple(u) => *u,
}
}
}
#[derive(Clone, Debug)]
struct RowItem<'a> { struct RowItem<'a> {
assign: Option<String>, path: Vec<Path>,
pattern: &'a TypedPattern, pattern: &'a TypedPattern,
} }
#[derive(Clone, Eq, PartialEq)] #[derive(Clone, Debug)]
pub enum CaseTest {
Constr(PatternConstructor),
Int(String),
Bytes(Vec<u8>),
List(usize),
Wild,
}
impl PartialOrd for CaseTest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal),
(CaseTest::Wild, _) => Some(Ordering::Less),
(_, CaseTest::Wild) => Some(Ordering::Greater),
(_, _) => Some(Ordering::Equal),
}
}
}
impl Ord for CaseTest {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Ordering::Equal,
(CaseTest::Wild, _) => Ordering::Less,
(_, CaseTest::Wild) => Ordering::Greater,
(_, _) => Ordering::Equal,
}
}
}
struct Assign { struct Assign {
subject_name: String, path: Vec<Path>,
subject_tuple_index: Option<usize>,
assigned: String, assigned: String,
} }
enum DecisionTree {
Switch {
subject_name: String,
subject_tuple_index: Option<usize>,
subject_tipo: Rc<Type>,
column_to_test: usize,
cases: Vec<(CaseTest, DecisionTree)>,
default: Box<DecisionTree>,
},
Leaf(TypedExpr),
}
struct Row<'a> { struct Row<'a> {
assigns: Vec<Assign>, assigns: Vec<Assign>,
columns: Vec<RowItem<'a>>, columns: Vec<RowItem<'a>>,
@ -77,60 +49,152 @@ struct PatternMatrix<'a> {
rows: Vec<Row<'a>>, rows: Vec<Row<'a>>,
} }
fn map_to_row<'a>( #[derive(Clone, Eq, PartialEq)]
pub enum CaseTest {
Constr(PatternConstructor),
Int(String),
Bytes(Vec<u8>),
List(usize),
Wild,
}
impl PartialOrd for CaseTest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Some(Ordering::Equal),
(CaseTest::Wild, _) => Some(Ordering::Less),
(_, CaseTest::Wild) => Some(Ordering::Greater),
(_, _) => Some(Ordering::Equal),
}
}
}
impl Ord for CaseTest {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(CaseTest::Wild, CaseTest::Wild) => Ordering::Equal,
(CaseTest::Wild, _) => Ordering::Less,
(_, CaseTest::Wild) => Ordering::Greater,
(_, _) => Ordering::Equal,
}
}
}
enum DecisionTree {
Switch {
subject_name: String,
subject_tipo: Rc<Type>,
column_to_test: usize,
cases: Vec<(CaseTest, Vec<Path>, DecisionTree)>,
default: (Vec<Path>, Box<DecisionTree>),
},
Leaf(TypedExpr),
}
fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
while let Some((p, rest)) = path.split_first() {
let index = p.get_index();
subject_tipo = subject_tipo.arg_types().unwrap().remove(index);
path = rest
}
subject_tipo
}
fn map_pattern_to_row<'a>(
pattern: &'a TypedPattern, pattern: &'a TypedPattern,
subject_name: &String, subject_name: &String,
column_count: usize, subject_tipo: Rc<Type>,
) -> Vec<RowItem<'a>> { path: Vec<Path>,
) -> (Vec<Assign>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let new_columns_added = if current_tipo.is_pair() {
2
} else if current_tipo.is_tuple() {
let Type::Tuple { elems, .. } = subject_tipo.as_ref() else {
unreachable!()
};
elems.len()
} else {
1
};
match pattern { match pattern {
Pattern::Var { name, .. } => vec![RowItem { Pattern::Var { name, .. } => (
assign: Some(name.clone()), vec![Assign {
path: path.clone(),
assigned: name.clone(),
}],
vec![RowItem {
pattern, pattern,
path: path.clone(),
}] }]
.into_iter() .into_iter()
.cycle() .cycle()
.take(column_count) .take(new_columns_added)
.collect_vec(), .collect_vec(),
),
Pattern::Assign { name, pattern, .. } => { Pattern::Assign { name, pattern, .. } => (
let p = map_to_row(pattern, subject_name, column_count); vec![Assign {
p.into_iter() path: path.clone(),
.map(|mut item| { assigned: name.clone(),
item.assign = Some(name.clone()); }],
item map_pattern_to_row(pattern, subject_name, subject_tipo.clone(), path).1,
}) ),
.collect_vec()
}
Pattern::Int { .. } Pattern::Int { .. }
| Pattern::ByteArray { .. } | Pattern::ByteArray { .. }
| Pattern::Discard { .. } | Pattern::Discard { .. }
| Pattern::List { .. } | Pattern::List { .. }
| Pattern::Constructor { .. } => vec![RowItem { | Pattern::Constructor { .. } => (
assign: None, vec![],
vec![RowItem {
pattern, pattern,
path: path.clone(),
}] }]
.into_iter() .into_iter()
.cycle() .cycle()
.take(column_count) .take(new_columns_added)
.collect_vec(), .collect_vec(),
),
Pattern::Pair { fst, snd, .. } => vec![ Pattern::Pair { fst, snd, .. } => {
RowItem { let mut fst_path = path.clone();
assign: None, fst_path.push(Path::Pair(0));
pattern: fst, let mut snd_path = path;
}, snd_path.push(Path::Pair(1));
RowItem {
assign: None, let (mut assigns, mut patts) =
pattern: snd, map_pattern_to_row(fst, subject_name, subject_tipo.clone(), fst_path);
},
], let (assign_snd, patt_snd) =
Pattern::Tuple { elems, .. } => elems map_pattern_to_row(snd, subject_name, subject_tipo.clone(), snd_path);
assigns.extend(assign_snd.into_iter());
patts.extend(patt_snd.into_iter());
(assigns, patts)
}
Pattern::Tuple { elems, .. } => {
elems
.iter() .iter()
.map(|elem| RowItem { .enumerate()
assign: None, .fold((vec![], vec![]), |mut acc, (index, item)| {
pattern: elem, let mut item_path = path.clone();
item_path.push(Path::Tuple(index));
let (assigns, patts) =
map_pattern_to_row(item, subject_name, subject_tipo.clone(), item_path);
acc.0.extend(assigns.into_iter());
acc.1.extend(patts.into_iter());
acc
}) })
.collect_vec(), }
} }
} }
@ -147,64 +211,36 @@ pub fn build_tree(
subject_tipo: Rc<Type>, subject_tipo: Rc<Type>,
clauses: &Vec<TypedClause>, clauses: &Vec<TypedClause>,
) -> DecisionTree { ) -> DecisionTree {
let column_count = if subject_tipo.is_pair() {
2
} else if subject_tipo.is_tuple() {
let Type::Tuple { elems, .. } = subject_tipo.as_ref() else {
unreachable!()
};
elems.len()
} else {
1
};
let rows = clauses let rows = clauses
.iter() .iter()
.map(|clause| { .map(|clause| {
let row_items = map_to_row(&clause.pattern, subject_name, column_count); let (assign, row_items) =
map_pattern_to_row(&clause.pattern, subject_name, subject_tipo.clone(), vec![]);
Row { Row {
assigns: vec![], assigns: assign.into_iter().collect_vec(),
columns: row_items, columns: row_items,
then: &clause.then, then: &clause.then,
} }
}) })
.collect_vec(); .collect_vec();
let subject_per_column = if column_count > 1 { do_build_tree(subject_name, &subject_tipo, PatternMatrix { rows })
(0..column_count)
.map(|index| (subject_name.clone(), Some(index)))
.collect_vec()
} else {
vec![(subject_name.clone(), None)]
};
do_build_tree(
subject_name,
subject_tipo,
subject_per_column,
PatternMatrix { rows },
)
} }
pub fn do_build_tree<'a>( pub fn do_build_tree<'a>(
subject_name: &String, subject_name: &String,
subject_tipo: Rc<Type>, subject_tipo: &Rc<Type>,
subject_per_column: Vec<(String, Option<usize>)>,
matrix: PatternMatrix<'a>, matrix: PatternMatrix<'a>,
) -> DecisionTree { ) -> DecisionTree {
let column_count = if subject_tipo.is_pair() { let column_length = matrix.rows[0].columns.len();
2
} else if subject_tipo.is_tuple() {
let Type::Tuple { elems, .. } = subject_tipo.as_ref() else {
unreachable!()
};
elems.len()
} else {
1
};
let occurrences = [Occurrence::default()].repeat(column_count); assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
let occurrences = [Occurrence::default()].repeat(column_length);
let occurrences = let occurrences =
matrix matrix
@ -238,10 +274,9 @@ pub fn do_build_tree<'a>(
} }
}); });
if column_count > 1 { if column_length > 1 {
DecisionTree::Switch { DecisionTree::Switch {
subject_name: subject_name.clone(), subject_name: subject_name.clone(),
subject_tuple_index: None,
subject_tipo: subject_tipo.clone(), subject_tipo: subject_tipo.clone(),
column_to_test: highest_occurrence.0, column_to_test: highest_occurrence.0,
cases: todo!(), cases: todo!(),
@ -250,42 +285,53 @@ pub fn do_build_tree<'a>(
} else { } else {
let mut collection_vec = matrix.rows.into_iter().fold( let mut collection_vec = matrix.rows.into_iter().fold(
vec![], vec![],
|mut collection_vec: Vec<(CaseTest, Vec<Row<'a>>)>, mut item: Row<'a>| { |mut collection_vec: Vec<(CaseTest, Vec<Path>, Vec<Row<'a>>)>, mut item: Row<'a>| {
let col = item.columns.remove(highest_occurrence.0); let col = item.columns.remove(highest_occurrence.0);
let mut patt = col.pattern;
if let Pattern::Assign { pattern, .. } = patt { assert!(!matches!(col.pattern, Pattern::Assign { .. }));
patt = pattern;
}
if let Some(assign) = col.assign { let (mapped_args, case) = match col.pattern {
item.assigns.push(Assign { Pattern::Int { value, .. } => (vec![], CaseTest::Int(value.clone())),
subject_name: subject_name.clone(), Pattern::ByteArray { value, .. } => (vec![], CaseTest::Bytes(value.clone())),
subject_tuple_index: None, Pattern::Var { .. } | Pattern::Discard { .. } => (vec![], CaseTest::Wild),
assigned: assign, Pattern::List { elements, .. } => (
}); elements
} .iter()
.enumerate()
.map(|(index, item)| {
let mut item_path = col.path.clone();
let case = match patt { item_path.push(Path::Tuple(index));
Pattern::Int { value, .. } => CaseTest::Int(value.clone()),
Pattern::ByteArray { value, .. } => CaseTest::Bytes(value.clone()), map_pattern_to_row(
Pattern::Var { .. } | Pattern::Discard { .. } => CaseTest::Wild, item,
Pattern::List { elements, .. } => CaseTest::List(elements.len()), subject_name,
Pattern::Constructor { constructor, .. } => { subject_tipo.clone(),
CaseTest::Constr(constructor.clone()) item_path,
)
})
.collect_vec(),
CaseTest::List(elements.len()),
),
Pattern::Constructor { .. } => {
todo!()
} }
Pattern::Pair { .. } => todo!(), _ => unreachable!("{:#?}", col.pattern),
Pattern::Tuple { .. } => todo!(),
_ => unreachable!(),
}; };
item.assigns
.extend(mapped_args.iter().map(|x| x.0.clone()).flatten());
item.columns
.extend(mapped_args.into_iter().map(|x| x.1).flatten());
if let Some(index) = collection_vec.iter().position(|item| item.0 == case) { if let Some(index) = collection_vec.iter().position(|item| item.0 == case) {
let entry = collection_vec.get_mut(index).unwrap(); let entry = collection_vec.get_mut(index).unwrap();
entry.1.push(item); entry.2.push(item);
collection_vec collection_vec
} else { } else {
collection_vec.push((case, vec![item])); collection_vec.push((case, col.path, vec![item]));
collection_vec collection_vec
} }
@ -293,27 +339,31 @@ pub fn do_build_tree<'a>(
); );
collection_vec.sort_by(|a, b| a.0.cmp(&b.0)); collection_vec.sort_by(|a, b| a.0.cmp(&b.0));
let mut collection_iter = collection_vec.into_iter().peekable(); let mut collection_iter = collection_vec
.into_iter()
.map(|x| {
(
x.0,
x.1,
do_build_tree(subject_name, subject_tipo, PatternMatrix { rows: x.2 }),
)
})
.peekable();
let cases = collection_iter let cases = collection_iter
.peeking_take_while(|a| !matches!(a.0, CaseTest::Wild)) .peeking_take_while(|a| !matches!(a.0, CaseTest::Wild))
.collect_vec(); .collect_vec();
let mut fallback = collection_iter.collect_vec(); let mut fallback = collection_iter.map(|x| (x.1, x.2.into())).collect_vec();
assert!(fallback.len() == 1); assert!(fallback.len() == 1);
let fallback_matrix = PatternMatrix {
rows: fallback.remove(0).1,
};
DecisionTree::Switch { DecisionTree::Switch {
subject_name: subject_name.clone(), subject_name: subject_name.clone(),
subject_tuple_index: None,
subject_tipo: subject_tipo.clone(), subject_tipo: subject_tipo.clone(),
column_to_test: highest_occurrence.0, column_to_test: highest_occurrence.0,
cases: todo!(), cases,
default: todo!(), default: fallback.remove(0),
} }
}; };