Add writeBits back in and use the optimizer to utilize the list conversion

This commit is contained in:
microproofs 2025-01-10 12:11:19 +07:00
parent 19d0ec23cf
commit f7f68fbafc
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
8 changed files with 238 additions and 67 deletions

View File

@ -509,12 +509,8 @@ pub fn plutus(id_gen: &IdGenerator) -> TypeInfo {
};
for builtin in DefaultFunction::iter() {
// FIXME: Disabling WriteBits for now, since its signature requires the ability to create
// list of raw integers, which isn't possible through Aiken at the moment.
if !matches!(builtin, DefaultFunction::WriteBits) {
let value = from_default_function(builtin, id_gen);
plutus.values.insert(builtin.aiken_name(), value);
}
let value = from_default_function(builtin, id_gen);
plutus.values.insert(builtin.aiken_name(), value);
}
let index_tipo = Type::function(vec![Type::data()], Type::int());

View File

@ -4028,7 +4028,7 @@ impl<'a> CodeGenerator<'a> {
} else {
let term = arg_stack.pop().unwrap();
match term.pierce_no_inlines() {
match term.pierce_no_inlines_ref() {
Term::Var(_) => Some(term.force()),
Term::Delay(inner_term) => Some(inner_term.as_ref().clone()),
Term::Apply { .. } => Some(term.force()),
@ -4356,7 +4356,7 @@ impl<'a> CodeGenerator<'a> {
known_data_to_type(term, &tipo)
};
if extract_constant(term.pierce_no_inlines()).is_some() {
if extract_constant(term.pierce_no_inlines_ref()).is_some() {
let mut program = self.new_program(term);
let mut interner = CodeGenInterner::new();
@ -4379,7 +4379,7 @@ impl<'a> CodeGenerator<'a> {
Air::CastToData { tipo } => {
let mut term = arg_stack.pop().unwrap();
if extract_constant(term.pierce_no_inlines()).is_some() {
if extract_constant(term.pierce_no_inlines_ref()).is_some() {
term = builder::convert_type_to_data(term, &tipo);
let mut program = self.new_program(term);
@ -4792,7 +4792,7 @@ impl<'a> CodeGenerator<'a> {
.apply(term);
if arg_vec.iter().all(|item| {
let maybe_const = extract_constant(item.pierce_no_inlines());
let maybe_const = extract_constant(item.pierce_no_inlines_ref());
maybe_const.is_some()
}) {
let mut program = self.new_program(term);

View File

@ -517,7 +517,7 @@ impl hash::Hash for Name {
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.unique == other.unique
self.unique == other.unique && self.text == other.text
}
}

View File

@ -9,6 +9,7 @@ pub const CONSTR_FIELDS_EXPOSER: &str = "__constr_fields_exposer";
pub const CONSTR_INDEX_EXPOSER: &str = "__constr_index_exposer";
pub const EXPECT_ON_LIST: &str = "__expect_on_list";
pub const INNER_EXPECT_ON_LIST: &str = "__inner_expect_on_list";
pub const INDICES_CONVERTER: &str = "__indices_converter";
impl<T> Term<T>
where
@ -82,6 +83,10 @@ where
Term::Constant(Constant::ProtoList(Type::Data, vals).into())
}
pub fn int_values(vals: Vec<Constant>) -> Self {
Term::Constant(Constant::ProtoList(Type::Integer, vals).into())
}
pub fn empty_map() -> Self {
Term::Constant(
Constant::ProtoList(Type::Pair(Type::Data.into(), Type::Data.into()), vec![]).into(),
@ -546,6 +551,33 @@ impl Term<Name> {
)
}
pub fn data_list_to_integer_list(self) -> Self {
self.lambda(INDICES_CONVERTER)
.apply(Term::var(INDICES_CONVERTER).apply(Term::var(INDICES_CONVERTER)))
.lambda(INDICES_CONVERTER)
.apply(
Term::var("xs")
.delayed_choose_list(
Term::int_values(vec![]),
Term::mk_cons()
.apply(Term::var("x"))
.apply(
Term::var(INDICES_CONVERTER)
.apply(Term::var(INDICES_CONVERTER))
.apply(Term::var("rest")),
)
.lambda("rest")
.apply(Term::tail_list().apply(Term::var("xs")))
.lambda("x")
.apply(
Term::un_i_data().apply(Term::head_list().apply(Term::var("xs"))),
),
)
.lambda("xs")
.lambda(INDICES_CONVERTER),
)
}
/// Introduce a let-binding for a given term. The callback receives a Term::Var
/// whose name matches the given 'var_name'. Handy to re-use a same var across
/// multiple lambda expressions.

View File

@ -6,7 +6,19 @@ use strum_macros::EnumIter;
/// All the possible builtin functions in Untyped Plutus Core.
#[repr(u8)]
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, PartialEq, Eq, Copy, EnumIter, serde::Serialize, serde::Deserialize)]
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Copy,
EnumIter,
serde::Serialize,
serde::Deserialize,
Hash,
PartialOrd,
Ord,
)]
pub enum DefaultFunction {
// Integer functions
AddInteger = 0,

View File

@ -1,9 +1,9 @@
use super::interner::CodeGenInterner;
use crate::{
ast::{Constant, Data, Name, NamedDeBruijn, Program, Term, Type},
builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER},
builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER, INDICES_CONVERTER},
builtins::DefaultFunction,
machine::{cost_model::ExBudget, runtime::Compressable},
machine::{cost_model::ExBudget, runtime::Compressable, value::from_pallas_bigint},
};
use blst::{blst_p1, blst_p2};
use indexmap::IndexMap;
@ -572,28 +572,32 @@ impl CurriedArgs {
mut fst_args,
mut snd_args,
},
BuiltinArgs::TwoArgsAnyOrder { fst, snd },
BuiltinArgs::TwoArgsAnyOrder { mut fst, snd },
) => {
let mut switched = false;
let fst_args = if fst_args.iter_mut().any(|item| item.term == fst.1) {
fst_args
let (switched, fst_args) = if fst_args.iter_mut().any(|item| item.term == fst.1) {
(false, fst_args)
} else if fst_args.iter_mut().any(|item| match &snd {
Some(snd) => item.term == snd.1,
None => false,
}) {
switched = true;
fst_args
(true, fst_args)
} else {
fst_args.push(CurriedNode {
id: fst.0,
term: fst.1.clone(),
// Replace the value here instead of cloning since
// switched must be false here
// I use Term::Error.force() since it's not a
// naturally occurring term in code gen.
term: std::mem::replace(&mut fst.1, Term::Error.force()),
});
fst_args
(false, fst_args)
};
// If switched then put the first arg in the second arg slot
let snd_args = if switched {
assert!(fst.1 != Term::Error.force());
if snd_args.iter_mut().any(|item| item.term == fst.1) {
snd_args
} else {
@ -679,12 +683,13 @@ impl CurriedArgs {
}
}
fn get_id_args(&self, path: &BuiltinArgs) -> Option<Vec<UplcNode>> {
// TODO: switch clones to memory moves out of path
fn get_id_args(&self, path: BuiltinArgs) -> Option<Vec<UplcNode>> {
match (self, path) {
(CurriedArgs::TwoArgs { fst_args, snd_args }, BuiltinArgs::TwoArgs { fst, snd }) => {
let arg = fst_args.iter().find(|item| fst.1 == item.term)?;
let Some(arg2) = snd_args.iter().find(|item| match snd {
let Some(arg2) = snd_args.iter().find(|item| match &snd {
Some(snd) => item.term == snd.1,
None => false,
}) else {
@ -721,7 +726,7 @@ impl CurriedArgs {
term: arg.term.clone(),
});
let Some(arg2) = snd_args.iter().find(|item| match snd {
let Some(arg2) = snd_args.iter().find(|item| match &snd {
Some(snd) => snd.1 == item.term,
None => false,
}) else {
@ -771,7 +776,7 @@ impl CurriedArgs {
) => {
let arg = fst_args.iter().find(|item| fst.1 == item.term)?;
let Some(arg2) = snd_args.iter().find(|item| match snd {
let Some(arg2) = snd_args.iter().find(|item| match &snd {
Some(snd) => item.term == snd.1,
None => false,
}) else {
@ -782,7 +787,7 @@ impl CurriedArgs {
}]);
};
let Some(arg3) = thd_args.iter().find(|item| match thd {
let Some(arg3) = thd_args.iter().find(|item| match &thd {
Some(thd) => item.term == thd.1,
None => false,
}) else {
@ -854,7 +859,7 @@ impl CurriedBuiltin {
}
}
pub fn get_id_args(&self, path: &BuiltinArgs) -> Option<Vec<UplcNode>> {
pub fn get_id_args(&self, path: BuiltinArgs) -> Option<Vec<UplcNode>> {
self.args.get_id_args(path)
}
@ -867,9 +872,11 @@ impl CurriedBuiltin {
pub struct Context {
pub inlined_apply_ids: Vec<usize>,
pub constants_to_flip: Vec<usize>,
pub builtins_map: IndexMap<u8, ()>,
pub write_bits_indices_arg: Vec<usize>,
pub builtins_map: IndexMap<DefaultFunction, ()>,
pub blst_p1_list: Vec<blst_p1>,
pub blst_p2_list: Vec<blst_p2>,
pub write_bits_convert: bool,
pub node_count: usize,
}
@ -903,6 +910,7 @@ impl Term<Name> {
);
let apply_id = id_gen.next_id();
// Here we must clone since we must leave the original AST alone
arg_stack.push(Args::Apply(apply_id, arg.clone()));
let func = Rc::make_mut(function);
@ -972,7 +980,7 @@ impl Term<Name> {
Term::Lambda {
parameter_name,
body,
} if parameter_name.text == p.text && parameter_name.unique == p.unique => {
} if *parameter_name == p => {
let body = Rc::make_mut(body);
body.traverse_uplc_with_helper(
scope,
@ -1019,15 +1027,28 @@ impl Term<Name> {
inline_lambda,
);
for branch in branches {
branch.traverse_uplc_with_helper(
if branches.len() == 1 {
// save a potentially big clone
// where currently all cases will be 1 branch
branches[0].traverse_uplc_with_helper(
scope,
arg_stack.clone(),
arg_stack,
id_gen,
with,
context,
inline_lambda,
);
} else {
for branch in branches {
branch.traverse_uplc_with_helper(
scope,
arg_stack.clone(),
id_gen,
with,
context,
inline_lambda,
);
}
}
}
Term::Constr { fields, .. } => {
@ -1063,16 +1084,14 @@ impl Term<Name> {
fn substitute_var(&mut self, original: Rc<Name>, replace_with: &Term<Name>) {
match self {
Term::Var(name) if name.text == original.text && name.unique == original.unique => {
Term::Var(name) if *name == original => {
*self = replace_with.clone();
}
Term::Delay(body) => Rc::make_mut(body).substitute_var(original, replace_with),
Term::Lambda {
parameter_name,
body,
} if parameter_name.text != original.text
|| parameter_name.unique != original.unique =>
{
} if *parameter_name != original => {
Rc::make_mut(body).substitute_var(original, replace_with);
}
Term::Apply { function, argument } => {
@ -1097,8 +1116,7 @@ impl Term<Name> {
parameter_name,
body,
} => {
if parameter_name.text != original.text || parameter_name.unique != original.unique
{
if *parameter_name != original {
Rc::make_mut(body).replace_identity_usage(original.clone());
}
}
@ -1113,7 +1131,7 @@ impl Term<Name> {
return;
};
if name.text == original.text && name.unique == original.unique {
if *name == original {
*self = std::mem::replace(arg, Term::Error.force());
}
}
@ -1134,7 +1152,7 @@ impl Term<Name> {
) -> VarLookup {
match self {
Term::Var(name) => {
if name.text == search_for.text && name.unique == search_for.unique {
if *name == search_for {
VarLookup::new_found()
} else {
VarLookup::new()
@ -1153,9 +1171,7 @@ impl Term<Name> {
if parameter_name.text == NO_INLINE {
body.var_occurrences(search_for, arg_stack, force_stack)
.no_inline_if_found()
} else if parameter_name.text == search_for.text
&& parameter_name.unique == search_for.unique
{
} else if *parameter_name == search_for {
VarLookup::new()
} else {
let not_applied = usize::from(arg_stack.pop().is_none());
@ -1312,7 +1328,10 @@ impl Term<Name> {
let body = Rc::make_mut(body);
context.inlined_apply_ids.push(arg_id);
body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines());
body.substitute_var(
parameter_name.clone(),
arg_term.pierce_no_inlines_ref(),
);
// creates new body that replaces all var occurrences with the arg
*self = std::mem::replace(body, Term::Error.force());
}
@ -1343,7 +1362,7 @@ impl Term<Name> {
}
if has_forces {
context.builtins_map.insert(*func as u8, ());
context.builtins_map.insert(*func, ());
*self = Term::var(func.wrapped_name());
}
}
@ -1445,6 +1464,58 @@ impl Term<Name> {
}
}
}
// List<Int> in Aiken is actually List<Data<Int>>
// So now we want to convert writeBits arg List<Data<Int>> to List<Int>
// Important: Only runs once and at the end.
fn write_bits_convert_arg(
&mut self,
id: Option<usize>,
mut arg_stack: Vec<Args>,
_scope: &Scope,
context: &mut Context,
) {
match self {
Term::Apply { argument, .. } => {
let id = id.unwrap();
if context.write_bits_indices_arg.contains(&id) {
match Rc::make_mut(argument) {
Term::Constant(constant) => {
let Constant::ProtoList(tipo, items) = Rc::make_mut(constant) else {
unreachable!();
};
assert!(*tipo == Type::Data);
*tipo = Type::Integer;
for item in items {
let Constant::Data(PlutusData::BigInt(i)) = item else {
unreachable!();
};
*item = Constant::Integer(from_pallas_bigint(i));
}
}
arg => {
context.write_bits_convert = true;
*arg = Term::var(INDICES_CONVERTER)
.apply(std::mem::replace(arg, Term::Error.force()));
}
}
}
}
Term::Builtin(DefaultFunction::WriteBits) => {
// first arg not needed
arg_stack.pop();
if let Some(Args::Apply(arg_id, _)) = arg_stack.pop() {
context.write_bits_indices_arg.push(arg_id);
}
}
_ => (),
}
}
fn identity_reducer(
&mut self,
@ -1479,9 +1550,7 @@ impl Term<Name> {
return false;
};
if identity_var.text == identity_name.text
&& identity_var.unique == identity_name.unique
{
if *identity_var == identity_name {
// Replace all applied usages of identity with the arg
body.replace_identity_usage(parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter
@ -1523,7 +1592,7 @@ impl Term<Name> {
return false;
};
let arg_term = arg_term.pierce_no_inlines();
let arg_term = arg_term.pierce_no_inlines_ref();
let body = Rc::make_mut(body);
@ -1540,9 +1609,17 @@ impl Term<Name> {
| Term::Builtin(_),
);
let force_wrapped_builtin = context
.builtins_map
.keys()
.any(|b| b.wrapped_name() == parameter_name.text);
// This will inline terms that only occur once
// if they are guaranteed to execute or can't throw an error by themselves
if var_lookup.occurrences == 1 && (must_execute_condition || cant_throw_condition) {
if !force_wrapped_builtin
&& var_lookup.occurrences == 1
&& (must_execute_condition || cant_throw_condition)
{
changed = true;
body.substitute_var(parameter_name.clone(), arg_term);
@ -1550,7 +1627,7 @@ impl Term<Name> {
*self = std::mem::replace(body, Term::Error.force());
// This will strip out unused terms that can't throw an error by themselves
} else if !var_lookup.found && cant_throw_condition {
} else if !var_lookup.found && (cant_throw_condition || force_wrapped_builtin) {
changed = true;
context.inlined_apply_ids.push(arg_id);
*self = std::mem::replace(body, Term::Error.force());
@ -1821,7 +1898,7 @@ impl Term<Name> {
unreachable!()
};
term.pierce_no_inlines()
term.pierce_no_inlines_ref()
})
.collect_vec();
if arg_stack.len() == func.arity() && func.is_error_safe(&args) {
@ -1910,7 +1987,7 @@ impl Term<Name> {
}
}
pub fn pierce_no_inlines(&self) -> &Self {
pub fn pierce_no_inlines_ref(&self) -> &Self {
let mut term = self;
while let Term::Lambda {
@ -1927,6 +2004,24 @@ impl Term<Name> {
term
}
pub fn pierce_no_inlines(mut self) -> Self {
let term = &mut self;
while let Term::Lambda {
parameter_name,
body,
} = term
{
if parameter_name.as_ref().text == NO_INLINE {
*term = std::mem::replace(Rc::make_mut(body), Term::Error.force());
} else {
break;
}
}
std::mem::replace(term, Term::Error.force())
}
}
impl Program<Name> {
@ -1943,9 +2038,11 @@ impl Program<Name> {
let mut context = Context {
inlined_apply_ids: vec![],
constants_to_flip: vec![],
write_bits_indices_arg: vec![],
builtins_map: IndexMap::new(),
blst_p1_list: vec![],
blst_p2_list: vec![],
write_bits_convert: false,
node_count: 0,
};
@ -2074,19 +2171,34 @@ impl Program<Name> {
self.traverse_uplc_with(inline_lambda, &mut |id, term, arg_stack, scope, context| {
with(id, term, arg_stack, scope, context);
term.flip_constants(id, vec![], scope, context);
term.remove_inlined_ids(id, vec![], scope, context);
})
.0
}
pub fn clean_up(self, case: bool) -> Self {
self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
if case {
term.case_constr_apply_reducer(id, vec![], scope, context);
}
})
.0
let (mut program, context) =
self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
term.write_bits_convert_arg(id, arg_stack, scope, context);
if case {
term.case_constr_apply_reducer(id, vec![], scope, context);
}
});
if context.write_bits_convert {
program.term = program.term.data_list_to_integer_list();
}
let mut interner = CodeGenInterner::new();
interner.program(&mut program);
let program = Program::<NamedDeBruijn>::try_from(program).unwrap();
Program::<Name>::try_from(program).unwrap()
}
// This one doesn't use the context since it's complicated and traverses the ast twice
@ -2133,13 +2245,13 @@ impl Program<Name> {
let curried_builtin =
curried_builtin.merge_node_by_path(builtin_args.clone());
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
unreachable!();
};
flipped_terms
.insert(scope.clone(), curried_builtin.is_flipped(&builtin_args));
let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else {
unreachable!();
};
curried_terms.push(curried_builtin);
id_vec
@ -2147,7 +2259,7 @@ impl Program<Name> {
// Brand new buitlin so we add it to the list
let curried_builtin = builtin_args.clone().args_to_curried_args(*func);
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else {
unreachable!();
};
@ -2246,7 +2358,7 @@ impl Program<Name> {
let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func);
let Some(mut id_vec) = curried_builtin.get_id_args(&builtin_args) else {
let Some(mut id_vec) = curried_builtin.get_id_args(builtin_args) else {
return;
};

View File

@ -0,0 +1,2 @@
name = "aiken-lang/acceptance_test_117"
version = "0.0.0"

View File

@ -0,0 +1,17 @@
use aiken/builtin.{write_bits}
test bar() {
let x =
if True {
[0, 1, 2, 3]
} else {
[0, 1]
}
write_bits(#"f0", x, True) == #"ff"
}
test baz() {
let x = [0, 1, 2, 3]
write_bits(#"f0", x, True) == #"ff"
}