change currying to happen with 3 or more occurrences

This commit is contained in:
microproofs 2024-03-07 00:01:24 -05:00 committed by Kasey
parent e9122de061
commit 1edd1a1fa3
1 changed files with 70 additions and 23 deletions

View File

@ -1273,7 +1273,7 @@ impl Program<Name> {
pub fn builtin_curry_reducer(self) -> Self {
let mut curried_terms = vec![];
let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, bool)> =
let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, usize)> =
IndexMap::new();
let mut curry_applied_ids = vec![];
let mut scope_mapped_to_term: IndexMap<Scope, Vec<(CurriedName, Term<Name>)>> =
@ -1341,15 +1341,15 @@ impl Program<Name> {
id_vec: id_only_vec,
};
if let Some((map_scope, _, multi_occurrences)) =
if let Some((map_scope, _, occurrences)) =
id_mapped_curry_terms.get_mut(&curry_name)
{
*map_scope = map_scope.common_ancestor(scope);
*multi_occurrences = true;
*occurrences += 1;
} else if id_vec.is_empty() {
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::Builtin(*func).apply(node.term), false),
(scope.clone(), Term::Builtin(*func).apply(node.term), 1),
);
} else {
let var_name = id_vec_function_to_var(
@ -1359,7 +1359,7 @@ impl Program<Name> {
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::var(var_name).apply(node.term), false),
(scope.clone(), Term::var(var_name).apply(node.term), 1),
);
}
}
@ -1372,7 +1372,8 @@ impl Program<Name> {
id_mapped_curry_terms
.into_iter()
.filter(|(_, (_, _, multi_occurrence))| *multi_occurrence)
// Only hoist for occurrences greater than 2
.filter(|(_, (_, _, occurrences))| *occurrences > 2)
.for_each(|(key, val)| {
final_ids.insert(key.id_vec.clone(), ());
@ -2129,7 +2130,12 @@ mod tests {
.apply(Term::var("y")),
)
.lambda("y")
.apply(Term::integer(5.into())),
.apply(
Term::add_integer()
.apply(Term::var("g"))
.apply(Term::integer(1.into())),
)
.lambda("g"),
};
let expected = Program {
@ -2138,8 +2144,44 @@ mod tests {
.apply(Term::var("x"))
.lambda("x")
.apply(Term::var("add_one_curried").apply(Term::var("y")))
.lambda("y")
.apply(Term::var("add_one_curried").apply(Term::var("g")))
.lambda("add_one_curried")
.apply(Term::add_integer().apply(Term::integer(1.into())))
.lambda("g"),
};
compare_optimization(expected, program, |p| p.builtin_curry_reducer());
}
#[test]
fn curry_reducer_test_2() {
let program: Program<Name> = Program {
version: (1, 0, 0),
term: Term::add_integer()
.apply(Term::var("x"))
.apply(Term::integer(1.into()))
.lambda("x")
.apply(
Term::add_integer()
.apply(Term::integer(1.into()))
.apply(Term::var("y")),
)
.lambda("y")
.apply(Term::integer(5.into())),
};
let expected = Program {
version: (1, 0, 0),
term: Term::add_integer()
.apply(Term::var("x"))
.apply(Term::integer(1.into()))
.lambda("x")
.apply(
Term::add_integer()
.apply(Term::integer(1.into()))
.apply(Term::var("y")),
)
.lambda("y")
.apply(Term::integer(5.into())),
};
@ -2148,7 +2190,7 @@ mod tests {
}
#[test]
fn curry_reducer_test_2() {
fn curry_reducer_test_3() {
let program: Program<Name> = Program {
version: (1, 0, 0),
term: Term::var("equivalence")
@ -2340,7 +2382,11 @@ mod tests {
Term::var("equals_integer_0_curried")
.apply(Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_0")))
.if_then_else(
Term::var("equals_integer_0_tuple_index_1_tag_curried")
Term::var("equals_integer_0_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_1")),
)
.if_then_else(
Term::var("equals_integer_0_curried")
.apply(
@ -2431,7 +2477,11 @@ mod tests {
.force()
.lambda("clauses_delayed")
.apply(
Term::var("equals_integer_1_tuple_index_0_tag_curried")
Term::var("equals_integer_1_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
)
.if_then_else(
Term::var("equals_integer_1_curried")
.apply(
@ -2449,9 +2499,17 @@ mod tests {
.force()
.lambda("clauses_delayed")
.apply(
Term::var("equals_integer_1_tuple_index_0_tag_curried")
Term::var("equals_integer_1_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
)
.if_then_else(
Term::var("equals_integer_0_tuple_index_1_tag_curried")
Term::var("equals_integer_0_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_1")),
)
.if_then_else(
Term::bool(false).delay(),
Term::var("clauses_delayed"),
@ -2465,21 +2523,10 @@ mod tests {
.apply(Term::bool(false).delay())
.delay(),
)
.lambda("equals_integer_1_tuple_index_0_tag_curried")
.apply(
Term::var("equals_integer_1_curried").apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
),
)
.lambda("equals_integer_1_curried")
.apply(Term::equals_integer().apply(Term::integer(1.into())))
.delay(),
)
.lambda("equals_integer_0_tuple_index_1_tag_curried")
.apply(Term::var("equals_integer_0_curried").apply(
Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_1")),
))
.lambda("equals_integer_0_curried")
.apply(Term::equals_integer().apply(Term::integer(0.into())))
.lambda("tuple_index_0")