change currying to happen with 3 or more occurrences

This commit is contained in:
microproofs 2024-03-07 00:01:24 -05:00 committed by Kasey
parent e9122de061
commit 1edd1a1fa3
1 changed files with 70 additions and 23 deletions

View File

@ -1273,7 +1273,7 @@ impl Program<Name> {
pub fn builtin_curry_reducer(self) -> Self { pub fn builtin_curry_reducer(self) -> Self {
let mut curried_terms = vec![]; let mut curried_terms = vec![];
let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, bool)> = let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, usize)> =
IndexMap::new(); IndexMap::new();
let mut curry_applied_ids = vec![]; let mut curry_applied_ids = vec![];
let mut scope_mapped_to_term: IndexMap<Scope, Vec<(CurriedName, Term<Name>)>> = let mut scope_mapped_to_term: IndexMap<Scope, Vec<(CurriedName, Term<Name>)>> =
@ -1341,15 +1341,15 @@ impl Program<Name> {
id_vec: id_only_vec, id_vec: id_only_vec,
}; };
if let Some((map_scope, _, multi_occurrences)) = if let Some((map_scope, _, occurrences)) =
id_mapped_curry_terms.get_mut(&curry_name) id_mapped_curry_terms.get_mut(&curry_name)
{ {
*map_scope = map_scope.common_ancestor(scope); *map_scope = map_scope.common_ancestor(scope);
*multi_occurrences = true; *occurrences += 1;
} else if id_vec.is_empty() { } else if id_vec.is_empty() {
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
curry_name, curry_name,
(scope.clone(), Term::Builtin(*func).apply(node.term), false), (scope.clone(), Term::Builtin(*func).apply(node.term), 1),
); );
} else { } else {
let var_name = id_vec_function_to_var( let var_name = id_vec_function_to_var(
@ -1359,7 +1359,7 @@ impl Program<Name> {
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
curry_name, curry_name,
(scope.clone(), Term::var(var_name).apply(node.term), false), (scope.clone(), Term::var(var_name).apply(node.term), 1),
); );
} }
} }
@ -1372,7 +1372,8 @@ impl Program<Name> {
id_mapped_curry_terms id_mapped_curry_terms
.into_iter() .into_iter()
.filter(|(_, (_, _, multi_occurrence))| *multi_occurrence) // Only hoist for occurrences greater than 2
.filter(|(_, (_, _, occurrences))| *occurrences > 2)
.for_each(|(key, val)| { .for_each(|(key, val)| {
final_ids.insert(key.id_vec.clone(), ()); final_ids.insert(key.id_vec.clone(), ());
@ -2129,7 +2130,12 @@ mod tests {
.apply(Term::var("y")), .apply(Term::var("y")),
) )
.lambda("y") .lambda("y")
.apply(Term::integer(5.into())), .apply(
Term::add_integer()
.apply(Term::var("g"))
.apply(Term::integer(1.into())),
)
.lambda("g"),
}; };
let expected = Program { let expected = Program {
@ -2138,8 +2144,44 @@ mod tests {
.apply(Term::var("x")) .apply(Term::var("x"))
.lambda("x") .lambda("x")
.apply(Term::var("add_one_curried").apply(Term::var("y"))) .apply(Term::var("add_one_curried").apply(Term::var("y")))
.lambda("y")
.apply(Term::var("add_one_curried").apply(Term::var("g")))
.lambda("add_one_curried") .lambda("add_one_curried")
.apply(Term::add_integer().apply(Term::integer(1.into()))) .apply(Term::add_integer().apply(Term::integer(1.into())))
.lambda("g"),
};
compare_optimization(expected, program, |p| p.builtin_curry_reducer());
}
#[test]
fn curry_reducer_test_2() {
let program: Program<Name> = Program {
version: (1, 0, 0),
term: Term::add_integer()
.apply(Term::var("x"))
.apply(Term::integer(1.into()))
.lambda("x")
.apply(
Term::add_integer()
.apply(Term::integer(1.into()))
.apply(Term::var("y")),
)
.lambda("y")
.apply(Term::integer(5.into())),
};
let expected = Program {
version: (1, 0, 0),
term: Term::add_integer()
.apply(Term::var("x"))
.apply(Term::integer(1.into()))
.lambda("x")
.apply(
Term::add_integer()
.apply(Term::integer(1.into()))
.apply(Term::var("y")),
)
.lambda("y") .lambda("y")
.apply(Term::integer(5.into())), .apply(Term::integer(5.into())),
}; };
@ -2148,7 +2190,7 @@ mod tests {
} }
#[test] #[test]
fn curry_reducer_test_2() { fn curry_reducer_test_3() {
let program: Program<Name> = Program { let program: Program<Name> = Program {
version: (1, 0, 0), version: (1, 0, 0),
term: Term::var("equivalence") term: Term::var("equivalence")
@ -2340,7 +2382,11 @@ mod tests {
Term::var("equals_integer_0_curried") Term::var("equals_integer_0_curried")
.apply(Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_0"))) .apply(Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_0")))
.if_then_else( .if_then_else(
Term::var("equals_integer_0_tuple_index_1_tag_curried") Term::var("equals_integer_0_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_1")),
)
.if_then_else( .if_then_else(
Term::var("equals_integer_0_curried") Term::var("equals_integer_0_curried")
.apply( .apply(
@ -2431,7 +2477,11 @@ mod tests {
.force() .force()
.lambda("clauses_delayed") .lambda("clauses_delayed")
.apply( .apply(
Term::var("equals_integer_1_tuple_index_0_tag_curried") Term::var("equals_integer_1_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
)
.if_then_else( .if_then_else(
Term::var("equals_integer_1_curried") Term::var("equals_integer_1_curried")
.apply( .apply(
@ -2449,9 +2499,17 @@ mod tests {
.force() .force()
.lambda("clauses_delayed") .lambda("clauses_delayed")
.apply( .apply(
Term::var("equals_integer_1_tuple_index_0_tag_curried") Term::var("equals_integer_1_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
)
.if_then_else( .if_then_else(
Term::var("equals_integer_0_tuple_index_1_tag_curried") Term::var("equals_integer_0_curried")
.apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_1")),
)
.if_then_else( .if_then_else(
Term::bool(false).delay(), Term::bool(false).delay(),
Term::var("clauses_delayed"), Term::var("clauses_delayed"),
@ -2465,21 +2523,10 @@ mod tests {
.apply(Term::bool(false).delay()) .apply(Term::bool(false).delay())
.delay(), .delay(),
) )
.lambda("equals_integer_1_tuple_index_0_tag_curried")
.apply(
Term::var("equals_integer_1_curried").apply(
Term::var(CONSTR_INDEX_EXPOSER)
.apply(Term::var("tuple_index_0")),
),
)
.lambda("equals_integer_1_curried") .lambda("equals_integer_1_curried")
.apply(Term::equals_integer().apply(Term::integer(1.into()))) .apply(Term::equals_integer().apply(Term::integer(1.into())))
.delay(), .delay(),
) )
.lambda("equals_integer_0_tuple_index_1_tag_curried")
.apply(Term::var("equals_integer_0_curried").apply(
Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_1")),
))
.lambda("equals_integer_0_curried") .lambda("equals_integer_0_curried")
.apply(Term::equals_integer().apply(Term::integer(0.into()))) .apply(Term::equals_integer().apply(Term::integer(0.into())))
.lambda("tuple_index_0") .lambda("tuple_index_0")