aiken/crates/aiken-lang/src/gen_uplc/decision_tree.rs

1572 lines
51 KiB
Rust

use core::fmt;
use pretty::RcDoc;
use std::{cmp::Ordering, fmt::Display, rc::Rc};
use indexmap::IndexMap;
use itertools::{Itertools, Position};
use crate::{
ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern},
expr::{lookup_data_type_by_tipo, Type, TypeVar, TypedExpr},
};
use super::{interner::AirInterner, tree::AirTree};
const PAIR_NEW_COLUMNS: usize = 2;
const MIN_NEW_COLUMNS: usize = 1;
#[derive(Clone, Default, Copy)]
struct Occurrence {
passed_wild_card: bool,
amount: usize,
}
#[derive(Clone, Debug)]
pub enum Path {
Pair(usize),
Tuple(usize),
Constr(Rc<Type>, usize),
OpaqueConstr(Rc<Type>),
List(usize),
ListTail(usize),
}
impl ToString for Path {
fn to_string(&self) -> String {
match self {
Path::Pair(i) => {
format!("pair_{}", i)
}
Path::Tuple(i) => {
format!("tuple_{}", i)
}
Path::Constr(_, i) => {
format!("constr_{}", i)
}
Path::OpaqueConstr(_) => "opaqueconstr".to_string(),
Path::List(i) => {
format!("list_{}", i)
}
Path::ListTail(i) => {
format!("listtail_{}", i)
}
}
}
}
impl PartialEq for Path {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Path::Pair(a), Path::Pair(b))
| (Path::Tuple(a), Path::Tuple(b))
| (Path::Constr(_, a), Path::Constr(_, b))
| (Path::List(a), Path::List(b))
| (Path::ListTail(a), Path::ListTail(b)) => a == b,
(Path::OpaqueConstr(_), Path::OpaqueConstr(_)) => true,
_ => false,
}
}
}
impl Eq for Path {}
#[derive(Clone, Debug)]
pub struct Assigned {
pub path: Vec<Path>,
pub assigned: String,
}
#[derive(Clone, Debug)]
struct RowItem<'a> {
path: Vec<Path>,
pattern: &'a TypedPattern,
}
#[derive(Clone, Debug)]
struct Row<'a> {
assigns: Vec<Assigned>,
columns: Vec<RowItem<'a>>,
then: String,
}
#[derive(Clone, Debug)]
struct PatternMatrix<'a> {
rows: Vec<Row<'a>>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CaseTest {
Constr(usize),
Int(String),
Bytes(Vec<u8>),
List(usize),
ListWithTail(usize),
Wild,
}
impl CaseTest {
pub fn get_air_pattern(&self, current_type: Rc<Type>) -> AirTree {
match self {
CaseTest::Constr(i) => {
if current_type.is_bool() {
AirTree::bool(1 == *i)
} else {
AirTree::int(i)
}
}
CaseTest::Int(i) => AirTree::int(i),
CaseTest::Bytes(vec) => AirTree::byte_array(vec.clone()),
CaseTest::List(_) => unreachable!(),
CaseTest::ListWithTail(_) => unreachable!(),
CaseTest::Wild => unreachable!(),
}
}
}
impl Display for CaseTest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CaseTest::Constr(i) => write!(f, "Constr({})", i),
CaseTest::Int(i) => write!(f, "Int({})", i),
CaseTest::Bytes(vec) => write!(f, "Bytes({:?})", vec),
CaseTest::List(i) => write!(f, "List({})", i),
CaseTest::ListWithTail(i) => write!(f, "ListWithTail({})", i),
CaseTest::Wild => write!(f, "Wild"),
}
}
}
#[derive(Debug, Clone)]
pub enum DecisionTree<'a> {
Switch {
path: Vec<Path>,
cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>,
},
ListSwitch {
path: Vec<Path>,
cases: Vec<(CaseTest, DecisionTree<'a>)>,
tail_cases: Vec<(CaseTest, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>,
},
HoistedLeaf(String, Vec<Assigned>),
HoistThen {
name: String,
assigns: Vec<Assigned>,
pattern: Box<DecisionTree<'a>>,
then: &'a TypedExpr,
},
}
#[derive(Eq, Hash, PartialEq, Clone, Debug)]
pub enum ScopePath {
Case(usize),
Fallback,
}
impl PartialOrd for ScopePath {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(ScopePath::Case(a), ScopePath::Case(b)) => Some(b.cmp(a)),
(ScopePath::Case(_), ScopePath::Fallback) => Some(Ordering::Greater),
(ScopePath::Fallback, ScopePath::Case(_)) => Some(Ordering::Less),
(ScopePath::Fallback, ScopePath::Fallback) => Some(Ordering::Equal),
}
}
}
impl Ord for ScopePath {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Eq, Hash, PartialEq, Clone, Debug, Default, PartialOrd, Ord)]
pub struct Scope {
scope: Vec<ScopePath>,
}
impl Scope {
pub fn new() -> Self {
Self { scope: vec![] }
}
pub fn push(&mut self, path: ScopePath) {
self.scope.push(path);
}
pub fn pop(&mut self) {
self.scope.pop();
}
pub fn common_ancestor(&mut self, other: &Scope) {
let scope = std::mem::replace(&mut self.scope, vec![]);
self.scope = scope
.into_iter()
.zip(other.scope.iter())
.map_while(|(a, b)| if a == *b { Some(a) } else { None })
.collect_vec()
}
pub fn len(&self) -> usize {
self.scope.len()
}
pub fn is_empty(&self) -> bool {
self.scope.is_empty()
}
}
enum Marker<'a, 'b> {
Pop,
Push(ScopePath, &'b DecisionTree<'a>),
PopPush(ScopePath, &'b DecisionTree<'a>),
}
impl<'a> DecisionTree<'a> {
pub fn to_pretty(&self) -> String {
let mut w = Vec::new();
self.to_doc().render(80, &mut w).unwrap();
String::from_utf8(w)
.unwrap()
.lines()
// This is a hack to deal with blank newlines
// that end up with a bunch of useless whitespace
// because of the nesting
.map(|l| {
if l.chars().all(|c| c.is_whitespace()) {
"".to_string()
} else {
l.to_string()
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn to_doc(&self) -> RcDoc<()> {
match self {
DecisionTree::Switch {
path,
cases,
default,
..
} => RcDoc::text("Switch(")
.append(
path.iter()
.fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| {
acc.append(
RcDoc::line()
.append(RcDoc::text(format!("{}", p.to_string())).nest(4)),
)
})
.append(RcDoc::line())
.append(RcDoc::text(")"))
.nest(4),
)
.append(
cases
.iter()
.fold(
RcDoc::line().append(RcDoc::text("cases(")),
|acc, (con, tree)| {
acc.append(RcDoc::line())
.append(RcDoc::text(format!("({}): ", con)))
.append(RcDoc::line())
.append(tree.to_doc().nest(4))
},
)
.append(RcDoc::line())
.append(RcDoc::text(")"))
.nest(4),
)
.append(
RcDoc::line()
.append(RcDoc::text("default : "))
.append(RcDoc::line())
.append(
default
.as_ref()
.map(|i| i.to_doc())
.unwrap_or(RcDoc::text("None")),
)
.append(RcDoc::line())
.nest(4),
)
.append(RcDoc::text(")")),
DecisionTree::ListSwitch {
path,
cases,
tail_cases,
default,
..
} => RcDoc::text("ListSwitch(")
.append(
path.iter()
.fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| {
acc.append(
RcDoc::line()
.append(RcDoc::text(format!("{}", p.to_string())).nest(4)),
)
})
.append(RcDoc::line())
.append(RcDoc::text(")"))
.nest(4),
)
.append(
cases
.iter()
.fold(
RcDoc::line().append(RcDoc::text("cases(")),
|acc, (con, tree)| {
acc.append(RcDoc::line())
.append(RcDoc::text(format!("({}): ", con)))
.append(RcDoc::line())
.append(tree.to_doc().nest(4))
},
)
.append(RcDoc::line())
.append(RcDoc::text(")"))
.nest(4),
)
.append(
tail_cases
.iter()
.fold(
RcDoc::line().append(RcDoc::text("tail_cases(")),
|acc, (con, tree)| {
acc.append(RcDoc::line())
.append(RcDoc::text(format!("({}): ", con)))
.append(RcDoc::line())
.append(tree.to_doc().nest(4))
},
)
.append(RcDoc::line())
.append(RcDoc::text(")"))
.nest(4),
)
.append(
RcDoc::line()
.append(RcDoc::text("default : "))
.append(RcDoc::line())
.append(
default
.as_ref()
.map(|i| i.to_doc())
.unwrap_or(RcDoc::text("None")),
)
.append(RcDoc::line())
.nest(4),
)
.append(RcDoc::text(")")),
DecisionTree::HoistedLeaf(name, _) => RcDoc::text(format!("Leaf({})", name)),
DecisionTree::HoistThen { name, pattern, .. } => RcDoc::text("HoistThen(")
.append(
RcDoc::line()
.append(RcDoc::text(format!("name : {}", name)))
.append(RcDoc::line())
.nest(4),
)
.append(
RcDoc::line()
.append(pattern.to_doc())
.append(RcDoc::line())
.nest(4),
)
.append(RcDoc::text(")")),
}
}
/// For fun I decided to do this without recursion
/// It doesn't look to bad lol
fn get_hoist_paths<'b>(&self, names: Vec<&'b String>) -> IndexMap<&'b String, Scope> {
let mut prev = vec![];
let mut current_path = Scope::new();
let mut tree = self;
let mut scope_map: IndexMap<&String, Scope> =
names.into_iter().map(|item| (item, Scope::new())).collect();
loop {
match tree {
DecisionTree::Switch { cases, default, .. } => {
prev.push(Marker::Pop);
if let Some(def) = default {
prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref()));
}
cases
.iter()
.enumerate()
.rev()
.for_each(|(index, (_, detree))| {
if index == 0 {
prev.push(Marker::Push(ScopePath::Case(index), detree));
} else {
prev.push(Marker::PopPush(ScopePath::Case(index), detree));
}
});
}
DecisionTree::ListSwitch {
cases,
tail_cases,
default,
..
} => {
prev.push(Marker::Pop);
if let Some(def) = default {
prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref()));
}
tail_cases
.iter()
.enumerate()
.rev()
.for_each(|(index, (_, detree))| {
prev.push(Marker::PopPush(
ScopePath::Case(index + cases.len()),
detree,
));
});
cases
.iter()
.enumerate()
.rev()
.for_each(|(index, (_, detree))| {
if index == 0 {
prev.push(Marker::Push(ScopePath::Case(index), detree));
} else {
prev.push(Marker::PopPush(ScopePath::Case(index), detree));
}
});
}
DecisionTree::HoistedLeaf(leaf_name, _) => {
let scope_for_name = scope_map
.get_mut(leaf_name)
.expect("Impossible, Leaf is based off of given names");
if scope_for_name.is_empty() {
*scope_for_name = current_path.clone();
} else {
scope_for_name.common_ancestor(&current_path);
}
}
DecisionTree::HoistThen { .. } => unreachable!(),
}
if let Some(action) = prev.pop() {
match action {
Marker::Pop => {
current_path.pop();
}
Marker::Push(p, dec_tree) => {
current_path.push(p);
tree = dec_tree;
}
Marker::PopPush(p, dec_tree) => {
current_path.pop();
current_path.push(p);
tree = dec_tree;
}
}
} else {
break;
}
}
scope_map
}
fn hoist_by_path(
&mut self,
current_path: &mut Scope,
name_paths: &mut Vec<(String, Scope)>,
hoistables: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
) {
match self {
DecisionTree::Switch { cases, default, .. } => {
cases.iter_mut().enumerate().for_each(|(index, (_, tree))| {
current_path.push(ScopePath::Case(index));
tree.hoist_by_path(current_path, name_paths, hoistables);
current_path.pop();
});
current_path.push(ScopePath::Fallback);
if let Some(def) = default {
def.hoist_by_path(current_path, name_paths, hoistables);
}
current_path.pop();
}
DecisionTree::ListSwitch {
cases,
tail_cases,
default,
..
} => {
cases.iter_mut().enumerate().for_each(|(index, (_, tree))| {
current_path.push(ScopePath::Case(index));
tree.hoist_by_path(current_path, name_paths, hoistables);
current_path.pop();
});
tail_cases
.iter_mut()
.enumerate()
.for_each(|(index, (_, tree))| {
current_path.push(ScopePath::Case(index + cases.len()));
tree.hoist_by_path(current_path, name_paths, hoistables);
current_path.pop();
});
current_path.push(ScopePath::Fallback);
if let Some(def) = default {
def.hoist_by_path(current_path, name_paths, hoistables);
}
current_path.pop();
}
DecisionTree::HoistedLeaf(_, _) => (),
DecisionTree::HoistThen { .. } => unreachable!(),
}
loop {
if let Some(name_path) = name_paths.pop() {
if name_path.1 == *current_path {
let (assigns, then) = hoistables.remove(&name_path.0).unwrap();
let pattern =
std::mem::replace(self, DecisionTree::HoistedLeaf("".to_string(), vec![]));
*self = DecisionTree::HoistThen {
name: name_path.0,
assigns,
pattern: pattern.into(),
then,
};
} else {
name_paths.push(name_path);
break;
}
} else {
break;
}
}
}
}
impl<'a> Display for DecisionTree<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.to_pretty())
}
}
pub struct TreeGen<'a, 'b> {
interner: &'b mut AirInterner,
data_types: &'b IndexMap<&'a DataTypeKey, &'a TypedDataType>,
wild_card_pattern: RowItem<'a>,
}
impl<'a, 'b> TreeGen<'a, 'b> {
pub fn new(
interner: &'b mut AirInterner,
data_types: &'b IndexMap<&'a DataTypeKey, &'a TypedDataType>,
wild_card_pattern: &'a TypedPattern,
) -> Self {
TreeGen {
interner,
data_types,
wild_card_pattern: RowItem {
path: vec![],
pattern: wild_card_pattern,
},
}
}
pub fn build_tree(
mut self,
subject_name: &String,
subject_tipo: &Rc<Type>,
clauses: &'a [TypedClause],
) -> DecisionTree<'a> {
let mut hoistables = IndexMap::new();
let rows = clauses
.iter()
.enumerate()
.map(|(index, clause)| {
let (assign, row_items) =
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
self.interner.intern(format!("__clause_then_{}", index));
let clause_then_name = self
.interner
.lookup_interned(&format!("__clause_then_{}", index));
hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
let row = Row {
assigns: assign.into_iter().collect_vec(),
columns: row_items,
then: clause_then_name,
};
self.interner.pop_text(format!("__clause_then_{}", index));
row
})
.collect_vec();
let mut tree = self.do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows },
&mut hoistables,
);
let scope_map = tree.get_hoist_paths(hoistables.keys().collect_vec());
let mut name_paths = scope_map
.into_iter()
.sorted_by(|a, b| a.1.cmp(&b.1))
.map(|(name, path)| (name.clone(), path))
.collect_vec();
tree.hoist_by_path(&mut Scope::new(), &mut name_paths, &mut hoistables);
// Do hoisting of thens here
tree
}
fn do_build_tree(
&mut self,
subject_name: &String,
subject_tipo: &Rc<Type>,
matrix: PatternMatrix<'a>,
then_map: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
// First step make sure all rows have same number of columns
// or something went wrong
assert!(matrix
.rows
.iter()
.all(|row| { row.columns.len() == column_length }));
let occurrence_col = highest_occurrence(&matrix, column_length);
// Find which column has the most important pattern
let Some(occurrence_col) = occurrence_col else {
// No more patterns to match on so we grab the first default row and return that
let mut fallback = matrix.rows;
let row = fallback.swap_remove(0);
let Some((assigns, _)) = then_map.get_mut(&row.then) else {
unreachable!()
};
if assigns.is_empty() {
*assigns = row.assigns.clone();
}
return DecisionTree::HoistedLeaf(row.then, row.assigns);
};
let mut longest_elems_no_tail = None;
let mut longest_elems_with_tail = None;
let mut has_list_pattern = false;
// List patterns are special so we need more information on length
matrix.rows.iter().for_each(|item| {
let col = &item.columns[occurrence_col];
match col.pattern {
Pattern::List { elements, tail, .. } => {
has_list_pattern = true;
if tail.is_none() {
match longest_elems_no_tail {
Some(elems_count) => {
if elems_count < elements.len() {
longest_elems_no_tail = Some(elements.len());
}
}
None => {
longest_elems_no_tail = Some(elements.len());
}
}
} else {
match longest_elems_with_tail {
Some(elems_count) => {
if elems_count < elements.len() {
longest_elems_with_tail = Some(elements.len());
}
}
None => {
longest_elems_with_tail = Some(elements.len());
}
}
}
}
_ => (),
}
});
let path = matrix
.rows
.get(0)
.unwrap()
.columns
.get(occurrence_col)
.map(|col| col.path.clone())
.unwrap_or(vec![]);
let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let (default_matrix, specialized_matrices) = matrix.rows.into_iter().fold(
(vec![], vec![]),
|(mut default_matrix, mut case_matrices): (Vec<Row>, Vec<(CaseTest, Vec<Row>)>),
mut row| {
// For example in the case of matching on []
if row.columns.is_empty() {
default_matrix.push(row);
return (default_matrix, case_matrices);
}
let col = row.columns.remove(occurrence_col);
let (case, remaining_patts) = match col.pattern {
Pattern::Var { .. } | Pattern::Discard { .. } => (CaseTest::Wild, vec![]),
Pattern::Int { value, .. } => (CaseTest::Int(value.clone()), vec![]),
Pattern::ByteArray { value, .. } => (CaseTest::Bytes(value.clone()), vec![]),
Pattern::List { elements, tail, .. } => (
if tail.is_none() {
CaseTest::List(elements.len())
} else {
CaseTest::ListWithTail(elements.len())
},
elements
.iter()
.chain(tail.as_ref().map(|tail| tail.as_ref()))
.enumerate()
.with_position()
.map(|elem| match elem {
Position::First((index, element))
| Position::Middle((index, element))
| Position::Only((index, element)) => {
let mut item_path = col.path.clone();
item_path.push(Path::List(index));
self.map_pattern_to_row(element, subject_tipo, item_path)
}
Position::Last((index, element)) => {
if tail.is_none() {
let mut item_path = col.path.clone();
item_path.push(Path::List(index));
self.map_pattern_to_row(element, subject_tipo, item_path)
} else {
let mut item_path = col.path.clone();
item_path.push(Path::ListTail(index));
self.map_pattern_to_row(element, subject_tipo, item_path)
}
}
})
.collect_vec(),
),
Pattern::Constructor {
name,
arguments,
tipo,
..
} => {
let data_type =
lookup_data_type_by_tipo(&self.data_types, &specialized_tipo).unwrap();
let (constr_index, _) = data_type
.constructors
.iter()
.enumerate()
.find(|(_, dt)| &dt.name == name)
.unwrap();
(
CaseTest::Constr(constr_index),
arguments
.iter()
.enumerate()
.map(|(index, arg)| {
let mut item_path = col.path.clone();
item_path.push(Path::Constr(tipo.clone(), index));
self.map_pattern_to_row(&arg.value, subject_tipo, item_path)
})
.collect_vec(),
)
}
Pattern::Tuple { .. } | Pattern::Pair { .. } | Pattern::Assign { .. } => {
unreachable!("{:#?}", col.pattern)
}
};
// Assert path is the same for each specialized row
assert!(path == col.path || matches!(case, CaseTest::Wild));
// expand assigns by newly added ones
row.assigns
.extend(remaining_patts.iter().flat_map(|x| x.0.clone()));
// Add inner patterns to existing row
let mut new_cols = remaining_patts.into_iter().flat_map(|x| x.1).collect_vec();
if matches!(case, CaseTest::ListWithTail(_)) {
new_cols.pop();
}
let added_columns = new_cols.len();
// Pop off tail so that it aligns more easily with other list patterns
new_cols.extend(row.columns);
row.columns = new_cols;
if let CaseTest::Wild = case {
let current_wild_cols = row.columns.len();
default_matrix.push(row.clone());
case_matrices.iter_mut().for_each(|(_, matrix)| {
let mut row = row.clone();
let total_cols = matrix[0].columns.len();
if total_cols != 0 {
let added_columns = total_cols - current_wild_cols;
for _ in 0..added_columns {
row.columns.insert(0, self.wild_card_pattern.clone());
}
matrix.push(row);
}
});
} else if let CaseTest::ListWithTail(case_length) = case {
// For lists with tail it's a special case where we also add it to existing patterns
// all the way to the longest element. The reason being that each list size greater
// than the list with tail could also match with could also match depending on the inner pattern.
// See tests below for an example
if let Some(longest_elems_no_tail) = longest_elems_no_tail {
for elem_count in case_length..=longest_elems_no_tail {
let case = CaseTest::List(elem_count);
let mut row = row.clone();
for _ in 0..(elem_count - case_length) {
row.columns
.insert(case_length, self.wild_card_pattern.clone());
}
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
}
}
let Some(longest_elems_with_tail) = longest_elems_with_tail else {
unreachable!()
};
for elem_count in case_length..=longest_elems_with_tail {
let case = CaseTest::ListWithTail(elem_count);
let mut row = row.clone();
for _ in 0..(elem_count - case_length) {
row.columns
.insert(case_length, self.wild_card_pattern.clone());
}
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
}
} else {
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
}
(default_matrix, case_matrices)
},
);
let default_matrix = PatternMatrix {
rows: default_matrix,
};
if has_list_pattern {
// Since the list_tail case might cover the rest of the possible matches extensively
// then fallback is optional here
let fallback_option = if default_matrix.rows.is_empty() {
None
} else {
Some(
self.do_build_tree(
subject_name,
subject_tipo,
// Since everything after this point had a wild card on or above
// the row for the selected column in front. Then we ignore the
// cases and continue to check other columns.
default_matrix,
then_map,
)
.into(),
)
};
let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices
.into_iter()
.partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_)));
DecisionTree::ListSwitch {
path,
cases: cases
.into_iter()
.map(|x| {
(
x.0,
self.do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows: x.1 },
then_map,
),
)
})
.collect_vec(),
tail_cases: tail_cases
.into_iter()
.map(|x| {
(
x.0,
self.do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows: x.1 },
then_map,
),
)
})
.collect_vec(),
default: fallback_option,
}
} else {
let fallback_option = if default_matrix.rows.is_empty() {
None
} else {
Some(
self.do_build_tree(
subject_name,
subject_tipo,
// Since everything after this point had a wild card on or above
// the row for the selected column in front. Then we ignore the
// cases and continue to check other columns.
default_matrix,
then_map,
)
.into(),
)
};
DecisionTree::Switch {
path,
cases: specialized_matrices
.into_iter()
.map(|x| {
(
x.0,
self.do_build_tree(
subject_name,
subject_tipo,
PatternMatrix { rows: x.1 },
then_map,
),
)
})
.collect_vec(),
default: fallback_option.into(),
}
}
}
fn map_pattern_to_row(
&self,
pattern: &'a TypedPattern,
subject_tipo: &Rc<Type>,
path: Vec<Path>,
) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let new_columns_added = if current_tipo.is_pair() {
PAIR_NEW_COLUMNS
} else if current_tipo.is_tuple() {
let Type::Tuple { elems, .. } = current_tipo.as_ref() else {
unreachable!("{:#?}", current_tipo)
};
elems.len()
} else if let Some(data) = lookup_data_type_by_tipo(self.data_types, &current_tipo) {
if data.constructors.len() == 1 {
data.constructors[0].arguments.len()
} else if data.is_never() {
0
} else {
MIN_NEW_COLUMNS
}
} else {
MIN_NEW_COLUMNS
};
match pattern {
Pattern::Var { name, .. } => (
vec![Assigned {
path: path.clone(),
assigned: name.clone(),
}],
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
),
Pattern::Assign { name, pattern, .. } => (
vec![Assigned {
path: path.clone(),
assigned: name.clone(),
}],
self.map_pattern_to_row(pattern, subject_tipo, path).1,
),
Pattern::Int { .. }
| Pattern::ByteArray { .. }
| Pattern::Discard { .. }
| Pattern::List { .. } => (
vec![],
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
),
Pattern::Constructor {
arguments, tipo, ..
} => {
let data_type = lookup_data_type_by_tipo(self.data_types, &current_tipo).unwrap();
let is_transparent =
data_type.opaque && data_type.constructors[0].arguments.len() == 1;
if data_type.constructors.len() == 1 || data_type.is_never() {
arguments
.iter()
.enumerate()
.fold((vec![], vec![]), |mut acc, (index, arg)| {
let arg_value = &arg.value;
let mut item_path = path.clone();
if is_transparent {
item_path.push(Path::OpaqueConstr(tipo.clone()));
} else {
item_path.push(Path::Constr(tipo.clone(), index));
}
let (assigns, patts) =
self.map_pattern_to_row(arg_value, subject_tipo, item_path);
acc.0.extend(assigns);
acc.1.extend(patts);
acc
})
} else {
(
vec![],
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
)
}
}
Pattern::Pair { fst, snd, .. } => {
let mut fst_path = path.clone();
fst_path.push(Path::Pair(0));
let mut snd_path = path;
snd_path.push(Path::Pair(1));
let (mut assigns, mut patts) = self.map_pattern_to_row(fst, subject_tipo, fst_path);
let (assign_snd, patt_snd) = self.map_pattern_to_row(snd, subject_tipo, snd_path);
assigns.extend(assign_snd);
patts.extend(patt_snd);
(assigns, patts)
}
Pattern::Tuple { elems, .. } => {
elems
.iter()
.enumerate()
.fold((vec![], vec![]), |mut acc, (index, item)| {
let mut item_path = path.clone();
item_path.push(Path::Tuple(index));
let (assigns, patts) =
self.map_pattern_to_row(item, subject_tipo, item_path);
acc.0.extend(assigns);
acc.1.extend(patts);
acc
})
}
}
}
fn insert_case(
&self,
case_matrices: &mut Vec<(CaseTest, Vec<Row<'a>>)>,
case: CaseTest,
default_matrix: &Vec<Row<'a>>,
new_row: Row<'a>,
added_columns: usize,
) {
if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
entry.1.push(new_row);
} else {
let mut rows = default_matrix.clone();
for _ in 0..added_columns {
for row in &mut rows {
row.columns.insert(0, self.wild_card_pattern.clone());
}
}
rows.push(new_row);
case_matrices.push((case, rows));
}
}
}
pub fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
while let Some((p, rest)) = path.split_first() {
subject_tipo = match p {
Path::Pair(index) | Path::Tuple(index) => {
subject_tipo.get_inner_types().swap_remove(*index)
}
Path::List(_) => subject_tipo.get_inner_types().swap_remove(0),
Path::ListTail(_) => subject_tipo,
Path::Constr(tipo, index) => tipo.arg_types().unwrap().swap_remove(*index),
Path::OpaqueConstr(tipo) => {
let x = tipo.arg_types().unwrap().swap_remove(0);
x
}
};
path = rest
}
match subject_tipo.as_ref() {
Type::Var { tipo, .. } => match &*tipo.borrow() {
TypeVar::Unbound { .. } | TypeVar::Generic { .. } => subject_tipo.clone(),
TypeVar::Link { tipo } => get_tipo_by_path(tipo.clone(), &[]),
},
_ => subject_tipo,
}
}
fn match_wild_card(pattern: &TypedPattern) -> bool {
match pattern {
Pattern::Var { .. } | Pattern::Discard { .. } => true,
Pattern::Assign { pattern, .. } => match_wild_card(pattern),
_ => false,
}
}
// A function to get which column has the most pattern matches before a wild card
// Returns none if all columns in the first row are wild cards
fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> Option<usize> {
let occurrences = [Occurrence::default()].repeat(column_length);
let occurrences =
matrix
.rows
.iter()
.fold(occurrences, |mut occurrences: Vec<Occurrence>, row| {
row.columns
.iter()
.enumerate()
.for_each(|(column_index, row_item)| {
let Some(occurrence_col) = occurrences.get_mut(column_index) else {
unreachable!()
};
if !match_wild_card(row_item.pattern) && !occurrence_col.passed_wild_card {
occurrence_col.amount += 1;
} else {
occurrence_col.passed_wild_card = true;
}
});
occurrences
});
// index and count
let mut highest_occurrence = (0, 0);
occurrences.iter().enumerate().for_each(|(index, occ)| {
if occ.amount > highest_occurrence.1 {
highest_occurrence.0 = index;
highest_occurrence.1 = occ.amount;
}
});
if highest_occurrence.1 == 0 {
None
} else {
Some(highest_occurrence.0)
}
}
#[cfg(test)]
mod tester {
use std::collections::HashMap;
use indexmap::IndexMap;
use crate::{
ast::{
Definition, ModuleKind, Span, TraceLevel, Tracing, TypedModule, TypedPattern,
UntypedModule,
},
builtins,
expr::TypedExpr,
gen_uplc::{decision_tree::TreeGen, interner::AirInterner},
parser,
tipo::error::{Error, Warning},
utils, IdGenerator,
};
fn parse(source_code: &str) -> UntypedModule {
let kind = ModuleKind::Lib;
let (ast, _) = parser::module(source_code, kind).expect("Failed to parse module");
ast
}
fn check_module(
ast: UntypedModule,
extra: Vec<(String, UntypedModule)>,
kind: ModuleKind,
tracing: Tracing,
) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
let id_gen = IdGenerator::new();
let mut warnings = vec![];
let mut module_types = HashMap::new();
module_types.insert("aiken".to_string(), builtins::prelude(&id_gen));
module_types.insert("aiken/builtin".to_string(), builtins::plutus(&id_gen));
for (package, module) in extra {
let mut warnings = vec![];
let typed_module = module
.infer(
&id_gen,
kind,
&package,
&module_types,
Tracing::All(TraceLevel::Verbose),
&mut warnings,
None,
)
.expect("extra dependency did not compile");
module_types.insert(package.clone(), typed_module.type_info.clone());
}
let result = ast.infer(
&id_gen,
kind,
"test/project",
&module_types,
tracing,
&mut warnings,
None,
);
result
.map(|o| (warnings.clone(), o))
.map_err(|e| (warnings, e))
}
fn check(ast: UntypedModule) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
check_module(ast, Vec::new(), ModuleKind::Lib, Tracing::verbose())
}
#[test]
fn thing() {
let source_code = r#"
test thing(){
when [1, 2, 3] is {
[] -> False
[4] -> fail
[a, 2, b] -> True
_ -> False
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When {
clauses, subject, ..
} = &function.body
else {
panic!()
};
let mut air_interner = AirInterner::new();
let data_types = IndexMap::new();
let pattern = TypedPattern::Discard {
name: "_".to_string(),
location: Span::empty(),
};
let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
println!("{}", tree);
}
#[test]
fn thing2() {
let source_code = r#"
test thing(){
when (1,2,#"",[]) is {
(a,b,#"", []) -> True
(1,b,#"", [1]) -> False
(3,b,#"aa", _) -> 2 == 2
_ -> 1 == 1
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When {
clauses, subject, ..
} = &function.body
else {
panic!()
};
let mut air_interner = AirInterner::new();
let data_types = IndexMap::new();
let pattern = TypedPattern::Discard {
name: "_".to_string(),
location: Span::empty(),
};
let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
println!("{}", tree);
}
#[test]
fn thing3() {
let source_code = r#"
test thing(){
when (1,2,#"",[]) is {
(2,b,#"", []) -> 4 == 4
(a,b,#"", [2, ..y]) -> True
(1,b,#"", [a]) -> False
(3,b,#"aa", [x, y, ..z]) -> 2 == 2
_ -> 1 == 1
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When {
clauses, subject, ..
} = &function.body
else {
panic!()
};
let mut air_interner = AirInterner::new();
let data_types = IndexMap::new();
let pattern = TypedPattern::Discard {
name: "_".to_string(),
location: Span::empty(),
};
let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
println!("{}", tree);
}
#[test]
fn thing4() {
let source_code = r#"
test thing(){
when (1,2,#"",[]) is {
(2,b,#"", []) -> 4 == 4
(a,b,#"", [2, ..y]) -> True
(1,b,#"", [a]) -> False
(3,b,#"aa", [x, y, ..z]) -> 2 == 2
(3,b, c, [x, 3 as q]) -> fail
_ -> 1 == 1
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When {
clauses, subject, ..
} = &function.body
else {
panic!()
};
let mut air_interner = AirInterner::new();
let data_types = IndexMap::new();
let pattern = TypedPattern::Discard {
name: "_".to_string(),
location: Span::empty(),
};
let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
println!("{}", tree);
}
#[test]
fn thing5() {
let source_code = r#"
test thing(){
when (1,[],#"",None) is {
(2,b,#"", Some(Seeded { choices: #"", .. })) -> 4 == 4
(a,b,#"", None) -> True
(1,b,#"", Some(Seeded{ choices: #"", ..})) -> False
(3,b,#"aa", Some(Replayed(..))) -> 2 == 2
(3,b, c, y) -> fail
_ -> 1 == 1
}
}
"#;
let (_, ast) = check(parse(source_code)).unwrap();
let Definition::Test(function) = &ast.definitions[0] else {
panic!()
};
let TypedExpr::When {
clauses, subject, ..
} = &function.body
else {
panic!()
};
let mut air_interner = AirInterner::new();
let id_gen = IdGenerator::new();
let data_types = builtins::prelude_data_types(&id_gen);
let pattern = TypedPattern::Discard {
name: "_".to_string(),
location: Span::empty(),
};
let data_types = utils::indexmap::as_ref_values(&data_types);
let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
println!("{}", tree);
}
}