Fix: issue crash in code gen with incorrect column length in decision trees (#1069)
* Fix: Deeply nested assignments would offset the new columns count calculation. Now we track relevant columns and their path to ensure each row has wildcards if they don't contain the relevant column * Add test plus clippy fix * Clippy fix * New version clippy fix
This commit is contained in:
@@ -881,7 +881,7 @@ pub enum Located<'a> {
|
||||
Annotation(&'a Annotation),
|
||||
}
|
||||
|
||||
impl<'a> Located<'a> {
|
||||
impl Located<'_> {
|
||||
pub fn definition_location(&self) -> Option<DefinitionLocation<'_>> {
|
||||
match self {
|
||||
Self::Expression(expression) => expression.definition_location(),
|
||||
@@ -1996,7 +1996,7 @@ impl<'de> serde::Deserialize<'de> for Bls12_381Point {
|
||||
{
|
||||
struct FieldVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for FieldVisitor {
|
||||
impl serde::de::Visitor<'_> for FieldVisitor {
|
||||
type Value = Field;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
|
||||
@@ -3,7 +3,7 @@ use pretty::RcDoc;
|
||||
use std::{cmp::Ordering, fmt::Display, rc::Rc};
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use itertools::{Itertools, Position};
|
||||
use itertools::{Either, Itertools, Position};
|
||||
|
||||
use crate::{
|
||||
ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern},
|
||||
@@ -12,10 +12,6 @@ use crate::{
|
||||
|
||||
use super::{interner::AirInterner, tree::AirTree};
|
||||
|
||||
const PAIR_NEW_COLUMNS: usize = 2;
|
||||
|
||||
const MIN_NEW_COLUMNS: usize = 1;
|
||||
|
||||
#[derive(Clone, Default, Copy)]
|
||||
struct Occurrence {
|
||||
passed_wild_card: bool,
|
||||
@@ -71,6 +67,26 @@ impl PartialEq for Path {
|
||||
|
||||
impl Eq for Path {}
|
||||
|
||||
impl PartialOrd for Path {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Path {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
match (self, other) {
|
||||
(Path::Pair(a), Path::Pair(b))
|
||||
| (Path::Tuple(a), Path::Tuple(b))
|
||||
| (Path::List(a), Path::List(b))
|
||||
| (Path::ListTail(a), Path::ListTail(b))
|
||||
| (Path::Constr(_, a), Path::Constr(_, b)) => a.cmp(b),
|
||||
(Path::OpaqueConstr(_), Path::OpaqueConstr(_)) => Ordering::Equal,
|
||||
_ => Ordering::Equal,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Assigned {
|
||||
pub path: Vec<Path>,
|
||||
@@ -568,7 +584,7 @@ impl<'a> DecisionTree<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Display for DecisionTree<'a> {
|
||||
impl Display for DecisionTree<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.to_pretty())
|
||||
}
|
||||
@@ -603,33 +619,86 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
) -> DecisionTree<'a> {
|
||||
let mut hoistables = IndexMap::new();
|
||||
|
||||
let rows = clauses
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, clause)| {
|
||||
// Assigns are split out from patterns so they can be handled
|
||||
// outside of the tree algorithm
|
||||
let (assign, row_items) =
|
||||
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
|
||||
let mut columns_added = vec![];
|
||||
|
||||
self.interner.intern(format!("__clause_then_{}", index));
|
||||
let clause_then_name = self
|
||||
.interner
|
||||
.lookup_interned(&format!("__clause_then_{}", index));
|
||||
let rows = {
|
||||
let rows_initial = clauses
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, clause)| {
|
||||
// Assigns are split out from patterns so they can be handled
|
||||
// outside of the tree algorithm
|
||||
let (assign, row_items) =
|
||||
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
|
||||
|
||||
hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
|
||||
self.interner.intern(format!("__clause_then_{}", index));
|
||||
let clause_then_name = self
|
||||
.interner
|
||||
.lookup_interned(&format!("__clause_then_{}", index));
|
||||
|
||||
let row = Row {
|
||||
assigns: assign.into_iter().collect_vec(),
|
||||
columns: row_items,
|
||||
then: clause_then_name,
|
||||
};
|
||||
hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
|
||||
|
||||
self.interner.pop_text(format!("__clause_then_{}", index));
|
||||
// Some good ol' mutation to track added columns per relevant path
|
||||
// relevant path indicating a column that has a pattern to test at some point in
|
||||
// one of the rows
|
||||
row_items.iter().for_each(|col| {
|
||||
if !columns_added.contains(&col.path) {
|
||||
columns_added.push(col.path.clone());
|
||||
}
|
||||
});
|
||||
|
||||
row
|
||||
})
|
||||
.collect_vec();
|
||||
let row = Row {
|
||||
assigns: assign.into_iter().collect_vec(),
|
||||
columns: row_items,
|
||||
then: clause_then_name,
|
||||
};
|
||||
|
||||
self.interner.pop_text(format!("__clause_then_{}", index));
|
||||
|
||||
row
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
columns_added = columns_added
|
||||
.into_iter()
|
||||
.sorted_by(|a, b| {
|
||||
let mut a = a.clone();
|
||||
let mut b = b.clone();
|
||||
|
||||
// It's impossible for duplicates since we check before insertion
|
||||
while let Ordering::Equal = a.first().cmp(&b.first()) {
|
||||
a.remove(0);
|
||||
b.remove(0);
|
||||
}
|
||||
|
||||
a.first().cmp(&b.first())
|
||||
})
|
||||
.map(|col| {
|
||||
// remove opaqueconstr paths since they are just type information
|
||||
col.into_iter()
|
||||
.filter(|a| !matches!(a, Path::OpaqueConstr(_)))
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
rows_initial
|
||||
.into_iter()
|
||||
.map(|mut row| {
|
||||
for (index, path) in columns_added.iter().enumerate() {
|
||||
if !row
|
||||
.columns
|
||||
.get(index)
|
||||
.map(|col| col.path == *path)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
row.columns.insert(index, self.wild_card_pattern.clone());
|
||||
}
|
||||
}
|
||||
|
||||
row
|
||||
})
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
let mut tree = self.do_build_tree(subject_tipo, PatternMatrix { rows }, &mut hoistables);
|
||||
|
||||
@@ -654,6 +723,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
then_map: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
|
||||
) -> DecisionTree<'a> {
|
||||
let column_length = matrix.rows[0].columns.len();
|
||||
|
||||
// First step make sure all rows have same number of columns
|
||||
// or something went wrong
|
||||
assert!(matrix
|
||||
@@ -734,6 +804,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
.unwrap();
|
||||
|
||||
let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
|
||||
let mut relevant_columns: Vec<(CaseTest, Vec<Vec<Path>>)> = vec![];
|
||||
|
||||
// Time to split on the matrices based on case to test for or lack of
|
||||
let (default_matrix, specialized_matrices) = matrix.rows.into_iter().fold(
|
||||
@@ -841,37 +912,37 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
// Add inner patterns to existing row
|
||||
let mut new_cols = remaining_patts.into_iter().flat_map(|x| x.1).collect_vec();
|
||||
|
||||
// To align number of columns we pop off the tail since it can
|
||||
// never include a pattern besides wild card
|
||||
if matches!(case, CaseTest::ListWithTail(_)) {
|
||||
new_cols.pop();
|
||||
}
|
||||
let new_paths = new_cols.iter().map(|col| col.path.clone()).collect_vec();
|
||||
|
||||
let added_columns = new_cols.len();
|
||||
|
||||
// Pop off tail so that it aligns more easily with other list patterns
|
||||
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
|
||||
new_paths.iter().for_each(|col| {
|
||||
if !a.1.contains(col) {
|
||||
a.1.push(col.clone());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
relevant_columns.push((case.clone(), new_paths.clone()));
|
||||
};
|
||||
|
||||
new_cols.extend(row.columns);
|
||||
|
||||
row.columns = new_cols;
|
||||
|
||||
if let CaseTest::Wild = case {
|
||||
let current_wild_cols = row.columns.len();
|
||||
default_matrix.push(row.clone());
|
||||
|
||||
case_matrices.iter_mut().for_each(|(_, matrix)| {
|
||||
let mut row = row.clone();
|
||||
let total_cols = matrix[0].columns.len();
|
||||
case_matrices.iter_mut().for_each(|(case, matrix)| {
|
||||
let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == *case) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if total_cols != 0 {
|
||||
let added_columns = total_cols - current_wild_cols;
|
||||
|
||||
for _ in 0..added_columns {
|
||||
row.columns.insert(0, self.wild_card_pattern.clone());
|
||||
new_paths.iter().for_each(|col| {
|
||||
if !a.1.contains(col) {
|
||||
a.1.push(col.clone());
|
||||
}
|
||||
});
|
||||
|
||||
matrix.push(row);
|
||||
}
|
||||
matrix.push(row.clone());
|
||||
});
|
||||
} else if let CaseTest::ListWithTail(tail_case_length) = case {
|
||||
// For lists with tail it's a special case where we also add it to existing patterns
|
||||
@@ -886,20 +957,30 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
for elem_count in tail_case_length..=longest_elems_no_tail {
|
||||
let case = CaseTest::List(elem_count);
|
||||
|
||||
let mut row = row.clone();
|
||||
// paths first
|
||||
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
|
||||
new_paths.iter().for_each(|col| {
|
||||
if !a.1.contains(col) {
|
||||
a.1.push(col.clone());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
relevant_columns.push((case.clone(), new_paths.clone()));
|
||||
};
|
||||
|
||||
for _ in 0..(elem_count - tail_case_length) {
|
||||
row.columns
|
||||
.insert(tail_case_length, self.wild_card_pattern.clone());
|
||||
let row = row.clone();
|
||||
|
||||
// now insertion into the appropriate matrix
|
||||
if let Some(entry) =
|
||||
case_matrices.iter_mut().find(|item| item.0 == case)
|
||||
{
|
||||
entry.1.push(row);
|
||||
} else {
|
||||
let mut rows = default_matrix.to_vec();
|
||||
|
||||
rows.push(row);
|
||||
case_matrices.push((case, rows));
|
||||
}
|
||||
|
||||
self.insert_case(
|
||||
&mut case_matrices,
|
||||
case,
|
||||
&default_matrix,
|
||||
row,
|
||||
added_columns,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -907,35 +988,123 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
for elem_count in tail_case_length..=longest_elems_with_tail {
|
||||
let case = CaseTest::ListWithTail(elem_count);
|
||||
|
||||
let mut row = row.clone();
|
||||
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
|
||||
new_paths.iter().for_each(|col| {
|
||||
if !a.1.contains(col) {
|
||||
a.1.push(col.clone());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
relevant_columns.push((case.clone(), new_paths.clone()));
|
||||
};
|
||||
|
||||
for _ in 0..(elem_count - tail_case_length) {
|
||||
row.columns
|
||||
.insert(tail_case_length, self.wild_card_pattern.clone());
|
||||
let row = row.clone();
|
||||
|
||||
if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
|
||||
entry.1.push(row);
|
||||
} else {
|
||||
let mut rows = default_matrix.clone();
|
||||
|
||||
rows.push(row);
|
||||
case_matrices.push((case, rows));
|
||||
}
|
||||
|
||||
self.insert_case(
|
||||
&mut case_matrices,
|
||||
case,
|
||||
&default_matrix,
|
||||
row,
|
||||
added_columns,
|
||||
);
|
||||
}
|
||||
} else if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
|
||||
entry.1.push(row);
|
||||
} else {
|
||||
self.insert_case(
|
||||
&mut case_matrices,
|
||||
case,
|
||||
&default_matrix,
|
||||
row,
|
||||
added_columns,
|
||||
);
|
||||
let mut rows = default_matrix.clone();
|
||||
|
||||
rows.push(row);
|
||||
case_matrices.push((case, rows));
|
||||
}
|
||||
|
||||
(default_matrix, case_matrices)
|
||||
},
|
||||
);
|
||||
|
||||
let (default_relevant_cols, relevant_columns): (Vec<_>, Vec<_>) = relevant_columns
|
||||
.into_iter()
|
||||
.map(|(case, paths)| {
|
||||
(
|
||||
case,
|
||||
paths
|
||||
.into_iter()
|
||||
.sorted_by(|a, b| {
|
||||
let mut a = a.clone();
|
||||
let mut b = b.clone();
|
||||
|
||||
// It's impossible for duplicates since we check before insertion
|
||||
while let Ordering::Equal = a.first().cmp(&b.first()) {
|
||||
a.remove(0);
|
||||
b.remove(0);
|
||||
}
|
||||
|
||||
a.first().cmp(&b.first())
|
||||
})
|
||||
.map(|col| {
|
||||
// remove opaqueconstr paths since they are just type information
|
||||
col.into_iter()
|
||||
.filter(|a| !matches!(a, Path::OpaqueConstr(_)))
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.partition_map(|(case, paths)| match case {
|
||||
CaseTest::Wild => Either::Left(paths),
|
||||
_ => Either::Right((case, paths)),
|
||||
});
|
||||
|
||||
let default_matrix = default_matrix
|
||||
.into_iter()
|
||||
.map(|mut row| {
|
||||
for (index, path) in default_relevant_cols[0].iter().enumerate() {
|
||||
if !row
|
||||
.columns
|
||||
.get(index)
|
||||
.map(|col| col.path == *path)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
row.columns.insert(index, self.wild_card_pattern.clone());
|
||||
}
|
||||
}
|
||||
|
||||
row
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let specialized_matrices = specialized_matrices
|
||||
.into_iter()
|
||||
.map(|(case, matrix)| {
|
||||
let Some((_, relevant_cols)) = relevant_columns
|
||||
.iter()
|
||||
.find(|(relevant_case, _)| relevant_case == &case)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
(
|
||||
case,
|
||||
matrix
|
||||
.into_iter()
|
||||
.map(|mut row| {
|
||||
for (index, path) in relevant_cols.iter().enumerate() {
|
||||
if !row
|
||||
.columns
|
||||
.get(index)
|
||||
.map(|col| col.path == *path)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
row.columns.insert(index, self.wild_card_pattern.clone());
|
||||
}
|
||||
}
|
||||
|
||||
row
|
||||
})
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let default_matrix = PatternMatrix {
|
||||
rows: default_matrix,
|
||||
};
|
||||
@@ -1001,39 +1170,13 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
|
||||
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
|
||||
|
||||
let new_columns_added = if current_tipo.is_pair() {
|
||||
PAIR_NEW_COLUMNS
|
||||
} else if current_tipo.is_tuple() {
|
||||
let Type::Tuple { elems, .. } = current_tipo.as_ref() else {
|
||||
unreachable!("{:#?}", current_tipo)
|
||||
};
|
||||
elems.len()
|
||||
} else if let Some(data) = lookup_data_type_by_tipo(self.data_types, ¤t_tipo) {
|
||||
if data.constructors.len() == 1 {
|
||||
data.constructors[0].arguments.len()
|
||||
} else if data.is_never() {
|
||||
0
|
||||
} else {
|
||||
MIN_NEW_COLUMNS
|
||||
}
|
||||
} else {
|
||||
MIN_NEW_COLUMNS
|
||||
};
|
||||
|
||||
match pattern {
|
||||
Pattern::Var { name, .. } => (
|
||||
vec![Assigned {
|
||||
path: path.clone(),
|
||||
assigned: name.clone(),
|
||||
}],
|
||||
vec![RowItem {
|
||||
pattern,
|
||||
path: path.clone(),
|
||||
}]
|
||||
.into_iter()
|
||||
.cycle()
|
||||
.take(new_columns_added)
|
||||
.collect_vec(),
|
||||
vec![],
|
||||
),
|
||||
|
||||
Pattern::Assign { name, pattern, .. } => {
|
||||
@@ -1049,19 +1192,14 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
);
|
||||
(assigns, patts)
|
||||
}
|
||||
Pattern::Int { .. }
|
||||
| Pattern::ByteArray { .. }
|
||||
| Pattern::Discard { .. }
|
||||
| Pattern::List { .. } => (
|
||||
Pattern::Discard { .. } => (vec![], vec![]),
|
||||
|
||||
Pattern::Int { .. } | Pattern::ByteArray { .. } | Pattern::List { .. } => (
|
||||
vec![],
|
||||
vec![RowItem {
|
||||
pattern,
|
||||
path: path.clone(),
|
||||
}]
|
||||
.into_iter()
|
||||
.cycle()
|
||||
.take(new_columns_added)
|
||||
.collect_vec(),
|
||||
}],
|
||||
),
|
||||
|
||||
Pattern::Constructor {
|
||||
@@ -1101,11 +1239,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
vec![RowItem {
|
||||
pattern,
|
||||
path: path.clone(),
|
||||
}]
|
||||
.into_iter()
|
||||
.cycle()
|
||||
.take(new_columns_added)
|
||||
.collect_vec(),
|
||||
}],
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1126,6 +1260,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
|
||||
(assigns, patts)
|
||||
}
|
||||
|
||||
Pattern::Tuple { elems, .. } => {
|
||||
elems
|
||||
.iter()
|
||||
@@ -1146,30 +1281,6 @@ impl<'a, 'b> TreeGen<'a, 'b> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_case(
|
||||
&self,
|
||||
case_matrices: &mut Vec<(CaseTest, Vec<Row<'a>>)>,
|
||||
case: CaseTest,
|
||||
default_matrix: &[Row<'a>],
|
||||
new_row: Row<'a>,
|
||||
added_columns: usize,
|
||||
) {
|
||||
if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
|
||||
entry.1.push(new_row);
|
||||
} else {
|
||||
let mut rows = default_matrix.to_vec();
|
||||
|
||||
for _ in 0..added_columns {
|
||||
for row in &mut rows {
|
||||
row.columns.insert(0, self.wild_card_pattern.clone());
|
||||
}
|
||||
}
|
||||
|
||||
rows.push(new_row);
|
||||
case_matrices.push((case, rows));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
|
||||
|
||||
@@ -622,7 +622,7 @@ pub struct Counterexample<'a> {
|
||||
pub cache: Cache<'a, PlutusData>,
|
||||
}
|
||||
|
||||
impl<'a> Counterexample<'a> {
|
||||
impl Counterexample<'_> {
|
||||
fn consider(&mut self, choices: &[u8]) -> bool {
|
||||
if choices == self.choices {
|
||||
return true;
|
||||
|
||||
@@ -1284,7 +1284,7 @@ fn suggest_pattern(
|
||||
}
|
||||
}
|
||||
|
||||
fn suggest_generic(name: &String, expected: usize) -> String {
|
||||
fn suggest_generic(name: &str, expected: usize) -> String {
|
||||
if expected == 0 {
|
||||
return name.to_doc().to_pretty_string(70);
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ impl Printer {
|
||||
}
|
||||
}
|
||||
|
||||
fn qualify_type_name(module: &String, typ_name: &str) -> Document<'static> {
|
||||
fn qualify_type_name(module: &str, typ_name: &str) -> Document<'static> {
|
||||
if module.is_empty() {
|
||||
docvec!["aiken.", Document::String(typ_name.to_string())]
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user