Inline now handles (if cond then body else error) patterns.

This allows conditions like ```expect x == 1``` to match performance with ```x == 1 && ...```

Also change builtins forcing to accommodate the new case-constr apply optimization
This commit is contained in:
microproofs
2025-01-09 15:41:47 +07:00
parent c130796f49
commit 19b4b9df0f
4 changed files with 298 additions and 127 deletions

View File

@@ -167,6 +167,7 @@ impl DefaultFunction {
)
}
/// For now all of the curry builtins are not forceable
/// Curryable builtins must take in 2 or more arguments
pub fn can_curry_builtin(self) -> bool {
matches!(
self,
@@ -205,7 +206,8 @@ impl DefaultFunction {
| DefaultFunction::MultiplyInteger
| DefaultFunction::EqualsInteger
| DefaultFunction::LessThanInteger
| DefaultFunction::LessThanEqualsInteger => arg_stack.iter().all(|arg| {
| DefaultFunction::LessThanEqualsInteger
| DefaultFunction::IData => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Integer(_))
} else {
@@ -226,10 +228,13 @@ impl DefaultFunction {
false
}
}),
DefaultFunction::EqualsByteString
DefaultFunction::LengthOfByteString
| DefaultFunction::EqualsByteString
| DefaultFunction::AppendByteString
| DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString => arg_stack.iter().all(|arg| {
| DefaultFunction::LessThanByteString
| DefaultFunction::DecodeUtf8
| DefaultFunction::BData => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::ByteString(_))
} else {
@@ -282,24 +287,26 @@ impl DefaultFunction {
}
}
DefaultFunction::EqualsString | DefaultFunction::AppendString => {
DefaultFunction::EqualsString
| DefaultFunction::AppendString
| DefaultFunction::EncodeUtf8 => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::String(_))
} else {
false
}
}),
DefaultFunction::EqualsData | DefaultFunction::SerialiseData => {
arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::String(_))
matches!(c.as_ref(), Constant::Data(_))
} else {
false
}
})
}
DefaultFunction::EqualsData => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Data(_))
} else {
false
}
}),
DefaultFunction::Bls12_381_G1_Equal | DefaultFunction::Bls12_381_G1_Add => {
arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
@@ -337,6 +344,10 @@ impl DefaultFunction {
_ => false,
}
}
pub fn wrapped_name(self) -> String {
format!("__{}_wrapped", self.aiken_name())
}
}
#[derive(PartialEq, Clone, Debug)]
@@ -1102,7 +1113,7 @@ impl Term<Name> {
}
}
Term::Delay(body) => {
let not_forced: isize = isize::from(force_stack.pop().is_none());
let not_forced = isize::from(force_stack.pop().is_none());
body.var_occurrences(search_for, arg_stack, force_stack)
.delay_if_found(not_forced)
@@ -1125,11 +1136,34 @@ impl Term<Name> {
}
}
Term::Apply { function, argument } => {
// unwrap apply and add void to arg stack!
arg_stack.push(());
function
.var_occurrences(search_for.clone(), arg_stack, force_stack)
.combine(argument.var_occurrences(search_for, vec![], vec![]))
let apply_var_occurrence_stack = |term: &Term<Name>, arg_stack: Vec<()>| {
term.var_occurrences(search_for.clone(), arg_stack, force_stack)
};
let apply_var_occurrence_no_stack =
|term: &Term<Name>| term.var_occurrences(search_for.clone(), vec![], vec![]);
if let Term::Apply {
function: next_func,
argument: next_arg,
} = function.as_ref()
{
// unwrap apply and add void to arg stack!
arg_stack.push(());
next_func.carry_args_to_branch(
next_arg,
argument,
arg_stack,
apply_var_occurrence_stack,
apply_var_occurrence_no_stack,
)
} else {
apply_var_occurrence_stack(function, arg_stack)
.combine(apply_var_occurrence_no_stack(argument))
}
}
Term::Force(x) => {
force_stack.push(());
@@ -1141,6 +1175,81 @@ impl Term<Name> {
}
}
// This handles the very common case of (if condition then body else error)
// or (if condition then error else body)
// In this case it is fine to treat the body as if it is not delayed
// since the other branch is error
fn carry_args_to_branch(
&self,
then_arg: &Rc<Term<Name>>,
else_arg: &Rc<Term<Name>>,
mut arg_stack: Vec<()>,
var_occurrence_stack: impl FnOnce(&Term<Name>, Vec<()>) -> VarLookup,
var_occurrence_no_stack: impl Fn(&Term<Name>) -> VarLookup,
) -> VarLookup {
let Term::Apply {
function: builtin,
argument: condition,
} = self
else {
return var_occurrence_stack(self, arg_stack)
.combine(var_occurrence_no_stack(then_arg))
.combine(var_occurrence_no_stack(else_arg));
};
// unwrap apply and add void to arg stack!
arg_stack.push(());
let Term::Delay(else_arg) = else_arg.as_ref() else {
return var_occurrence_stack(builtin, arg_stack)
.combine(var_occurrence_no_stack(condition))
.combine(var_occurrence_no_stack(then_arg))
.combine(var_occurrence_no_stack(else_arg));
};
let Term::Delay(then_arg) = then_arg.as_ref() else {
return var_occurrence_stack(builtin, arg_stack)
.combine(var_occurrence_no_stack(condition))
.combine(var_occurrence_no_stack(then_arg))
.combine(var_occurrence_no_stack(else_arg));
};
match builtin.as_ref() {
Term::Var(a)
if a.text == DefaultFunction::IfThenElse.wrapped_name()
|| a.text == DefaultFunction::ChooseList.wrapped_name() =>
{
if matches!(else_arg.as_ref(), Term::Error) {
// Pop 3 args of arg_stack due to branch execution
arg_stack.pop();
arg_stack.pop();
arg_stack.pop();
var_occurrence_no_stack(condition)
.combine(var_occurrence_stack(then_arg, arg_stack))
} else if matches!(then_arg.as_ref(), Term::Error) {
// Pop 3 args of arg_stack due to branch execution
arg_stack.pop();
arg_stack.pop();
arg_stack.pop();
var_occurrence_no_stack(condition)
.combine(var_occurrence_stack(else_arg, arg_stack))
} else {
var_occurrence_stack(builtin, arg_stack)
.combine(var_occurrence_no_stack(condition))
.combine(var_occurrence_no_stack(then_arg))
.combine(var_occurrence_no_stack(else_arg))
}
}
_ => var_occurrence_stack(builtin, arg_stack)
.combine(var_occurrence_no_stack(condition))
.combine(var_occurrence_no_stack(then_arg))
.combine(var_occurrence_no_stack(else_arg)),
}
}
fn lambda_reducer(
&mut self,
_id: Option<usize>,
@@ -1207,7 +1316,7 @@ impl Term<Name> {
if has_forces {
context.builtins_map.insert(*func as u8, ());
*self = Term::var(format!("__{}_wrapped", func.aiken_name()));
*self = Term::var(func.wrapped_name());
}
}
}
@@ -1269,44 +1378,37 @@ impl Term<Name> {
} => {
let body = Rc::make_mut(body);
// pops stack here no matter what
let temp = Term::Error;
if let (
arg_id,
Term::Lambda {
parameter_name: identity_name,
body: identity_body,
},
) = match &arg_stack.pop() {
Some(Args::Apply(
arg_id,
Term::Lambda {
parameter_name: inline_name,
body,
},
)) if inline_name.text == NO_INLINE => (*arg_id, body.as_ref()),
Some(Args::Apply(arg_id, term)) => (*arg_id, term),
_ => (0, &temp),
} {
let Term::Var(identity_var) = identity_body.as_ref() else {
return false;
};
let Some(Args::Apply(arg_id, identity_func)) = arg_stack.pop() else {
return false;
};
if identity_var.text == identity_name.text
&& identity_var.unique == identity_name.unique
let Term::Lambda {
parameter_name: identity_name,
body: identity_body,
} = identity_func.pierce_no_inlines()
else {
return false;
};
let Term::Var(identity_var) = identity_body.as_ref() else {
return false;
};
if identity_var.text == identity_name.text
&& identity_var.unique == identity_name.unique
{
// Replace all applied usages of identity with the arg
body.replace_identity_usage(parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter
// After attempting replacement
if !body
.var_occurrences(parameter_name.clone(), vec![], vec![])
.found
{
// Replace all applied usages of identity with the arg
body.replace_identity_usage(parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter
// After attempting replacement
if !body
.var_occurrences(parameter_name.clone(), vec![], vec![])
.found
{
changed = true;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
}
changed = true;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
}
}
}
@@ -1333,53 +1435,44 @@ impl Term<Name> {
body,
} => {
// pops stack here no matter what
if let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() {
let arg_term = match &arg_term {
Term::Lambda {
parameter_name,
body,
} if parameter_name.text == NO_INLINE => body.as_ref().clone(),
_ => arg_term,
};
let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() else {
return false;
};
let body = Rc::make_mut(body);
let arg_term = arg_term.pierce_no_inlines();
let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]);
let body = Rc::make_mut(body);
let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline)
|| matches!(
&arg_term,
Term::Var(_)
| Term::Constant(_)
| Term::Delay(_)
| Term::Lambda { .. }
| Term::Builtin(_),
);
let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]);
if var_lookup.occurrences == 1 && substitute_condition {
changed = true;
body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines());
let must_execute_condition = var_lookup.delays == 0 && !var_lookup.no_inline;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
let cant_throw_condition = matches!(
arg_term,
Term::Var(_)
| Term::Constant(_)
| Term::Delay(_)
| Term::Lambda { .. }
| Term::Builtin(_),
);
// This will strip out unused terms that can't throw an error by themselves
} else if !var_lookup.found
&& matches!(
arg_term,
Term::Var(_)
| Term::Constant(_)
| Term::Delay(_)
| Term::Lambda { .. }
| Term::Builtin(_)
)
{
changed = true;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
}
// This will inline terms that only occur once
// if they are guaranteed to execute or can't throw an error by themselves
if var_lookup.occurrences == 1 && (must_execute_condition || cant_throw_condition) {
changed = true;
body.substitute_var(parameter_name.clone(), arg_term);
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
// This will strip out unused terms that can't throw an error by themselves
} else if !var_lookup.found && cant_throw_condition {
changed = true;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
@@ -1647,10 +1740,7 @@ impl Term<Name> {
term.pierce_no_inlines()
})
.collect_vec();
if func.can_curry_builtin()
&& arg_stack.len() == func.arity()
&& func.is_error_safe(&args)
{
if arg_stack.len() == func.arity() && func.is_error_safe(&args) {
changed = true;
let applied_term =
arg_stack
@@ -1671,7 +1761,7 @@ impl Term<Name> {
}
.to_named_debruijn()
.unwrap()
.eval(ExBudget::max())
.eval(ExBudget::default())
.result()
.unwrap()
.try_into()
@@ -1827,13 +1917,17 @@ impl Program<Name> {
for default_func_index in context.builtins_map.keys().sorted().cloned() {
let default_func: DefaultFunction = default_func_index.try_into().unwrap();
term = term
.lambda(format!("__{}_wrapped", default_func.aiken_name()))
.apply(if default_func.force_count() == 1 {
Term::Builtin(default_func).force()
} else {
Term::Builtin(default_func).force().force()
});
term = term.lambda(default_func.wrapped_name());
}
for default_func_index in context.builtins_map.keys().sorted().cloned().rev() {
let default_func: DefaultFunction = default_func_index.try_into().unwrap();
term = term.apply(if default_func.force_count() == 1 {
Term::Builtin(default_func).force()
} else {
Term::Builtin(default_func).force().force()
});
}
let mut program = Program {
@@ -2165,10 +2259,11 @@ fn is_a_builtin_wrapper(term: &Term<Name>) -> bool {
while let Term::Apply { function, argument } = term {
match argument.as_ref() {
Term::Var(name) => arg_names.push(name),
Term::Var(name) => arg_names.push(format!("{}_{}", name.text, name.unique)),
Term::Constant(_) => {}
_ => {
//Break loop, it's not a builtin wrapper function
return false;
}
}
@@ -2178,7 +2273,7 @@ fn is_a_builtin_wrapper(term: &Term<Name>) -> bool {
arg_names.iter().all(|item| names.contains(item)) && matches!(term, Term::Builtin(_))
}
fn pop_lambdas_and_get_names(term: &Term<Name>) -> (Vec<Rc<Name>>, &Term<Name>) {
fn pop_lambdas_and_get_names(term: &Term<Name>) -> (Vec<String>, &Term<Name>) {
let mut names = vec![];
let mut term = term;
@@ -2189,7 +2284,7 @@ fn pop_lambdas_and_get_names(term: &Term<Name>) -> (Vec<Rc<Name>>, &Term<Name>)
} = term
{
if parameter_name.text != NO_INLINE {
names.push(parameter_name.clone());
names.push(format!("{}_{}", parameter_name.text, parameter_name.unique));
}
term = body.as_ref();
}
@@ -2365,11 +2460,11 @@ mod tests {
.lambda("y")
// Forces are automatically applied by builder
.lambda("__cons_list_wrapped")
.apply(Term::mk_cons())
.lambda("__head_list_wrapped")
.apply(Term::head_list())
.lambda("__tail_list_wrapped")
.apply(Term::tail_list()),
.apply(Term::tail_list())
.apply(Term::head_list())
.apply(Term::mk_cons()),
};
compare_optimization(expected, program, |p| p.run_once_pass());
@@ -2453,9 +2548,9 @@ mod tests {
.apply(Term::data(Data::integer(5.into()))),
)
.lambda("__fst_pair_wrapped")
.apply(Term::fst_pair())
.lambda("__snd_pair_wrapped")
.apply(Term::snd_pair()),
.apply(Term::snd_pair())
.apply(Term::fst_pair()),
};
compare_optimization(expected, program, |p| p.run_once_pass());
@@ -2644,6 +2739,80 @@ mod tests {
});
}
#[test]
fn inline_reduce_if_then_else_then() {
let program: Program<Name> = Program {
version: (1, 0, 0),
term: Term::var("__if_then_else_wrapped")
.apply(Term::bool(true))
.apply(Term::sha3_256().apply(Term::var("x")).delay())
.apply(Term::Error.delay())
.force()
.lambda("x")
.apply(Term::sha3_256().apply(Term::byte_string(vec![])))
.lambda("__if_then_else_wrapped")
.apply(Term::Builtin(DefaultFunction::IfThenElse).force()),
};
let expected = Program {
version: (1, 0, 0),
term: Term::var("__if_then_else_wrapped")
.apply(Term::bool(true))
.apply(
Term::sha3_256()
.apply(Term::sha3_256().apply(Term::byte_string(vec![])))
.delay(),
)
.apply(Term::Error.delay())
.force()
.lambda("__if_then_else_wrapped")
.apply(Term::Builtin(DefaultFunction::IfThenElse).force()),
};
compare_optimization(expected, program, |p| {
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
term.inline_reducer(id, arg_stack, scope, context);
})
});
}
#[test]
fn inline_reduce_if_then_else_else() {
let program: Program<Name> = Program {
version: (1, 0, 0),
term: Term::var("__if_then_else_wrapped")
.apply(Term::bool(true))
.apply(Term::Error.delay())
.apply(Term::sha3_256().apply(Term::var("x")).delay())
.force()
.lambda("x")
.apply(Term::sha3_256().apply(Term::byte_string(vec![])))
.lambda("__if_then_else_wrapped")
.apply(Term::Builtin(DefaultFunction::IfThenElse).force()),
};
let expected = Program {
version: (1, 0, 0),
term: Term::var("__if_then_else_wrapped")
.apply(Term::bool(true))
.apply(Term::Error.delay())
.apply(
Term::sha3_256()
.apply(Term::sha3_256().apply(Term::byte_string(vec![])))
.delay(),
)
.force()
.lambda("__if_then_else_wrapped")
.apply(Term::Builtin(DefaultFunction::IfThenElse).force()),
};
compare_optimization(expected, program, |p| {
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
term.inline_reducer(id, arg_stack, scope, context);
})
});
}
#[test]
fn inline_reduce_0_occurrence() {
let program: Program<Name> = Program {