feat: finish curry optmization, improve inline optimization further, and add a subtract integer to add integer conversion

This commit is contained in:
microproofs 2024-02-19 11:31:20 -05:00 committed by Kasey
parent 7d8fdc0f22
commit 62963f7fc2
2 changed files with 85 additions and 63 deletions

View File

@ -1,32 +1,21 @@
use crate::{ use crate::ast::{Name, Program};
ast::{Name, NamedDeBruijn, Program},
parser::interner::Interner,
};
pub mod shrinker; pub mod shrinker;
pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> { pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> {
let mut program = program.builtin_force_reducer();
let mut interner = Interner::new();
interner.program(&mut program);
// Use conversion to Debruijn to prevent optimizations from affecting shadowing
let program_named: Program<NamedDeBruijn> = program.try_into().unwrap();
let program: Program<Name> = program_named.try_into().unwrap();
program program
.builtin_force_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.force_delay_reducer() .force_delay_reducer()
.cast_data_reducer() .cast_data_reducer()
.convert_arithmetic_ops()
.builtin_curry_reducer() .builtin_curry_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.builtin_curry_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
} }

View File

@ -1,9 +1,4 @@
use std::{ use std::{cmp::Ordering, iter, ops::Neg, rc::Rc, vec};
cmp::Ordering,
iter::{self},
rc::Rc,
vec,
};
use indexmap::IndexMap; use indexmap::IndexMap;
use itertools::Itertools; use itertools::Itertools;
@ -11,7 +6,7 @@ use itertools::Itertools;
use pallas::ledger::primitives::babbage::{BigInt, PlutusData}; use pallas::ledger::primitives::babbage::{BigInt, PlutusData};
use crate::{ use crate::{
ast::{Constant, Data, Name, Program, Term, Type}, ast::{Constant, Data, Name, NamedDeBruijn, Program, Term, Type},
builtins::DefaultFunction, builtins::DefaultFunction,
parser::interner::Interner, parser::interner::Interner,
}; };
@ -90,6 +85,18 @@ impl Default for IdGen {
} }
} }
fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String {
format!(
"__{}_{}_curried",
func_name,
id_vec
.iter()
.map(|item| item.to_string())
.collect::<Vec<String>>()
.join("_")
)
}
#[derive(PartialEq, PartialOrd, Default, Debug, Clone)] #[derive(PartialEq, PartialOrd, Default, Debug, Clone)]
pub struct VarLookup { pub struct VarLookup {
found: bool, found: bool,
@ -158,6 +165,10 @@ impl DefaultFunction {
DefaultFunction::AddInteger DefaultFunction::AddInteger
| DefaultFunction::SubtractInteger | DefaultFunction::SubtractInteger
| DefaultFunction::MultiplyInteger | DefaultFunction::MultiplyInteger
| DefaultFunction::DivideInteger
| DefaultFunction::ModInteger
| DefaultFunction::QuotientInteger
| DefaultFunction::RemainderInteger
| DefaultFunction::EqualsInteger | DefaultFunction::EqualsInteger
| DefaultFunction::EqualsByteString | DefaultFunction::EqualsByteString
| DefaultFunction::EqualsString | DefaultFunction::EqualsString
@ -172,6 +183,7 @@ impl DefaultFunction {
| DefaultFunction::IndexByteString | DefaultFunction::IndexByteString
| DefaultFunction::LessThanEqualsByteString | DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString | DefaultFunction::LessThanByteString
| DefaultFunction::AppendString
| DefaultFunction::Bls12_381_G1_Add | DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add | DefaultFunction::Bls12_381_G2_Add
| DefaultFunction::ConstrData | DefaultFunction::ConstrData
@ -695,6 +707,8 @@ impl Term<Name> {
with, with,
); );
scope.pop();
with(Some(apply_id), self, vec![], scope); with(Some(apply_id), self, vec![], scope);
} }
Term::Delay(d) => { Term::Delay(d) => {
@ -843,10 +857,18 @@ impl Program<Name> {
}); });
} }
Program { let mut program = Program {
version: program.version, version: program.version,
term, term,
} };
let mut interner = Interner::new();
interner.program(&mut program);
let program = Program::<NamedDeBruijn>::try_from(program).unwrap();
Program::<Name>::try_from(program).unwrap()
} }
pub fn inline_reducer(self) -> Self { pub fn inline_reducer(self) -> Self {
@ -1098,6 +1120,42 @@ impl Program<Name> {
}) })
} }
// Converts subtract integer with a constant to add integer with a negative constant
pub fn convert_arithmetic_ops(self) -> Self {
let mut constants_to_flip = vec![];
self.traverse_uplc_with(&mut |id, term, arg_stack, _scope| match term {
Term::Apply { argument, .. } => {
let id = id.unwrap();
if constants_to_flip.contains(&id) {
let Term::Constant(c) = Rc::make_mut(argument) else {
unreachable!();
};
let Constant::Integer(i) = c.as_ref() else {
unreachable!();
};
*c = Constant::Integer(i.neg()).into();
}
}
Term::Builtin(d @ DefaultFunction::SubtractInteger) => {
if arg_stack.len() == d.arity() {
let Some((apply_id, Term::Constant(_))) = arg_stack.last() else {
return;
};
constants_to_flip.push(*apply_id);
*term = Term::Builtin(DefaultFunction::AddInteger);
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
})
}
pub fn builtin_curry_reducer(self) -> Self { pub fn builtin_curry_reducer(self) -> Self {
let mut curried_terms = vec![]; let mut curried_terms = vec![];
let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, bool)> = let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, bool)> =
@ -1118,14 +1176,9 @@ impl Program<Name> {
let builtin_args = let builtin_args =
BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic); BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic);
let mut scope = scope.clone();
// Get upper scope of the function plus args // Get upper scope of the function plus args
// So for example if the scope is [.., ARG, ARG, FUNC] // So for example if the scope is [.., ARG, ARG, FUNC]
// we want to pop off the last 3 to get the scope right above the function applications // we want to pop off the last 3 to get the scope right above the function applications
for _ in 0..func.arity() {
scope = scope.pop();
}
// First we see if we have already curried this builtin before // First we see if we have already curried this builtin before
let mut id_vec = if let Some(curried_builtin) = curried_terms let mut id_vec = if let Some(curried_builtin) = curried_terms
@ -1164,7 +1217,7 @@ impl Program<Name> {
if let Some((map_scope, _, multi_occurrences)) = if let Some((map_scope, _, multi_occurrences)) =
id_mapped_curry_terms.get_mut(&curry_name) id_mapped_curry_terms.get_mut(&curry_name)
{ {
*map_scope = map_scope.common_ancestor(&scope); *map_scope = map_scope.common_ancestor(scope);
*multi_occurrences = true; *multi_occurrences = true;
} else if id_vec.is_empty() { } else if id_vec.is_empty() {
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
@ -1172,14 +1225,9 @@ impl Program<Name> {
(scope.clone(), Term::Builtin(*func).apply(node.term), false), (scope.clone(), Term::Builtin(*func).apply(node.term), false),
); );
} else { } else {
let var_name = format!( let var_name = id_vec_function_to_var(
"{}_{}", &func.aiken_name(),
func.aiken_name(), &id_vec.iter().map(|item| item.curried_id).collect_vec(),
id_vec
.iter()
.map(|item| item.curried_id.to_string())
.collect::<Vec<String>>()
.join("_")
); );
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
@ -1218,7 +1266,7 @@ impl Program<Name> {
let mut step_b = step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term { let mut step_b = step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term {
Term::Builtin(func) => { Term::Builtin(func) => {
if func.can_curry_builtin() { if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let Some(curried_builtin) = let Some(curried_builtin) =
curried_terms.iter().find(|curry| curry.func == *func) curried_terms.iter().find(|curry| curry.func == *func)
else { else {
@ -1247,18 +1295,15 @@ impl Program<Name> {
return; return;
} }
let id_str = id_vec let name = id_vec_function_to_var(
.iter() &func.aiken_name(),
.map(|item| item.curried_id.to_string()) &id_vec.iter().map(|item| item.curried_id).collect_vec(),
.collect::<Vec<String>>() );
.join("_");
id_vec.iter().for_each(|item| { id_vec.iter().for_each(|item| {
curry_applied_ids.push(item.applied_id); curry_applied_ids.push(item.applied_id);
}); });
let name = format!("{}_{}", func.aiken_name(), id_str);
*term = Term::var(name); *term = Term::var(name);
} }
} }
@ -1271,14 +1316,7 @@ impl Program<Name> {
if let Some(insert_list) = scope_mapped_to_term.remove(scope) { if let Some(insert_list) = scope_mapped_to_term.remove(scope) {
for (key, val) in insert_list.into_iter().rev() { for (key, val) in insert_list.into_iter().rev() {
let name = format!( let name = id_vec_function_to_var(&key.func_name, &key.id_vec);
"{}_{}",
key.func_name,
key.id_vec
.into_iter()
.map(|item| item.to_string())
.join("_")
);
if var_occurrences(term, Name::text(&name).into()).found { if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
@ -1291,14 +1329,7 @@ impl Program<Name> {
_ => { _ => {
if let Some(insert_list) = scope_mapped_to_term.remove(scope) { if let Some(insert_list) = scope_mapped_to_term.remove(scope) {
for (key, val) in insert_list.into_iter().rev() { for (key, val) in insert_list.into_iter().rev() {
let name = format!( let name = id_vec_function_to_var(&key.func_name, &key.id_vec);
"{}_{}",
key.func_name,
key.id_vec
.into_iter()
.map(|item| item.to_string())
.join("_")
);
if var_occurrences(term, Name::text(&name).into()).found { if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
@ -1312,7 +1343,9 @@ impl Program<Name> {
interner.program(&mut step_b); interner.program(&mut step_b);
step_b let step_b = Program::<NamedDeBruijn>::try_from(step_b).unwrap();
Program::<Name>::try_from(step_b).unwrap()
} }
} }