From f7f68fbafcf55a0084eefbb966894757eecfac6a Mon Sep 17 00:00:00 2001 From: microproofs Date: Fri, 10 Jan 2025 12:11:19 +0700 Subject: [PATCH] Add writeBits back in and use the optimizer to utilize the list conversion --- crates/aiken-lang/src/builtins.rs | 8 +- crates/aiken-lang/src/gen_uplc.rs | 8 +- crates/uplc/src/ast.rs | 2 +- crates/uplc/src/builder.rs | 32 +++ crates/uplc/src/builtins.rs | 14 +- crates/uplc/src/optimize/shrinker.rs | 222 ++++++++++++++++----- examples/acceptance_tests/117/aiken.toml | 2 + examples/acceptance_tests/117/lib/tests.ak | 17 ++ 8 files changed, 238 insertions(+), 67 deletions(-) create mode 100644 examples/acceptance_tests/117/aiken.toml create mode 100644 examples/acceptance_tests/117/lib/tests.ak diff --git a/crates/aiken-lang/src/builtins.rs b/crates/aiken-lang/src/builtins.rs index 08179ba3..346f1dbe 100644 --- a/crates/aiken-lang/src/builtins.rs +++ b/crates/aiken-lang/src/builtins.rs @@ -509,12 +509,8 @@ pub fn plutus(id_gen: &IdGenerator) -> TypeInfo { }; for builtin in DefaultFunction::iter() { - // FIXME: Disabling WriteBits for now, since its signature requires the ability to create - // list of raw integers, which isn't possible through Aiken at the moment. - if !matches!(builtin, DefaultFunction::WriteBits) { - let value = from_default_function(builtin, id_gen); - plutus.values.insert(builtin.aiken_name(), value); - } + let value = from_default_function(builtin, id_gen); + plutus.values.insert(builtin.aiken_name(), value); } let index_tipo = Type::function(vec![Type::data()], Type::int()); diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 55ba6ec6..2c4c2750 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -4028,7 +4028,7 @@ impl<'a> CodeGenerator<'a> { } else { let term = arg_stack.pop().unwrap(); - match term.pierce_no_inlines() { + match term.pierce_no_inlines_ref() { Term::Var(_) => Some(term.force()), Term::Delay(inner_term) => Some(inner_term.as_ref().clone()), Term::Apply { .. } => Some(term.force()), @@ -4356,7 +4356,7 @@ impl<'a> CodeGenerator<'a> { known_data_to_type(term, &tipo) }; - if extract_constant(term.pierce_no_inlines()).is_some() { + if extract_constant(term.pierce_no_inlines_ref()).is_some() { let mut program = self.new_program(term); let mut interner = CodeGenInterner::new(); @@ -4379,7 +4379,7 @@ impl<'a> CodeGenerator<'a> { Air::CastToData { tipo } => { let mut term = arg_stack.pop().unwrap(); - if extract_constant(term.pierce_no_inlines()).is_some() { + if extract_constant(term.pierce_no_inlines_ref()).is_some() { term = builder::convert_type_to_data(term, &tipo); let mut program = self.new_program(term); @@ -4792,7 +4792,7 @@ impl<'a> CodeGenerator<'a> { .apply(term); if arg_vec.iter().all(|item| { - let maybe_const = extract_constant(item.pierce_no_inlines()); + let maybe_const = extract_constant(item.pierce_no_inlines_ref()); maybe_const.is_some() }) { let mut program = self.new_program(term); diff --git a/crates/uplc/src/ast.rs b/crates/uplc/src/ast.rs index 17bcf682..82c78eec 100644 --- a/crates/uplc/src/ast.rs +++ b/crates/uplc/src/ast.rs @@ -517,7 +517,7 @@ impl hash::Hash for Name { impl PartialEq for Name { fn eq(&self, other: &Self) -> bool { - self.unique == other.unique + self.unique == other.unique && self.text == other.text } } diff --git a/crates/uplc/src/builder.rs b/crates/uplc/src/builder.rs index 2b71b1e2..9b8cc6ca 100644 --- a/crates/uplc/src/builder.rs +++ b/crates/uplc/src/builder.rs @@ -9,6 +9,7 @@ pub const CONSTR_FIELDS_EXPOSER: &str = "__constr_fields_exposer"; pub const CONSTR_INDEX_EXPOSER: &str = "__constr_index_exposer"; pub const EXPECT_ON_LIST: &str = "__expect_on_list"; pub const INNER_EXPECT_ON_LIST: &str = "__inner_expect_on_list"; +pub const INDICES_CONVERTER: &str = "__indices_converter"; impl Term where @@ -82,6 +83,10 @@ where Term::Constant(Constant::ProtoList(Type::Data, vals).into()) } + pub fn int_values(vals: Vec) -> Self { + Term::Constant(Constant::ProtoList(Type::Integer, vals).into()) + } + pub fn empty_map() -> Self { Term::Constant( Constant::ProtoList(Type::Pair(Type::Data.into(), Type::Data.into()), vec![]).into(), @@ -546,6 +551,33 @@ impl Term { ) } + pub fn data_list_to_integer_list(self) -> Self { + self.lambda(INDICES_CONVERTER) + .apply(Term::var(INDICES_CONVERTER).apply(Term::var(INDICES_CONVERTER))) + .lambda(INDICES_CONVERTER) + .apply( + Term::var("xs") + .delayed_choose_list( + Term::int_values(vec![]), + Term::mk_cons() + .apply(Term::var("x")) + .apply( + Term::var(INDICES_CONVERTER) + .apply(Term::var(INDICES_CONVERTER)) + .apply(Term::var("rest")), + ) + .lambda("rest") + .apply(Term::tail_list().apply(Term::var("xs"))) + .lambda("x") + .apply( + Term::un_i_data().apply(Term::head_list().apply(Term::var("xs"))), + ), + ) + .lambda("xs") + .lambda(INDICES_CONVERTER), + ) + } + /// Introduce a let-binding for a given term. The callback receives a Term::Var /// whose name matches the given 'var_name'. Handy to re-use a same var across /// multiple lambda expressions. diff --git a/crates/uplc/src/builtins.rs b/crates/uplc/src/builtins.rs index 574bda47..6f315702 100644 --- a/crates/uplc/src/builtins.rs +++ b/crates/uplc/src/builtins.rs @@ -6,7 +6,19 @@ use strum_macros::EnumIter; /// All the possible builtin functions in Untyped Plutus Core. #[repr(u8)] #[allow(non_camel_case_types)] -#[derive(Debug, Clone, PartialEq, Eq, Copy, EnumIter, serde::Serialize, serde::Deserialize)] +#[derive( + Debug, + Clone, + PartialEq, + Eq, + Copy, + EnumIter, + serde::Serialize, + serde::Deserialize, + Hash, + PartialOrd, + Ord, +)] pub enum DefaultFunction { // Integer functions AddInteger = 0, diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index fb766f62..450047ca 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1,9 +1,9 @@ use super::interner::CodeGenInterner; use crate::{ ast::{Constant, Data, Name, NamedDeBruijn, Program, Term, Type}, - builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER}, + builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER, INDICES_CONVERTER}, builtins::DefaultFunction, - machine::{cost_model::ExBudget, runtime::Compressable}, + machine::{cost_model::ExBudget, runtime::Compressable, value::from_pallas_bigint}, }; use blst::{blst_p1, blst_p2}; use indexmap::IndexMap; @@ -572,28 +572,32 @@ impl CurriedArgs { mut fst_args, mut snd_args, }, - BuiltinArgs::TwoArgsAnyOrder { fst, snd }, + BuiltinArgs::TwoArgsAnyOrder { mut fst, snd }, ) => { - let mut switched = false; - let fst_args = if fst_args.iter_mut().any(|item| item.term == fst.1) { - fst_args + let (switched, fst_args) = if fst_args.iter_mut().any(|item| item.term == fst.1) { + (false, fst_args) } else if fst_args.iter_mut().any(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) { - switched = true; - fst_args + (true, fst_args) } else { fst_args.push(CurriedNode { id: fst.0, - term: fst.1.clone(), + // Replace the value here instead of cloning since + // switched must be false here + // I use Term::Error.force() since it's not a + // naturally occurring term in code gen. + term: std::mem::replace(&mut fst.1, Term::Error.force()), }); - fst_args + (false, fst_args) }; // If switched then put the first arg in the second arg slot let snd_args = if switched { + assert!(fst.1 != Term::Error.force()); + if snd_args.iter_mut().any(|item| item.term == fst.1) { snd_args } else { @@ -679,12 +683,13 @@ impl CurriedArgs { } } - fn get_id_args(&self, path: &BuiltinArgs) -> Option> { + // TODO: switch clones to memory moves out of path + fn get_id_args(&self, path: BuiltinArgs) -> Option> { match (self, path) { (CurriedArgs::TwoArgs { fst_args, snd_args }, BuiltinArgs::TwoArgs { fst, snd }) => { let arg = fst_args.iter().find(|item| fst.1 == item.term)?; - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) else { @@ -721,7 +726,7 @@ impl CurriedArgs { term: arg.term.clone(), }); - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => snd.1 == item.term, None => false, }) else { @@ -771,7 +776,7 @@ impl CurriedArgs { ) => { let arg = fst_args.iter().find(|item| fst.1 == item.term)?; - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) else { @@ -782,7 +787,7 @@ impl CurriedArgs { }]); }; - let Some(arg3) = thd_args.iter().find(|item| match thd { + let Some(arg3) = thd_args.iter().find(|item| match &thd { Some(thd) => item.term == thd.1, None => false, }) else { @@ -854,7 +859,7 @@ impl CurriedBuiltin { } } - pub fn get_id_args(&self, path: &BuiltinArgs) -> Option> { + pub fn get_id_args(&self, path: BuiltinArgs) -> Option> { self.args.get_id_args(path) } @@ -867,9 +872,11 @@ impl CurriedBuiltin { pub struct Context { pub inlined_apply_ids: Vec, pub constants_to_flip: Vec, - pub builtins_map: IndexMap, + pub write_bits_indices_arg: Vec, + pub builtins_map: IndexMap, pub blst_p1_list: Vec, pub blst_p2_list: Vec, + pub write_bits_convert: bool, pub node_count: usize, } @@ -903,6 +910,7 @@ impl Term { ); let apply_id = id_gen.next_id(); + // Here we must clone since we must leave the original AST alone arg_stack.push(Args::Apply(apply_id, arg.clone())); let func = Rc::make_mut(function); @@ -972,7 +980,7 @@ impl Term { Term::Lambda { parameter_name, body, - } if parameter_name.text == p.text && parameter_name.unique == p.unique => { + } if *parameter_name == p => { let body = Rc::make_mut(body); body.traverse_uplc_with_helper( scope, @@ -1019,15 +1027,28 @@ impl Term { inline_lambda, ); - for branch in branches { - branch.traverse_uplc_with_helper( + if branches.len() == 1 { + // save a potentially big clone + // where currently all cases will be 1 branch + branches[0].traverse_uplc_with_helper( scope, - arg_stack.clone(), + arg_stack, id_gen, with, context, inline_lambda, ); + } else { + for branch in branches { + branch.traverse_uplc_with_helper( + scope, + arg_stack.clone(), + id_gen, + with, + context, + inline_lambda, + ); + } } } Term::Constr { fields, .. } => { @@ -1063,16 +1084,14 @@ impl Term { fn substitute_var(&mut self, original: Rc, replace_with: &Term) { match self { - Term::Var(name) if name.text == original.text && name.unique == original.unique => { + Term::Var(name) if *name == original => { *self = replace_with.clone(); } Term::Delay(body) => Rc::make_mut(body).substitute_var(original, replace_with), Term::Lambda { parameter_name, body, - } if parameter_name.text != original.text - || parameter_name.unique != original.unique => - { + } if *parameter_name != original => { Rc::make_mut(body).substitute_var(original, replace_with); } Term::Apply { function, argument } => { @@ -1097,8 +1116,7 @@ impl Term { parameter_name, body, } => { - if parameter_name.text != original.text || parameter_name.unique != original.unique - { + if *parameter_name != original { Rc::make_mut(body).replace_identity_usage(original.clone()); } } @@ -1113,7 +1131,7 @@ impl Term { return; }; - if name.text == original.text && name.unique == original.unique { + if *name == original { *self = std::mem::replace(arg, Term::Error.force()); } } @@ -1134,7 +1152,7 @@ impl Term { ) -> VarLookup { match self { Term::Var(name) => { - if name.text == search_for.text && name.unique == search_for.unique { + if *name == search_for { VarLookup::new_found() } else { VarLookup::new() @@ -1153,9 +1171,7 @@ impl Term { if parameter_name.text == NO_INLINE { body.var_occurrences(search_for, arg_stack, force_stack) .no_inline_if_found() - } else if parameter_name.text == search_for.text - && parameter_name.unique == search_for.unique - { + } else if *parameter_name == search_for { VarLookup::new() } else { let not_applied = usize::from(arg_stack.pop().is_none()); @@ -1312,7 +1328,10 @@ impl Term { let body = Rc::make_mut(body); context.inlined_apply_ids.push(arg_id); - body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + body.substitute_var( + parameter_name.clone(), + arg_term.pierce_no_inlines_ref(), + ); // creates new body that replaces all var occurrences with the arg *self = std::mem::replace(body, Term::Error.force()); } @@ -1343,7 +1362,7 @@ impl Term { } if has_forces { - context.builtins_map.insert(*func as u8, ()); + context.builtins_map.insert(*func, ()); *self = Term::var(func.wrapped_name()); } } @@ -1445,6 +1464,58 @@ impl Term { } } } + // List in Aiken is actually List> + // So now we want to convert writeBits arg List> to List + // Important: Only runs once and at the end. + fn write_bits_convert_arg( + &mut self, + id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + match self { + Term::Apply { argument, .. } => { + let id = id.unwrap(); + + if context.write_bits_indices_arg.contains(&id) { + match Rc::make_mut(argument) { + Term::Constant(constant) => { + let Constant::ProtoList(tipo, items) = Rc::make_mut(constant) else { + unreachable!(); + }; + + assert!(*tipo == Type::Data); + *tipo = Type::Integer; + + for item in items { + let Constant::Data(PlutusData::BigInt(i)) = item else { + unreachable!(); + }; + + *item = Constant::Integer(from_pallas_bigint(i)); + } + } + arg => { + context.write_bits_convert = true; + *arg = Term::var(INDICES_CONVERTER) + .apply(std::mem::replace(arg, Term::Error.force())); + } + } + } + } + + Term::Builtin(DefaultFunction::WriteBits) => { + // first arg not needed + arg_stack.pop(); + + if let Some(Args::Apply(arg_id, _)) = arg_stack.pop() { + context.write_bits_indices_arg.push(arg_id); + } + } + _ => (), + } + } fn identity_reducer( &mut self, @@ -1479,9 +1550,7 @@ impl Term { return false; }; - if identity_var.text == identity_name.text - && identity_var.unique == identity_name.unique - { + if *identity_var == identity_name { // Replace all applied usages of identity with the arg body.replace_identity_usage(parameter_name.clone()); // Have to check if the body still has any occurrences of the parameter @@ -1523,7 +1592,7 @@ impl Term { return false; }; - let arg_term = arg_term.pierce_no_inlines(); + let arg_term = arg_term.pierce_no_inlines_ref(); let body = Rc::make_mut(body); @@ -1540,9 +1609,17 @@ impl Term { | Term::Builtin(_), ); + let force_wrapped_builtin = context + .builtins_map + .keys() + .any(|b| b.wrapped_name() == parameter_name.text); + // This will inline terms that only occur once // if they are guaranteed to execute or can't throw an error by themselves - if var_lookup.occurrences == 1 && (must_execute_condition || cant_throw_condition) { + if !force_wrapped_builtin + && var_lookup.occurrences == 1 + && (must_execute_condition || cant_throw_condition) + { changed = true; body.substitute_var(parameter_name.clone(), arg_term); @@ -1550,7 +1627,7 @@ impl Term { *self = std::mem::replace(body, Term::Error.force()); // This will strip out unused terms that can't throw an error by themselves - } else if !var_lookup.found && cant_throw_condition { + } else if !var_lookup.found && (cant_throw_condition || force_wrapped_builtin) { changed = true; context.inlined_apply_ids.push(arg_id); *self = std::mem::replace(body, Term::Error.force()); @@ -1821,7 +1898,7 @@ impl Term { unreachable!() }; - term.pierce_no_inlines() + term.pierce_no_inlines_ref() }) .collect_vec(); if arg_stack.len() == func.arity() && func.is_error_safe(&args) { @@ -1910,7 +1987,7 @@ impl Term { } } - pub fn pierce_no_inlines(&self) -> &Self { + pub fn pierce_no_inlines_ref(&self) -> &Self { let mut term = self; while let Term::Lambda { @@ -1927,6 +2004,24 @@ impl Term { term } + + pub fn pierce_no_inlines(mut self) -> Self { + let term = &mut self; + + while let Term::Lambda { + parameter_name, + body, + } = term + { + if parameter_name.as_ref().text == NO_INLINE { + *term = std::mem::replace(Rc::make_mut(body), Term::Error.force()); + } else { + break; + } + } + + std::mem::replace(term, Term::Error.force()) + } } impl Program { @@ -1943,9 +2038,11 @@ impl Program { let mut context = Context { inlined_apply_ids: vec![], constants_to_flip: vec![], + write_bits_indices_arg: vec![], builtins_map: IndexMap::new(), blst_p1_list: vec![], blst_p2_list: vec![], + write_bits_convert: false, node_count: 0, }; @@ -2074,19 +2171,34 @@ impl Program { self.traverse_uplc_with(inline_lambda, &mut |id, term, arg_stack, scope, context| { with(id, term, arg_stack, scope, context); term.flip_constants(id, vec![], scope, context); + term.remove_inlined_ids(id, vec![], scope, context); }) .0 } pub fn clean_up(self, case: bool) -> Self { - self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { - term.remove_no_inlines(id, vec![], scope, context); - if case { - term.case_constr_apply_reducer(id, vec![], scope, context); - } - }) - .0 + let (mut program, context) = + self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { + term.remove_no_inlines(id, vec![], scope, context); + term.write_bits_convert_arg(id, arg_stack, scope, context); + + if case { + term.case_constr_apply_reducer(id, vec![], scope, context); + } + }); + + if context.write_bits_convert { + program.term = program.term.data_list_to_integer_list(); + } + + let mut interner = CodeGenInterner::new(); + + interner.program(&mut program); + + let program = Program::::try_from(program).unwrap(); + + Program::::try_from(program).unwrap() } // This one doesn't use the context since it's complicated and traverses the ast twice @@ -2133,13 +2245,13 @@ impl Program { let curried_builtin = curried_builtin.merge_node_by_path(builtin_args.clone()); - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { - unreachable!(); - }; - flipped_terms .insert(scope.clone(), curried_builtin.is_flipped(&builtin_args)); + let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else { + unreachable!(); + }; + curried_terms.push(curried_builtin); id_vec @@ -2147,7 +2259,7 @@ impl Program { // Brand new buitlin so we add it to the list let curried_builtin = builtin_args.clone().args_to_curried_args(*func); - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { + let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else { unreachable!(); }; @@ -2246,7 +2358,7 @@ impl Program { let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func); - let Some(mut id_vec) = curried_builtin.get_id_args(&builtin_args) else { + let Some(mut id_vec) = curried_builtin.get_id_args(builtin_args) else { return; }; diff --git a/examples/acceptance_tests/117/aiken.toml b/examples/acceptance_tests/117/aiken.toml new file mode 100644 index 00000000..86dfc5e8 --- /dev/null +++ b/examples/acceptance_tests/117/aiken.toml @@ -0,0 +1,2 @@ +name = "aiken-lang/acceptance_test_117" +version = "0.0.0" diff --git a/examples/acceptance_tests/117/lib/tests.ak b/examples/acceptance_tests/117/lib/tests.ak new file mode 100644 index 00000000..31322759 --- /dev/null +++ b/examples/acceptance_tests/117/lib/tests.ak @@ -0,0 +1,17 @@ +use aiken/builtin.{write_bits} + +test bar() { + let x = + if True { + [0, 1, 2, 3] + } else { + [0, 1] + } + + write_bits(#"f0", x, True) == #"ff" +} + +test baz() { + let x = [0, 1, 2, 3] + write_bits(#"f0", x, True) == #"ff" +}