Fix: issue crash in code gen with incorrect column length in decision trees (#1069)

* Fix: Deeply nested assignments would offset the new columns count calculation. Now we track relevant columns and their path to ensure each row has wildcards if they don't contain the relevant column

* Add test plus clippy fix

* Clippy fix

* New version clippy fix
This commit is contained in:
Kasey
2024-12-05 11:02:19 +07:00
committed by GitHub
parent a9675fedc6
commit 86ec3b2924
13 changed files with 392 additions and 183 deletions

View File

@@ -881,7 +881,7 @@ pub enum Located<'a> {
Annotation(&'a Annotation),
}
impl<'a> Located<'a> {
impl Located<'_> {
pub fn definition_location(&self) -> Option<DefinitionLocation<'_>> {
match self {
Self::Expression(expression) => expression.definition_location(),
@@ -1996,7 +1996,7 @@ impl<'de> serde::Deserialize<'de> for Bls12_381Point {
{
struct FieldVisitor;
impl<'de> serde::de::Visitor<'de> for FieldVisitor {
impl serde::de::Visitor<'_> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {

View File

@@ -3,7 +3,7 @@ use pretty::RcDoc;
use std::{cmp::Ordering, fmt::Display, rc::Rc};
use indexmap::IndexMap;
use itertools::{Itertools, Position};
use itertools::{Either, Itertools, Position};
use crate::{
ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern},
@@ -12,10 +12,6 @@ use crate::{
use super::{interner::AirInterner, tree::AirTree};
const PAIR_NEW_COLUMNS: usize = 2;
const MIN_NEW_COLUMNS: usize = 1;
#[derive(Clone, Default, Copy)]
struct Occurrence {
passed_wild_card: bool,
@@ -71,6 +67,26 @@ impl PartialEq for Path {
impl Eq for Path {}
impl PartialOrd for Path {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Path {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(Path::Pair(a), Path::Pair(b))
| (Path::Tuple(a), Path::Tuple(b))
| (Path::List(a), Path::List(b))
| (Path::ListTail(a), Path::ListTail(b))
| (Path::Constr(_, a), Path::Constr(_, b)) => a.cmp(b),
(Path::OpaqueConstr(_), Path::OpaqueConstr(_)) => Ordering::Equal,
_ => Ordering::Equal,
}
}
}
#[derive(Clone, Debug)]
pub struct Assigned {
pub path: Vec<Path>,
@@ -568,7 +584,7 @@ impl<'a> DecisionTree<'a> {
}
}
impl<'a> Display for DecisionTree<'a> {
impl Display for DecisionTree<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.to_pretty())
}
@@ -603,33 +619,86 @@ impl<'a, 'b> TreeGen<'a, 'b> {
) -> DecisionTree<'a> {
let mut hoistables = IndexMap::new();
let rows = clauses
.iter()
.enumerate()
.map(|(index, clause)| {
// Assigns are split out from patterns so they can be handled
// outside of the tree algorithm
let (assign, row_items) =
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
let mut columns_added = vec![];
self.interner.intern(format!("__clause_then_{}", index));
let clause_then_name = self
.interner
.lookup_interned(&format!("__clause_then_{}", index));
let rows = {
let rows_initial = clauses
.iter()
.enumerate()
.map(|(index, clause)| {
// Assigns are split out from patterns so they can be handled
// outside of the tree algorithm
let (assign, row_items) =
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
self.interner.intern(format!("__clause_then_{}", index));
let clause_then_name = self
.interner
.lookup_interned(&format!("__clause_then_{}", index));
let row = Row {
assigns: assign.into_iter().collect_vec(),
columns: row_items,
then: clause_then_name,
};
hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
self.interner.pop_text(format!("__clause_then_{}", index));
// Some good ol' mutation to track added columns per relevant path
// relevant path indicating a column that has a pattern to test at some point in
// one of the rows
row_items.iter().for_each(|col| {
if !columns_added.contains(&col.path) {
columns_added.push(col.path.clone());
}
});
row
})
.collect_vec();
let row = Row {
assigns: assign.into_iter().collect_vec(),
columns: row_items,
then: clause_then_name,
};
self.interner.pop_text(format!("__clause_then_{}", index));
row
})
.collect_vec();
columns_added = columns_added
.into_iter()
.sorted_by(|a, b| {
let mut a = a.clone();
let mut b = b.clone();
// It's impossible for duplicates since we check before insertion
while let Ordering::Equal = a.first().cmp(&b.first()) {
a.remove(0);
b.remove(0);
}
a.first().cmp(&b.first())
})
.map(|col| {
// remove opaqueconstr paths since they are just type information
col.into_iter()
.filter(|a| !matches!(a, Path::OpaqueConstr(_)))
.collect_vec()
})
.collect_vec();
rows_initial
.into_iter()
.map(|mut row| {
for (index, path) in columns_added.iter().enumerate() {
if !row
.columns
.get(index)
.map(|col| col.path == *path)
.unwrap_or(false)
{
row.columns.insert(index, self.wild_card_pattern.clone());
}
}
row
})
.collect_vec()
};
let mut tree = self.do_build_tree(subject_tipo, PatternMatrix { rows }, &mut hoistables);
@@ -654,6 +723,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
then_map: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
) -> DecisionTree<'a> {
let column_length = matrix.rows[0].columns.len();
// First step make sure all rows have same number of columns
// or something went wrong
assert!(matrix
@@ -734,6 +804,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
.unwrap();
let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let mut relevant_columns: Vec<(CaseTest, Vec<Vec<Path>>)> = vec![];
// Time to split on the matrices based on case to test for or lack of
let (default_matrix, specialized_matrices) = matrix.rows.into_iter().fold(
@@ -841,37 +912,37 @@ impl<'a, 'b> TreeGen<'a, 'b> {
// Add inner patterns to existing row
let mut new_cols = remaining_patts.into_iter().flat_map(|x| x.1).collect_vec();
// To align number of columns we pop off the tail since it can
// never include a pattern besides wild card
if matches!(case, CaseTest::ListWithTail(_)) {
new_cols.pop();
}
let new_paths = new_cols.iter().map(|col| col.path.clone()).collect_vec();
let added_columns = new_cols.len();
// Pop off tail so that it aligns more easily with other list patterns
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
new_paths.iter().for_each(|col| {
if !a.1.contains(col) {
a.1.push(col.clone());
}
});
} else {
relevant_columns.push((case.clone(), new_paths.clone()));
};
new_cols.extend(row.columns);
row.columns = new_cols;
if let CaseTest::Wild = case {
let current_wild_cols = row.columns.len();
default_matrix.push(row.clone());
case_matrices.iter_mut().for_each(|(_, matrix)| {
let mut row = row.clone();
let total_cols = matrix[0].columns.len();
case_matrices.iter_mut().for_each(|(case, matrix)| {
let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == *case) else {
unreachable!()
};
if total_cols != 0 {
let added_columns = total_cols - current_wild_cols;
for _ in 0..added_columns {
row.columns.insert(0, self.wild_card_pattern.clone());
new_paths.iter().for_each(|col| {
if !a.1.contains(col) {
a.1.push(col.clone());
}
});
matrix.push(row);
}
matrix.push(row.clone());
});
} else if let CaseTest::ListWithTail(tail_case_length) = case {
// For lists with tail it's a special case where we also add it to existing patterns
@@ -886,20 +957,30 @@ impl<'a, 'b> TreeGen<'a, 'b> {
for elem_count in tail_case_length..=longest_elems_no_tail {
let case = CaseTest::List(elem_count);
let mut row = row.clone();
// paths first
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
new_paths.iter().for_each(|col| {
if !a.1.contains(col) {
a.1.push(col.clone());
}
});
} else {
relevant_columns.push((case.clone(), new_paths.clone()));
};
for _ in 0..(elem_count - tail_case_length) {
row.columns
.insert(tail_case_length, self.wild_card_pattern.clone());
let row = row.clone();
// now insertion into the appropriate matrix
if let Some(entry) =
case_matrices.iter_mut().find(|item| item.0 == case)
{
entry.1.push(row);
} else {
let mut rows = default_matrix.to_vec();
rows.push(row);
case_matrices.push((case, rows));
}
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
}
}
@@ -907,35 +988,123 @@ impl<'a, 'b> TreeGen<'a, 'b> {
for elem_count in tail_case_length..=longest_elems_with_tail {
let case = CaseTest::ListWithTail(elem_count);
let mut row = row.clone();
if let Some(a) = relevant_columns.iter_mut().find(|a| a.0 == case) {
new_paths.iter().for_each(|col| {
if !a.1.contains(col) {
a.1.push(col.clone());
}
});
} else {
relevant_columns.push((case.clone(), new_paths.clone()));
};
for _ in 0..(elem_count - tail_case_length) {
row.columns
.insert(tail_case_length, self.wild_card_pattern.clone());
let row = row.clone();
if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
entry.1.push(row);
} else {
let mut rows = default_matrix.clone();
rows.push(row);
case_matrices.push((case, rows));
}
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
}
} else if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
entry.1.push(row);
} else {
self.insert_case(
&mut case_matrices,
case,
&default_matrix,
row,
added_columns,
);
let mut rows = default_matrix.clone();
rows.push(row);
case_matrices.push((case, rows));
}
(default_matrix, case_matrices)
},
);
let (default_relevant_cols, relevant_columns): (Vec<_>, Vec<_>) = relevant_columns
.into_iter()
.map(|(case, paths)| {
(
case,
paths
.into_iter()
.sorted_by(|a, b| {
let mut a = a.clone();
let mut b = b.clone();
// It's impossible for duplicates since we check before insertion
while let Ordering::Equal = a.first().cmp(&b.first()) {
a.remove(0);
b.remove(0);
}
a.first().cmp(&b.first())
})
.map(|col| {
// remove opaqueconstr paths since they are just type information
col.into_iter()
.filter(|a| !matches!(a, Path::OpaqueConstr(_)))
.collect_vec()
})
.collect_vec(),
)
})
.partition_map(|(case, paths)| match case {
CaseTest::Wild => Either::Left(paths),
_ => Either::Right((case, paths)),
});
let default_matrix = default_matrix
.into_iter()
.map(|mut row| {
for (index, path) in default_relevant_cols[0].iter().enumerate() {
if !row
.columns
.get(index)
.map(|col| col.path == *path)
.unwrap_or(false)
{
row.columns.insert(index, self.wild_card_pattern.clone());
}
}
row
})
.collect_vec();
let specialized_matrices = specialized_matrices
.into_iter()
.map(|(case, matrix)| {
let Some((_, relevant_cols)) = relevant_columns
.iter()
.find(|(relevant_case, _)| relevant_case == &case)
else {
unreachable!()
};
(
case,
matrix
.into_iter()
.map(|mut row| {
for (index, path) in relevant_cols.iter().enumerate() {
if !row
.columns
.get(index)
.map(|col| col.path == *path)
.unwrap_or(false)
{
row.columns.insert(index, self.wild_card_pattern.clone());
}
}
row
})
.collect_vec(),
)
})
.collect_vec();
let default_matrix = PatternMatrix {
rows: default_matrix,
};
@@ -1001,39 +1170,13 @@ impl<'a, 'b> TreeGen<'a, 'b> {
) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
let new_columns_added = if current_tipo.is_pair() {
PAIR_NEW_COLUMNS
} else if current_tipo.is_tuple() {
let Type::Tuple { elems, .. } = current_tipo.as_ref() else {
unreachable!("{:#?}", current_tipo)
};
elems.len()
} else if let Some(data) = lookup_data_type_by_tipo(self.data_types, &current_tipo) {
if data.constructors.len() == 1 {
data.constructors[0].arguments.len()
} else if data.is_never() {
0
} else {
MIN_NEW_COLUMNS
}
} else {
MIN_NEW_COLUMNS
};
match pattern {
Pattern::Var { name, .. } => (
vec![Assigned {
path: path.clone(),
assigned: name.clone(),
}],
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
vec![],
),
Pattern::Assign { name, pattern, .. } => {
@@ -1049,19 +1192,14 @@ impl<'a, 'b> TreeGen<'a, 'b> {
);
(assigns, patts)
}
Pattern::Int { .. }
| Pattern::ByteArray { .. }
| Pattern::Discard { .. }
| Pattern::List { .. } => (
Pattern::Discard { .. } => (vec![], vec![]),
Pattern::Int { .. } | Pattern::ByteArray { .. } | Pattern::List { .. } => (
vec![],
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
}],
),
Pattern::Constructor {
@@ -1101,11 +1239,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
vec![RowItem {
pattern,
path: path.clone(),
}]
.into_iter()
.cycle()
.take(new_columns_added)
.collect_vec(),
}],
)
}
}
@@ -1126,6 +1260,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
(assigns, patts)
}
Pattern::Tuple { elems, .. } => {
elems
.iter()
@@ -1146,30 +1281,6 @@ impl<'a, 'b> TreeGen<'a, 'b> {
}
}
}
fn insert_case(
&self,
case_matrices: &mut Vec<(CaseTest, Vec<Row<'a>>)>,
case: CaseTest,
default_matrix: &[Row<'a>],
new_row: Row<'a>,
added_columns: usize,
) {
if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
entry.1.push(new_row);
} else {
let mut rows = default_matrix.to_vec();
for _ in 0..added_columns {
for row in &mut rows {
row.columns.insert(0, self.wild_card_pattern.clone());
}
}
rows.push(new_row);
case_matrices.push((case, rows));
}
}
}
pub fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {

View File

@@ -622,7 +622,7 @@ pub struct Counterexample<'a> {
pub cache: Cache<'a, PlutusData>,
}
impl<'a> Counterexample<'a> {
impl Counterexample<'_> {
fn consider(&mut self, choices: &[u8]) -> bool {
if choices == self.choices {
return true;

View File

@@ -1284,7 +1284,7 @@ fn suggest_pattern(
}
}
fn suggest_generic(name: &String, expected: usize) -> String {
fn suggest_generic(name: &str, expected: usize) -> String {
if expected == 0 {
return name.to_doc().to_pretty_string(70);
}

View File

@@ -206,7 +206,7 @@ impl Printer {
}
}
fn qualify_type_name(module: &String, typ_name: &str) -> Document<'static> {
fn qualify_type_name(module: &str, typ_name: &str) -> Document<'static> {
if module.is_empty() {
docvec!["aiken.", Document::String(typ_name.to_string())]
} else {