diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 292b2055..0eca13e0 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1273,7 +1273,7 @@ impl Program { pub fn builtin_curry_reducer(self) -> Self { let mut curried_terms = vec![]; - let mut id_mapped_curry_terms: IndexMap, bool)> = + let mut id_mapped_curry_terms: IndexMap, usize)> = IndexMap::new(); let mut curry_applied_ids = vec![]; let mut scope_mapped_to_term: IndexMap)>> = @@ -1341,15 +1341,15 @@ impl Program { id_vec: id_only_vec, }; - if let Some((map_scope, _, multi_occurrences)) = + if let Some((map_scope, _, occurrences)) = id_mapped_curry_terms.get_mut(&curry_name) { *map_scope = map_scope.common_ancestor(scope); - *multi_occurrences = true; + *occurrences += 1; } else if id_vec.is_empty() { id_mapped_curry_terms.insert( curry_name, - (scope.clone(), Term::Builtin(*func).apply(node.term), false), + (scope.clone(), Term::Builtin(*func).apply(node.term), 1), ); } else { let var_name = id_vec_function_to_var( @@ -1359,7 +1359,7 @@ impl Program { id_mapped_curry_terms.insert( curry_name, - (scope.clone(), Term::var(var_name).apply(node.term), false), + (scope.clone(), Term::var(var_name).apply(node.term), 1), ); } } @@ -1372,7 +1372,8 @@ impl Program { id_mapped_curry_terms .into_iter() - .filter(|(_, (_, _, multi_occurrence))| *multi_occurrence) + // Only hoist for occurrences greater than 2 + .filter(|(_, (_, _, occurrences))| *occurrences > 2) .for_each(|(key, val)| { final_ids.insert(key.id_vec.clone(), ()); @@ -2129,7 +2130,12 @@ mod tests { .apply(Term::var("y")), ) .lambda("y") - .apply(Term::integer(5.into())), + .apply( + Term::add_integer() + .apply(Term::var("g")) + .apply(Term::integer(1.into())), + ) + .lambda("g"), }; let expected = Program { @@ -2138,8 +2144,44 @@ mod tests { .apply(Term::var("x")) .lambda("x") .apply(Term::var("add_one_curried").apply(Term::var("y"))) + .lambda("y") + .apply(Term::var("add_one_curried").apply(Term::var("g"))) .lambda("add_one_curried") .apply(Term::add_integer().apply(Term::integer(1.into()))) + .lambda("g"), + }; + + compare_optimization(expected, program, |p| p.builtin_curry_reducer()); + } + + #[test] + fn curry_reducer_test_2() { + let program: Program = Program { + version: (1, 0, 0), + term: Term::add_integer() + .apply(Term::var("x")) + .apply(Term::integer(1.into())) + .lambda("x") + .apply( + Term::add_integer() + .apply(Term::integer(1.into())) + .apply(Term::var("y")), + ) + .lambda("y") + .apply(Term::integer(5.into())), + }; + + let expected = Program { + version: (1, 0, 0), + term: Term::add_integer() + .apply(Term::var("x")) + .apply(Term::integer(1.into())) + .lambda("x") + .apply( + Term::add_integer() + .apply(Term::integer(1.into())) + .apply(Term::var("y")), + ) .lambda("y") .apply(Term::integer(5.into())), }; @@ -2148,7 +2190,7 @@ mod tests { } #[test] - fn curry_reducer_test_2() { + fn curry_reducer_test_3() { let program: Program = Program { version: (1, 0, 0), term: Term::var("equivalence") @@ -2340,7 +2382,11 @@ mod tests { Term::var("equals_integer_0_curried") .apply(Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_0"))) .if_then_else( - Term::var("equals_integer_0_tuple_index_1_tag_curried") + Term::var("equals_integer_0_curried") + .apply( + Term::var(CONSTR_INDEX_EXPOSER) + .apply(Term::var("tuple_index_1")), + ) .if_then_else( Term::var("equals_integer_0_curried") .apply( @@ -2431,7 +2477,11 @@ mod tests { .force() .lambda("clauses_delayed") .apply( - Term::var("equals_integer_1_tuple_index_0_tag_curried") + Term::var("equals_integer_1_curried") + .apply( + Term::var(CONSTR_INDEX_EXPOSER) + .apply(Term::var("tuple_index_0")), + ) .if_then_else( Term::var("equals_integer_1_curried") .apply( @@ -2449,9 +2499,17 @@ mod tests { .force() .lambda("clauses_delayed") .apply( - Term::var("equals_integer_1_tuple_index_0_tag_curried") + Term::var("equals_integer_1_curried") + .apply( + Term::var(CONSTR_INDEX_EXPOSER) + .apply(Term::var("tuple_index_0")), + ) .if_then_else( - Term::var("equals_integer_0_tuple_index_1_tag_curried") + Term::var("equals_integer_0_curried") + .apply( + Term::var(CONSTR_INDEX_EXPOSER) + .apply(Term::var("tuple_index_1")), + ) .if_then_else( Term::bool(false).delay(), Term::var("clauses_delayed"), @@ -2465,21 +2523,10 @@ mod tests { .apply(Term::bool(false).delay()) .delay(), ) - .lambda("equals_integer_1_tuple_index_0_tag_curried") - .apply( - Term::var("equals_integer_1_curried").apply( - Term::var(CONSTR_INDEX_EXPOSER) - .apply(Term::var("tuple_index_0")), - ), - ) .lambda("equals_integer_1_curried") .apply(Term::equals_integer().apply(Term::integer(1.into()))) .delay(), ) - .lambda("equals_integer_0_tuple_index_1_tag_curried") - .apply(Term::var("equals_integer_0_curried").apply( - Term::var(CONSTR_INDEX_EXPOSER).apply(Term::var("tuple_index_1")), - )) .lambda("equals_integer_0_curried") .apply(Term::equals_integer().apply(Term::integer(0.into()))) .lambda("tuple_index_0")