diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index fc4f95ef..e9dff2a2 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -199,6 +199,146 @@ impl DefaultFunction { | DefaultFunction::ConstrData ) } + + pub fn is_error_safe(self, arg_stack: &[&Term]) -> bool { + match self { + DefaultFunction::AddInteger + | DefaultFunction::SubtractInteger + | DefaultFunction::MultiplyInteger + | DefaultFunction::EqualsInteger + | DefaultFunction::LessThanInteger + | DefaultFunction::LessThanEqualsInteger => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::Integer(_)) + } else { + false + } + }), + DefaultFunction::DivideInteger + | DefaultFunction::ModInteger + | DefaultFunction::QuotientInteger + | DefaultFunction::RemainderInteger => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + if let Constant::Integer(i) = c.as_ref() { + *i != 0.into() + } else { + false + } + } else { + false + } + }), + DefaultFunction::EqualsByteString + | DefaultFunction::AppendByteString + | DefaultFunction::LessThanEqualsByteString + | DefaultFunction::LessThanByteString => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::ByteString(_)) + } else { + false + } + }), + + DefaultFunction::ConsByteString => { + if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) { + if let (Constant::Integer(i), Constant::ByteString(_)) = + (c.as_ref(), c2.as_ref()) + { + i >= &0.into() && i < &255.into() + } else { + false + } + } else { + false + } + } + + DefaultFunction::SliceByteString => { + if let (Term::Constant(c), Term::Constant(c2), Term::Constant(c3)) = + (&arg_stack[0], &arg_stack[1], &arg_stack[2]) + { + matches!( + (c.as_ref(), c2.as_ref(), c3.as_ref()), + ( + Constant::Integer(_), + Constant::Integer(_), + Constant::ByteString(_) + ) + ) + } else { + false + } + } + + DefaultFunction::IndexByteString => { + if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) { + if let (Constant::ByteString(bs), Constant::Integer(i)) = + (c.as_ref(), c2.as_ref()) + { + i >= &0.into() && i < &bs.len().into() + } else { + false + } + } else { + false + } + } + + DefaultFunction::EqualsString | DefaultFunction::AppendString => { + arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::String(_)) + } else { + false + } + }) + } + + DefaultFunction::EqualsData => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::Data(_)) + } else { + false + } + }), + + DefaultFunction::Bls12_381_G1_Equal | DefaultFunction::Bls12_381_G1_Add => { + arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::Bls12_381G1Element(_)) + } else { + false + } + }) + } + + DefaultFunction::Bls12_381_G2_Equal | DefaultFunction::Bls12_381_G2_Add => { + arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::Bls12_381G2Element(_)) + } else { + false + } + }) + } + + DefaultFunction::ConstrData => { + if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) { + if let (Constant::Integer(i), Constant::ProtoList(Type::Data, _)) = + (c.as_ref(), c2.as_ref()) + { + i >= &0.into() + } else { + false + } + } else { + false + } + } + + _ => false, + } + } } #[derive(PartialEq, Clone, Debug)] @@ -219,10 +359,12 @@ pub enum BuiltinArgs { } impl BuiltinArgs { - fn args_from_arg_stack(stack: Vec<(usize, Term)>, is_order_agnostic: bool) -> Self { + fn args_from_arg_stack(stack: Vec<(usize, Term)>, func: DefaultFunction) -> Self { + let error_safe = func.is_error_safe(&stack.iter().map(|(_, term)| term).collect_vec()); + let mut ordered_arg_stack = stack.into_iter().sorted_by(|(_, arg1), (_, arg2)| { // sort by constant first if the builtin is order agnostic - if is_order_agnostic { + if func.is_order_agnostic_builtin() { if matches!(arg1, Term::Constant(_)) == matches!(arg2, Term::Constant(_)) { Ordering::Equal } else if matches!(arg1, Term::Constant(_)) { @@ -235,23 +377,35 @@ impl BuiltinArgs { } }); - if ordered_arg_stack.len() == 2 && is_order_agnostic { + if ordered_arg_stack.len() == 2 && func.is_order_agnostic_builtin() { // This is the special case where the order of args is irrelevant to the builtin // An example is addInteger or multiplyInteger BuiltinArgs::TwoArgsAnyOrder { fst: ordered_arg_stack.next().unwrap(), - snd: ordered_arg_stack.next(), + snd: if error_safe { + ordered_arg_stack.next() + } else { + None + }, } } else if ordered_arg_stack.len() == 2 { BuiltinArgs::TwoArgs { fst: ordered_arg_stack.next().unwrap(), - snd: ordered_arg_stack.next(), + snd: if error_safe { + ordered_arg_stack.next() + } else { + None + }, } } else { BuiltinArgs::ThreeArgs { fst: ordered_arg_stack.next().unwrap(), snd: ordered_arg_stack.next(), - thd: ordered_arg_stack.next(), + thd: if error_safe { + ordered_arg_stack.next() + } else { + None + }, } } } @@ -855,7 +1009,7 @@ impl Program { pub fn lambda_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; - self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg @@ -905,7 +1059,7 @@ impl Program { pub fn builtin_force_reducer(self) -> Self { let mut builtin_map = IndexMap::new(); - let program = self.traverse_uplc_with(false, &mut |_id, term, _arg_stack, _scope| { + let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { if let Term::Force(f) = term { let f = Rc::make_mut(f); match f { @@ -965,7 +1119,7 @@ impl Program { pub fn identity_reducer(self) -> Self { let mut identity_applied_ids = vec![]; - self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg @@ -1074,7 +1228,7 @@ impl Program { pub fn inline_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; - self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| match term { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg let id = id.unwrap(); @@ -1140,7 +1294,7 @@ impl Program { } pub fn force_delay_reducer(self) -> Self { - self.traverse_uplc_with(false, &mut |_id, term, _arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { if let Term::Force(f) = term { let f = f.as_ref(); @@ -1152,7 +1306,7 @@ impl Program { } pub fn remove_no_inlines(self) -> Self { - self.traverse_uplc_with(false, &mut |_, term, _, _| match term { + self.traverse_uplc_with(true, &mut |_, term, _, _| match term { Term::Lambda { parameter_name, body, @@ -1162,7 +1316,7 @@ impl Program { } pub fn inline_constr_ops(self) -> Self { - self.traverse_uplc_with(false, &mut |_, term, _, _| { + self.traverse_uplc_with(true, &mut |_, term, _, _| { if let Term::Apply { function, argument } = term { if let Term::Var(name) = function.as_ref() { if name.text == CONSTR_FIELDS_EXPOSER { @@ -1184,7 +1338,7 @@ impl Program { pub fn cast_data_reducer(self) -> Self { let mut applied_ids = vec![]; - self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are apply some arg so now we unwrap the id of the applied arg @@ -1312,7 +1466,7 @@ impl Program { pub fn convert_arithmetic_ops(self) -> Self { let mut constants_to_flip = vec![]; - self.traverse_uplc_with(false, &mut |id, term, arg_stack, _scope| match term { + self.traverse_uplc_with(true, &mut |id, term, arg_stack, _scope| match term { Term::Apply { argument, .. } => { let id = id.unwrap(); @@ -1360,13 +1514,10 @@ impl Program { self.traverse_uplc_with(false, &mut |_id, term, arg_stack, scope| match term { Term::Builtin(func) => { if func.can_curry_builtin() && arg_stack.len() == func.arity() { - let is_order_agnostic = func.is_order_agnostic_builtin(); - // In the case of order agnostic builtins we want to sort the args by constant first // This gives us the opportunity to curry constants that often pop up in the code - let builtin_args = - BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic); + let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func); // First we see if we have already curried this builtin before let mut id_vec = if let Some((index, _)) = @@ -1480,10 +1631,7 @@ impl Program { arg_stack.reverse(); } - let builtin_args = BuiltinArgs::args_from_arg_stack( - arg_stack, - func.is_order_agnostic_builtin(), - ); + let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func); let Some(mut id_vec) = curried_builtin.get_id_args(&builtin_args) else { return; @@ -1594,14 +1742,14 @@ fn var_occurrences( if parameter_name.text == NO_INLINE { var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) .no_inline_if_found() - } else if parameter_name.text != search_for.text - || parameter_name.unique != search_for.unique + } else if parameter_name.text == search_for.text + && parameter_name.unique == search_for.unique { + VarLookup::new() + } else { let not_applied: isize = isize::from(arg_stack.pop().is_none()); var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) .delay_if_found(not_applied) - } else { - VarLookup::new() } } Term::Apply { function, argument } => { @@ -1646,15 +1794,15 @@ fn substitute_var(term: &Term, original: Rc, replace_with: &Term { - if parameter_name.text != original.text || parameter_name.unique != original.unique { + if parameter_name.text == original.text && parameter_name.unique == original.unique { Term::Lambda { parameter_name: parameter_name.clone(), - body: substitute_var(body.as_ref(), original, replace_with).into(), + body: body.clone(), } } else { Term::Lambda { parameter_name: parameter_name.clone(), - body: body.clone(), + body: substitute_var(body.as_ref(), original, replace_with).into(), } } } @@ -1676,15 +1824,15 @@ fn replace_identity_usage(term: &Term, original: Rc) -> Term { parameter_name, body, } => { - if parameter_name.text != original.text || parameter_name.unique != original.unique { + if parameter_name.text == original.text && parameter_name.unique == original.unique { Term::Lambda { parameter_name: parameter_name.clone(), - body: Rc::new(replace_identity_usage(body.as_ref(), original)), + body: body.clone(), } } else { Term::Lambda { parameter_name: parameter_name.clone(), - body: body.clone(), + body: Rc::new(replace_identity_usage(body.as_ref(), original)), } } }