finish up curry optimization on builtins

This commit is contained in:
microproofs 2024-02-15 17:21:41 -05:00 committed by Kasey
parent 58d586c5cf
commit 3b55a32583
1 changed files with 127 additions and 74 deletions

View File

@ -1,6 +1,6 @@
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
iter::{self, Peekable}, iter::{self},
rc::Rc, rc::Rc,
vec, vec,
}; };
@ -13,6 +13,7 @@ use pallas::ledger::primitives::babbage::{BigInt, PlutusData};
use crate::{ use crate::{
ast::{Constant, Data, Name, Program, Term, Type}, ast::{Constant, Data, Name, Program, Term, Type},
builtins::DefaultFunction, builtins::DefaultFunction,
parser::interner::Interner,
}; };
#[derive(Eq, Hash, PartialEq, Clone, Debug, PartialOrd)] #[derive(Eq, Hash, PartialEq, Clone, Debug, PartialOrd)]
@ -268,6 +269,22 @@ pub struct UplcNode {
term: Term<Name>, term: Term<Name>,
} }
#[derive(Eq, Hash, PartialEq, Clone, Debug)]
pub struct CurriedName {
func_name: String,
id_vec: Vec<usize>,
}
impl CurriedName {
pub fn len(&self) -> usize {
self.id_vec.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(PartialEq, Clone, Debug)] #[derive(PartialEq, Clone, Debug)]
pub enum CurriedArgs { pub enum CurriedArgs {
TwoArgs { TwoArgs {
@ -673,46 +690,6 @@ impl Term<Name> {
} }
} }
} }
fn traverse_to_scope<'a, 'b, I>(&'a mut self, scope: &mut Peekable<I>) -> &'a mut Self
where
I: Iterator<Item = &'b ScopePath>,
{
if scope.peek().is_none() {
self
} else {
match self {
Term::Apply { function, argument } => {
let scope_path = scope.next();
match scope_path.unwrap() {
ScopePath::FUNC => Self::traverse_to_scope(Rc::make_mut(function), scope),
ScopePath::ARG => Self::traverse_to_scope(Rc::make_mut(argument), scope),
}
}
Term::Delay(d) => {
let d = Rc::make_mut(d);
// First we recurse further to reduce the inner terms before coming back up to the Delay
Self::traverse_to_scope(d, scope)
}
Term::Lambda { body, .. } => {
let body = Rc::make_mut(body);
// First we recurse further to reduce the inner terms before coming back up to the Delay
Self::traverse_to_scope(body, scope)
}
Term::Force(f) => {
let f = Rc::make_mut(f);
// First we recurse further to reduce the inner terms before coming back up to the Delay
Self::traverse_to_scope(f, scope)
}
Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(),
_ => unreachable!("Incorrect scope path"),
}
}
}
} }
impl Program<Name> { impl Program<Name> {
@ -732,12 +709,6 @@ impl Program<Name> {
} }
} }
fn find_node_by_scope(&mut self, scope: &Scope) -> &mut Term<Name> {
let term = &mut self.term;
let iter = &mut scope.scope.iter().peekable();
term.traverse_to_scope(iter)
}
pub fn lambda_reducer(self) -> Self { pub fn lambda_reducer(self) -> Self {
let mut lambda_applied_ids = vec![]; let mut lambda_applied_ids = vec![];
@ -1103,10 +1074,11 @@ impl Program<Name> {
// WIP // WIP
pub fn builtin_curry_reducer(self) -> Self { pub fn builtin_curry_reducer(self) -> Self {
let mut curried_terms = vec![]; let mut curried_terms = vec![];
let mut id_mapped_curry_terms: IndexMap<Vec<usize>, (Scope, Term<Name>, bool)> = let mut id_mapped_curry_terms: IndexMap<CurriedName, (Scope, Term<Name>, bool)> =
IndexMap::new(); IndexMap::new();
let mut curry_applied_ids = vec![]; let mut curry_applied_ids = vec![];
let mut other_thing: IndexMap<Scope, Vec<(String, Term<Name>)>> = IndexMap::new(); let mut scope_mapped_to_term: IndexMap<Scope, Vec<(CurriedName, Term<Name>)>> =
IndexMap::new();
let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => { Term::Builtin(func) => {
@ -1156,34 +1128,40 @@ impl Program<Name> {
id_only_vec.push(node.curried_id); id_only_vec.push(node.curried_id);
let curry_name = CurriedName {
func_name: func.aiken_name(),
id_vec: id_only_vec,
};
if let Some((map_scope, _, multi_occurrences)) = if let Some((map_scope, _, multi_occurrences)) =
id_mapped_curry_terms.get_mut(&id_only_vec) id_mapped_curry_terms.get_mut(&curry_name)
{ {
*map_scope = map_scope.common_ancestor(&scope); *map_scope = map_scope.common_ancestor(&scope);
*multi_occurrences = true; *multi_occurrences = true;
} else if id_vec.is_empty() { } else if id_vec.is_empty() {
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
id_only_vec, curry_name,
(scope.clone(), Term::Builtin(*func).apply(node.term), false), (scope.clone(), Term::Builtin(*func).apply(node.term), false),
); );
} else { } else {
let var_name = id_vec let var_name = format!(
"{}_{}",
func.aiken_name(),
id_vec
.iter() .iter()
.map(|item| item.curried_id.to_string()) .map(|item| item.curried_id.to_string())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("_"); .join("_")
);
id_mapped_curry_terms.insert( id_mapped_curry_terms.insert(
id_only_vec, curry_name,
(scope.clone(), Term::var(var_name).apply(node.term), false), (scope.clone(), Term::var(var_name).apply(node.term), false),
); );
} }
curry_applied_ids.push(node.applied_id);
} }
} }
} }
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
_ => {} _ => {}
@ -1197,25 +1175,100 @@ impl Program<Name> {
id_mapped_curry_terms id_mapped_curry_terms
.into_iter() .into_iter()
.filter(|(_, (_, _, multi_occurrence))| *multi_occurrence) .filter(|(_, (_, _, multi_occurrence))| *multi_occurrence)
.for_each(|(key, val)| { .for_each(|(key, val)| match scope_mapped_to_term.get_mut(&val.0) {
let name = key.into_iter().map(|item| item.to_string()).join("_"); Some(list) => {
let insert_position = list
.iter()
.position(|(list_key, _)| key.len() <= list_key.len())
.unwrap_or(list.len());
match other_thing.get_mut(&val.0) { list.insert(insert_position, (key, val.1));
Some(list) => list.push((name, val.1)), }
None => { None => {
other_thing.insert(val.0, vec![(name, val.1)]); scope_mapped_to_term.insert(val.0, vec![(key, val.1)]);
}
});
let mut b = a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() {
let Some(curried_builtin) =
curried_terms.iter().find(|curry| curry.func == *func)
else {
return;
};
let builtin_args = BuiltinArgs::args_from_arg_stack(
arg_stack,
func.is_order_agnostic_builtin(),
);
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
return;
};
let id_str = id_vec
.iter()
.map(|item| item.curried_id.to_string())
.collect::<Vec<String>>()
.join("_");
id_vec.iter().for_each(|item| {
curry_applied_ids.push(item.applied_id);
});
let name = format!("{}_{}", func.aiken_name(), id_str);
*term = Term::var(name);
}
}
Term::Apply { function, .. } => {
let id = id.unwrap();
if curry_applied_ids.contains(&id) {
*term = function.as_ref().clone();
}
if let Some(insert_list) = scope_mapped_to_term.remove(scope) {
for (key, val) in insert_list.into_iter().rev() {
let name = format!(
"{}_{}",
key.func_name,
key.id_vec
.into_iter()
.map(|item| item.to_string())
.join("_")
);
*term = term.clone().lambda(name).apply(val);
}
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {
if let Some(insert_list) = scope_mapped_to_term.remove(scope) {
for (key, val) in insert_list.into_iter().rev() {
let name = format!(
"{}_{}",
key.func_name,
key.id_vec
.into_iter()
.map(|item| item.to_string())
.join("_")
);
*term = term.clone().lambda(name).apply(val);
}
} }
} }
}); });
// other_thing let mut interner = Interner::new();
// .into_iter()
// .sorted_by(|item1, item2| item1.0.partial_cmp(&item2.0).expect("HOWWW?"))
// .for_each(|(scope, nodes_to_insert)| {
// // b.get_s
// });
todo!() interner.program(&mut b);
b
} }
} }