diff --git a/crates/aiken-lang/src/uplc.rs b/crates/aiken-lang/src/uplc.rs index 79850d43..b01fea5c 100644 --- a/crates/aiken-lang/src/uplc.rs +++ b/crates/aiken-lang/src/uplc.rs @@ -3999,24 +3999,16 @@ impl<'a> CodeGenerator<'a> { fn gen_uplc(&mut self, ir: Air, arg_stack: &mut Vec>) { match ir { Air::Int { value, .. } => { - let integer = value.parse().unwrap(); - - let term = Term::Constant(UplcConstant::Integer(integer).into()); - - arg_stack.push(term); + arg_stack.push(Term::integer(value.parse().unwrap())); } Air::String { value, .. } => { - let term = Term::Constant(UplcConstant::String(value).into()); - - arg_stack.push(term); + arg_stack.push(Term::string(value)); } Air::ByteArray { bytes, .. } => { - let term = Term::Constant(UplcConstant::ByteString(bytes).into()); - arg_stack.push(term); + arg_stack.push(Term::byte_string(bytes)); } Air::Bool { value, .. } => { - let term = Term::Constant(UplcConstant::Bool(value).into()); - arg_stack.push(term); + arg_stack.push(Term::bool(value)); } Air::Var { name, @@ -4214,19 +4206,17 @@ impl<'a> CodeGenerator<'a> { .take(names.len()) .collect_vec(); - term = apply_wrap( - list_access_to_uplc( - &names, - &id_list, - tail, - 0, - term, - inner_types, - check_last_item, - true, - ), - value, - ); + term = list_access_to_uplc( + &names, + &id_list, + tail, + 0, + term, + inner_types, + check_last_item, + true, + ) + .apply(value); arg_stack.push(term); } @@ -4239,66 +4229,21 @@ impl<'a> CodeGenerator<'a> { let mut term = arg_stack.pop().unwrap(); if let Some((tail_var, tail_name)) = tail { - term = apply_wrap( - Term::Lambda { - parameter_name: Name { - text: tail_name, - unique: 0.into(), - } - .into(), - body: term.into(), - }, - apply_wrap( - Term::Builtin(DefaultFunction::TailList).force(), - Term::Var( - Name { - text: tail_var, - unique: 0.into(), - } - .into(), - ), - ), - ); + term = term + .lambda(tail_name) + .apply(Term::tail_list().apply(Term::var(tail_var))); } for (tail_var, head_name) in tail_head_names.into_iter().rev() { let head_list = if tipo.is_map() { - apply_wrap( - Term::Force(Term::Builtin(DefaultFunction::HeadList).into()), - Term::Var( - Name { - text: tail_var, - unique: 0.into(), - } - .into(), - ), - ) + Term::head_list().apply(Term::var(tail_var)) } else { convert_data_to_type( - apply_wrap( - Term::Builtin(DefaultFunction::HeadList).force(), - Term::Var( - Name { - text: tail_var, - unique: 0.into(), - } - .into(), - ), - ), + Term::head_list().apply(Term::var(tail_var)), &tipo.get_inner_types()[0], ) }; - term = apply_wrap( - Term::Lambda { - parameter_name: Name { - text: head_name, - unique: 0.into(), - } - .into(), - body: term.into(), - }, - head_list, - ); + term = term.lambda(head_name).apply(head_list); } arg_stack.push(term); @@ -4307,14 +4252,7 @@ impl<'a> CodeGenerator<'a> { let mut term = arg_stack.pop().unwrap(); for param in params.iter().rev() { - term = Term::Lambda { - parameter_name: Name { - text: param.clone(), - unique: 0.into(), - } - .into(), - body: term.into(), - }; + term = term.lambda(param); } arg_stack.push(term); @@ -4326,7 +4264,7 @@ impl<'a> CodeGenerator<'a> { for _ in 0..count { let arg = arg_stack.pop().unwrap(); - term = apply_wrap(term, arg); + term = term.apply(arg); } arg_stack.push(term); } else { @@ -4395,7 +4333,7 @@ impl<'a> CodeGenerator<'a> { for (index, arg) in arg_vec.into_iter().enumerate() { let arg = if matches!(func, DefaultFunction::ChooseData) && index > 0 { - Term::Delay(arg.into()) + arg.delay() } else { arg }; @@ -4409,29 +4347,13 @@ impl<'a> CodeGenerator<'a> { let temp_var = format!("__item_{}", self.id_gen.next()); if count == 0 { - term = apply_wrap( - term, - Term::Var( - Name { - text: temp_var.clone(), - unique: 0.into(), - } - .into(), - ), - ); + term = term.apply(Term::var(temp_var.clone())); } term = convert_data_to_type(term, &tipo); if count == 0 { - term = Term::Lambda { - parameter_name: Name { - text: temp_var, - unique: 0.into(), - } - .into(), - body: term.into(), - }; + term = term.lambda(temp_var); } } DefaultFunction::UnConstrData => { @@ -4440,72 +4362,23 @@ impl<'a> CodeGenerator<'a> { let temp_var = format!("__item_{}", self.id_gen.next()); if count == 0 { - term = apply_wrap( - term, - Term::Var( - Name { - text: temp_var.clone(), - unique: 0.into(), - } - .into(), - ), - ); + term = term.apply(Term::var(temp_var.clone())); } - term = apply_wrap( - Term::Lambda { - parameter_name: Name { - text: temp_tuple.clone(), - unique: 0.into(), - } - .into(), - body: apply_wrap( - apply_wrap( - Term::Builtin(DefaultFunction::MkPairData), - apply_wrap( - Term::Builtin(DefaultFunction::IData), - apply_wrap( - Term::Builtin(DefaultFunction::FstPair) - .force() - .force(), - Term::Var( - Name { - text: temp_tuple.clone(), - unique: 0.into(), - } - .into(), - ), - ), - ), - ), - apply_wrap( - Term::Builtin(DefaultFunction::ListData), - apply_wrap( - Term::Builtin(DefaultFunction::SndPair).force().force(), - Term::Var( - Name { - text: temp_tuple, - unique: 0.into(), - } - .into(), - ), - ), - ), - ) - .into(), - }, - term, - ); + term = Term::mk_pair_data() + .apply( + Term::i_data() + .apply(Term::fst_pair().apply(Term::var(temp_tuple.clone()))), + ) + .apply( + Term::list_data() + .apply(Term::snd_pair().apply(Term::var(temp_tuple.clone()))), + ) + .lambda(temp_tuple) + .apply(term); if count == 0 { - term = Term::Lambda { - parameter_name: Name { - text: temp_var, - unique: 0.into(), - } - .into(), - body: term.into(), - }; + term = term.lambda(temp_var); } } DefaultFunction::MkCons => { @@ -4521,29 +4394,11 @@ impl<'a> CodeGenerator<'a> { if count == 0 { for (index, temp_var) in temp_vars.iter().enumerate() { - term = apply_wrap( - term, - if index > 0 { - Term::Delay( - Term::Var( - Name { - text: temp_var.clone(), - unique: 0.into(), - } - .into(), - ) - .into(), - ) - } else { - Term::Var( - Name { - text: temp_var.clone(), - unique: 0.into(), - } - .into(), - ) - }, - ); + term = term.apply(if index > 0 { + Term::var(temp_var.clone()).delay() + } else { + Term::var(temp_var.clone()) + }); } } @@ -4551,14 +4406,7 @@ impl<'a> CodeGenerator<'a> { if count == 0 { for temp_var in temp_vars.into_iter().rev() { - term = Term::Lambda { - parameter_name: Name { - text: temp_var, - unique: 0.into(), - } - .into(), - body: term.into(), - }; + term = term.lambda(temp_var); } } } @@ -4570,86 +4418,42 @@ impl<'a> CodeGenerator<'a> { let left = arg_stack.pop().unwrap(); let right = arg_stack.pop().unwrap(); - let default_builtin = if tipo.is_int() { - DefaultFunction::EqualsInteger + let builtin = if tipo.is_int() { + Term::equals_integer() } else if tipo.is_string() { - DefaultFunction::EqualsString + Term::equals_string() } else if tipo.is_bytearray() { - DefaultFunction::EqualsByteString + Term::equals_bytestring() } else { - DefaultFunction::EqualsData + Term::equals_data() }; let term = match name { - BinOp::And => delayed_if_else( - left, - right, - Term::Constant(UplcConstant::Bool(false).into()), - ), - BinOp::Or => delayed_if_else( - left, - Term::Constant(UplcConstant::Bool(true).into()), - right, - ), - + BinOp::And => left.delayed_if_else(right, Term::bool(false)), + BinOp::Or => left.delayed_if_else(Term::bool(true), right), BinOp::Eq => { if tipo.is_bool() { - let term = delayed_if_else( - left, + let term = left.delayed_if_else( right.clone(), - if_else( - right, - Term::Constant(UplcConstant::Bool(false).into()), - Term::Constant(UplcConstant::Bool(true).into()), - ), + right.if_else(Term::bool(false), Term::bool(true)), ); + arg_stack.push(term); return; } else if tipo.is_map() { - let term = apply_wrap( - apply_wrap( - default_builtin.into(), - apply_wrap(DefaultFunction::MapData.into(), left), - ), - apply_wrap(DefaultFunction::MapData.into(), right), - ); + let term = builtin + .apply(Term::map_data().apply(left)) + .apply(Term::map_data().apply(right)); arg_stack.push(term); return; } else if tipo.is_tuple() && matches!(tipo.get_uplc_type(), UplcType::Pair(_, _)) { - let term = apply_wrap( - apply_wrap( - default_builtin.into(), - apply_wrap( - DefaultFunction::MapData.into(), - apply_wrap( - apply_wrap( - Term::Builtin(DefaultFunction::MkCons).force(), - left, - ), - Term::Constant( - UplcConstant::ProtoList( - UplcType::Pair( - UplcType::Data.into(), - UplcType::Data.into(), - ), - vec![], - ) - .into(), - ), - ), - ), - ), - apply_wrap( - DefaultFunction::MapData.into(), - apply_wrap( - apply_wrap( - Term::Builtin(DefaultFunction::MkCons).force(), - right, - ), - Term::Constant( + let term = builtin + .apply( + Term::map_data().apply( + Term::mk_cons().apply(left).apply(Term::Constant( UplcConstant::ProtoList( UplcType::Pair( UplcType::Data.into(), @@ -4658,30 +4462,39 @@ impl<'a> CodeGenerator<'a> { vec![], ) .into(), - ), + )), ), - ), - ); + ) + .apply( + Term::map_data().apply( + Term::mk_cons().apply(right).apply(Term::Constant( + UplcConstant::ProtoList( + UplcType::Pair( + UplcType::Data.into(), + UplcType::Data.into(), + ), + vec![], + ) + .into(), + )), + ), + ); arg_stack.push(term); return; } else if tipo.is_list() || tipo.is_tuple() { - let term = apply_wrap( - apply_wrap( - default_builtin.into(), - apply_wrap(DefaultFunction::ListData.into(), left), - ), - apply_wrap(DefaultFunction::ListData.into(), right), - ); + let term = builtin + .apply(Term::list_data().apply(left)) + .apply(Term::list_data().apply(right)); arg_stack.push(term); return; } else if tipo.is_void() { - arg_stack.push(Term::Constant(UplcConstant::Bool(true).into())); + arg_stack.push(Term::bool(true)); return; } - apply_wrap(apply_wrap(default_builtin.into(), left), right) + builtin.apply(left).apply(right) } BinOp::NotEq => { if tipo.is_bool() { @@ -4700,7 +4513,7 @@ impl<'a> CodeGenerator<'a> { let term = if_else( apply_wrap( apply_wrap( - default_builtin.into(), + builtin, apply_wrap(DefaultFunction::MapData.into(), left), ), apply_wrap(DefaultFunction::MapData.into(), right), @@ -4716,7 +4529,7 @@ impl<'a> CodeGenerator<'a> { { let mut term = apply_wrap( apply_wrap( - default_builtin.into(), + builtin, apply_wrap( DefaultFunction::MapData.into(), apply_wrap( @@ -4769,7 +4582,7 @@ impl<'a> CodeGenerator<'a> { let term = if_else( apply_wrap( apply_wrap( - default_builtin.into(), + builtin, apply_wrap(DefaultFunction::ListData.into(), left), ), apply_wrap(DefaultFunction::ListData.into(), right), @@ -4786,7 +4599,7 @@ impl<'a> CodeGenerator<'a> { } if_else( - apply_wrap(apply_wrap(default_builtin.into(), left), right), + apply_wrap(apply_wrap(builtin, left), right), Term::Constant(UplcConstant::Bool(false).into()), Term::Constant(UplcConstant::Bool(true).into()), ) diff --git a/crates/uplc/src/ast/builder.rs b/crates/uplc/src/ast/builder.rs index 9bef8037..28db6ed7 100644 --- a/crates/uplc/src/ast/builder.rs +++ b/crates/uplc/src/ast/builder.rs @@ -38,32 +38,93 @@ impl Term { Term::Constant(Constant::Integer(i).into()) } + pub fn string(s: String) -> Self { + Term::Constant(Constant::String(s).into()) + } + + pub fn byte_string(b: Vec) -> Self { + Term::Constant(Constant::ByteString(b).into()) + } + + pub fn bool(b: bool) -> Self { + Term::Constant(Constant::Bool(b).into()) + } + pub fn constr_data() -> Self { Term::Builtin(DefaultFunction::ConstrData) } + pub fn map_data() -> Self { + Term::Builtin(DefaultFunction::MapData) + } + + pub fn list_data() -> Self { + Term::Builtin(DefaultFunction::ListData) + } + + pub fn b_data() -> Self { + Term::Builtin(DefaultFunction::BData) + } + + pub fn i_data() -> Self { + Term::Builtin(DefaultFunction::IData) + } + pub fn equals_integer() -> Self { Term::Builtin(DefaultFunction::EqualsInteger) } + pub fn equals_string() -> Self { + Term::Builtin(DefaultFunction::EqualsString) + } + + pub fn equals_bytestring() -> Self { + Term::Builtin(DefaultFunction::EqualsByteString) + } + + pub fn equals_data() -> Self { + Term::Builtin(DefaultFunction::EqualsData) + } + pub fn head_list() -> Self { Term::Builtin(DefaultFunction::HeadList).force() } + pub fn tail_list() -> Self { + Term::Builtin(DefaultFunction::TailList).force() + } + + pub fn mk_cons() -> Self { + Term::Builtin(DefaultFunction::MkCons).force() + } + + pub fn fst_pair() -> Self { + Term::Builtin(DefaultFunction::FstPair).force().force() + } + + pub fn snd_pair() -> Self { + Term::Builtin(DefaultFunction::SndPair).force().force() + } + + pub fn mk_pair_data() -> Self { + Term::Builtin(DefaultFunction::MkPairData) + } + + pub fn if_else(self, then_term: Self, else_term: Self) -> Self { + Term::Builtin(DefaultFunction::IfThenElse) + .force() + .apply(self) + .apply(then_term) + .apply(else_term) + } + pub fn delayed_if_else(self, then_term: Self, else_term: Self) -> Self { - Term::Apply { - function: Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::IfThenElse).force().into(), - argument: self.into(), - } - .into(), - argument: Term::Delay(then_term.into()).into(), - } - .into(), - argument: Term::Delay(else_term.into()).into(), - } - .force() + Term::Builtin(DefaultFunction::IfThenElse) + .force() + .apply(self) + .apply(then_term.delay()) + .apply(else_term.delay()) + .force() } }