refactor: uplc code gen uses shorthand methods

This commit is contained in:
Kasey White 2023-03-17 22:05:04 -04:00 committed by Lucas
parent 74a7a2f214
commit ef3862ade8
2 changed files with 161 additions and 287 deletions

View File

@ -3999,24 +3999,16 @@ impl<'a> CodeGenerator<'a> {
fn gen_uplc(&mut self, ir: Air, arg_stack: &mut Vec<Term<Name>>) { fn gen_uplc(&mut self, ir: Air, arg_stack: &mut Vec<Term<Name>>) {
match ir { match ir {
Air::Int { value, .. } => { Air::Int { value, .. } => {
let integer = value.parse().unwrap(); arg_stack.push(Term::integer(value.parse().unwrap()));
let term = Term::Constant(UplcConstant::Integer(integer).into());
arg_stack.push(term);
} }
Air::String { value, .. } => { Air::String { value, .. } => {
let term = Term::Constant(UplcConstant::String(value).into()); arg_stack.push(Term::string(value));
arg_stack.push(term);
} }
Air::ByteArray { bytes, .. } => { Air::ByteArray { bytes, .. } => {
let term = Term::Constant(UplcConstant::ByteString(bytes).into()); arg_stack.push(Term::byte_string(bytes));
arg_stack.push(term);
} }
Air::Bool { value, .. } => { Air::Bool { value, .. } => {
let term = Term::Constant(UplcConstant::Bool(value).into()); arg_stack.push(Term::bool(value));
arg_stack.push(term);
} }
Air::Var { Air::Var {
name, name,
@ -4214,8 +4206,7 @@ impl<'a> CodeGenerator<'a> {
.take(names.len()) .take(names.len())
.collect_vec(); .collect_vec();
term = apply_wrap( term = list_access_to_uplc(
list_access_to_uplc(
&names, &names,
&id_list, &id_list,
tail, tail,
@ -4224,9 +4215,8 @@ impl<'a> CodeGenerator<'a> {
inner_types, inner_types,
check_last_item, check_last_item,
true, true,
), )
value, .apply(value);
);
arg_stack.push(term); arg_stack.push(term);
} }
@ -4239,66 +4229,21 @@ impl<'a> CodeGenerator<'a> {
let mut term = arg_stack.pop().unwrap(); let mut term = arg_stack.pop().unwrap();
if let Some((tail_var, tail_name)) = tail { if let Some((tail_var, tail_name)) = tail {
term = apply_wrap( term = term
Term::Lambda { .lambda(tail_name)
parameter_name: Name { .apply(Term::tail_list().apply(Term::var(tail_var)));
text: tail_name,
unique: 0.into(),
}
.into(),
body: term.into(),
},
apply_wrap(
Term::Builtin(DefaultFunction::TailList).force(),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
),
);
} }
for (tail_var, head_name) in tail_head_names.into_iter().rev() { for (tail_var, head_name) in tail_head_names.into_iter().rev() {
let head_list = if tipo.is_map() { let head_list = if tipo.is_map() {
apply_wrap( Term::head_list().apply(Term::var(tail_var))
Term::Force(Term::Builtin(DefaultFunction::HeadList).into()),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
)
} else { } else {
convert_data_to_type( convert_data_to_type(
apply_wrap( Term::head_list().apply(Term::var(tail_var)),
Term::Builtin(DefaultFunction::HeadList).force(),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
),
&tipo.get_inner_types()[0], &tipo.get_inner_types()[0],
) )
}; };
term = apply_wrap( term = term.lambda(head_name).apply(head_list);
Term::Lambda {
parameter_name: Name {
text: head_name,
unique: 0.into(),
}
.into(),
body: term.into(),
},
head_list,
);
} }
arg_stack.push(term); arg_stack.push(term);
@ -4307,14 +4252,7 @@ impl<'a> CodeGenerator<'a> {
let mut term = arg_stack.pop().unwrap(); let mut term = arg_stack.pop().unwrap();
for param in params.iter().rev() { for param in params.iter().rev() {
term = Term::Lambda { term = term.lambda(param);
parameter_name: Name {
text: param.clone(),
unique: 0.into(),
}
.into(),
body: term.into(),
};
} }
arg_stack.push(term); arg_stack.push(term);
@ -4326,7 +4264,7 @@ impl<'a> CodeGenerator<'a> {
for _ in 0..count { for _ in 0..count {
let arg = arg_stack.pop().unwrap(); let arg = arg_stack.pop().unwrap();
term = apply_wrap(term, arg); term = term.apply(arg);
} }
arg_stack.push(term); arg_stack.push(term);
} else { } else {
@ -4395,7 +4333,7 @@ impl<'a> CodeGenerator<'a> {
for (index, arg) in arg_vec.into_iter().enumerate() { for (index, arg) in arg_vec.into_iter().enumerate() {
let arg = if matches!(func, DefaultFunction::ChooseData) && index > 0 { let arg = if matches!(func, DefaultFunction::ChooseData) && index > 0 {
Term::Delay(arg.into()) arg.delay()
} else { } else {
arg arg
}; };
@ -4409,29 +4347,13 @@ impl<'a> CodeGenerator<'a> {
let temp_var = format!("__item_{}", self.id_gen.next()); let temp_var = format!("__item_{}", self.id_gen.next());
if count == 0 { if count == 0 {
term = apply_wrap( term = term.apply(Term::var(temp_var.clone()));
term,
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
),
);
} }
term = convert_data_to_type(term, &tipo); term = convert_data_to_type(term, &tipo);
if count == 0 { if count == 0 {
term = Term::Lambda { term = term.lambda(temp_var);
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
} }
} }
DefaultFunction::UnConstrData => { DefaultFunction::UnConstrData => {
@ -4440,72 +4362,23 @@ impl<'a> CodeGenerator<'a> {
let temp_var = format!("__item_{}", self.id_gen.next()); let temp_var = format!("__item_{}", self.id_gen.next());
if count == 0 { if count == 0 {
term = apply_wrap( term = term.apply(Term::var(temp_var.clone()));
term,
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
),
);
} }
term = apply_wrap( term = Term::mk_pair_data()
Term::Lambda { .apply(
parameter_name: Name { Term::i_data()
text: temp_tuple.clone(), .apply(Term::fst_pair().apply(Term::var(temp_tuple.clone()))),
unique: 0.into(),
}
.into(),
body: apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkPairData),
apply_wrap(
Term::Builtin(DefaultFunction::IData),
apply_wrap(
Term::Builtin(DefaultFunction::FstPair)
.force()
.force(),
Term::Var(
Name {
text: temp_tuple.clone(),
unique: 0.into(),
}
.into(),
),
),
),
),
apply_wrap(
Term::Builtin(DefaultFunction::ListData),
apply_wrap(
Term::Builtin(DefaultFunction::SndPair).force().force(),
Term::Var(
Name {
text: temp_tuple,
unique: 0.into(),
}
.into(),
),
),
),
) )
.into(), .apply(
}, Term::list_data()
term, .apply(Term::snd_pair().apply(Term::var(temp_tuple.clone()))),
); )
.lambda(temp_tuple)
.apply(term);
if count == 0 { if count == 0 {
term = Term::Lambda { term = term.lambda(temp_var);
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
} }
} }
DefaultFunction::MkCons => { DefaultFunction::MkCons => {
@ -4521,29 +4394,11 @@ impl<'a> CodeGenerator<'a> {
if count == 0 { if count == 0 {
for (index, temp_var) in temp_vars.iter().enumerate() { for (index, temp_var) in temp_vars.iter().enumerate() {
term = apply_wrap( term = term.apply(if index > 0 {
term, Term::var(temp_var.clone()).delay()
if index > 0 {
Term::Delay(
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
)
.into(),
)
} else { } else {
Term::Var( Term::var(temp_var.clone())
Name { });
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
)
},
);
} }
} }
@ -4551,14 +4406,7 @@ impl<'a> CodeGenerator<'a> {
if count == 0 { if count == 0 {
for temp_var in temp_vars.into_iter().rev() { for temp_var in temp_vars.into_iter().rev() {
term = Term::Lambda { term = term.lambda(temp_var);
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
} }
} }
} }
@ -4570,66 +4418,42 @@ impl<'a> CodeGenerator<'a> {
let left = arg_stack.pop().unwrap(); let left = arg_stack.pop().unwrap();
let right = arg_stack.pop().unwrap(); let right = arg_stack.pop().unwrap();
let default_builtin = if tipo.is_int() { let builtin = if tipo.is_int() {
DefaultFunction::EqualsInteger Term::equals_integer()
} else if tipo.is_string() { } else if tipo.is_string() {
DefaultFunction::EqualsString Term::equals_string()
} else if tipo.is_bytearray() { } else if tipo.is_bytearray() {
DefaultFunction::EqualsByteString Term::equals_bytestring()
} else { } else {
DefaultFunction::EqualsData Term::equals_data()
}; };
let term = match name { let term = match name {
BinOp::And => delayed_if_else( BinOp::And => left.delayed_if_else(right, Term::bool(false)),
left, BinOp::Or => left.delayed_if_else(Term::bool(true), right),
right,
Term::Constant(UplcConstant::Bool(false).into()),
),
BinOp::Or => delayed_if_else(
left,
Term::Constant(UplcConstant::Bool(true).into()),
right,
),
BinOp::Eq => { BinOp::Eq => {
if tipo.is_bool() { if tipo.is_bool() {
let term = delayed_if_else( let term = left.delayed_if_else(
left,
right.clone(), right.clone(),
if_else( right.if_else(Term::bool(false), Term::bool(true)),
right,
Term::Constant(UplcConstant::Bool(false).into()),
Term::Constant(UplcConstant::Bool(true).into()),
),
); );
arg_stack.push(term); arg_stack.push(term);
return; return;
} else if tipo.is_map() { } else if tipo.is_map() {
let term = apply_wrap( let term = builtin
apply_wrap( .apply(Term::map_data().apply(left))
default_builtin.into(), .apply(Term::map_data().apply(right));
apply_wrap(DefaultFunction::MapData.into(), left),
),
apply_wrap(DefaultFunction::MapData.into(), right),
);
arg_stack.push(term); arg_stack.push(term);
return; return;
} else if tipo.is_tuple() } else if tipo.is_tuple()
&& matches!(tipo.get_uplc_type(), UplcType::Pair(_, _)) && matches!(tipo.get_uplc_type(), UplcType::Pair(_, _))
{ {
let term = apply_wrap( let term = builtin
apply_wrap( .apply(
default_builtin.into(), Term::map_data().apply(
apply_wrap( Term::mk_cons().apply(left).apply(Term::Constant(
DefaultFunction::MapData.into(),
apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkCons).force(),
left,
),
Term::Constant(
UplcConstant::ProtoList( UplcConstant::ProtoList(
UplcType::Pair( UplcType::Pair(
UplcType::Data.into(), UplcType::Data.into(),
@ -4638,18 +4462,12 @@ impl<'a> CodeGenerator<'a> {
vec![], vec![],
) )
.into(), .into(),
)),
), ),
), )
), .apply(
), Term::map_data().apply(
apply_wrap( Term::mk_cons().apply(right).apply(Term::Constant(
DefaultFunction::MapData.into(),
apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkCons).force(),
right,
),
Term::Constant(
UplcConstant::ProtoList( UplcConstant::ProtoList(
UplcType::Pair( UplcType::Pair(
UplcType::Data.into(), UplcType::Data.into(),
@ -4658,30 +4476,25 @@ impl<'a> CodeGenerator<'a> {
vec![], vec![],
) )
.into(), .into(),
), )),
),
), ),
); );
arg_stack.push(term); arg_stack.push(term);
return; return;
} else if tipo.is_list() || tipo.is_tuple() { } else if tipo.is_list() || tipo.is_tuple() {
let term = apply_wrap( let term = builtin
apply_wrap( .apply(Term::list_data().apply(left))
default_builtin.into(), .apply(Term::list_data().apply(right));
apply_wrap(DefaultFunction::ListData.into(), left),
),
apply_wrap(DefaultFunction::ListData.into(), right),
);
arg_stack.push(term); arg_stack.push(term);
return; return;
} else if tipo.is_void() { } else if tipo.is_void() {
arg_stack.push(Term::Constant(UplcConstant::Bool(true).into())); arg_stack.push(Term::bool(true));
return; return;
} }
apply_wrap(apply_wrap(default_builtin.into(), left), right) builtin.apply(left).apply(right)
} }
BinOp::NotEq => { BinOp::NotEq => {
if tipo.is_bool() { if tipo.is_bool() {
@ -4700,7 +4513,7 @@ impl<'a> CodeGenerator<'a> {
let term = if_else( let term = if_else(
apply_wrap( apply_wrap(
apply_wrap( apply_wrap(
default_builtin.into(), builtin,
apply_wrap(DefaultFunction::MapData.into(), left), apply_wrap(DefaultFunction::MapData.into(), left),
), ),
apply_wrap(DefaultFunction::MapData.into(), right), apply_wrap(DefaultFunction::MapData.into(), right),
@ -4716,7 +4529,7 @@ impl<'a> CodeGenerator<'a> {
{ {
let mut term = apply_wrap( let mut term = apply_wrap(
apply_wrap( apply_wrap(
default_builtin.into(), builtin,
apply_wrap( apply_wrap(
DefaultFunction::MapData.into(), DefaultFunction::MapData.into(),
apply_wrap( apply_wrap(
@ -4769,7 +4582,7 @@ impl<'a> CodeGenerator<'a> {
let term = if_else( let term = if_else(
apply_wrap( apply_wrap(
apply_wrap( apply_wrap(
default_builtin.into(), builtin,
apply_wrap(DefaultFunction::ListData.into(), left), apply_wrap(DefaultFunction::ListData.into(), left),
), ),
apply_wrap(DefaultFunction::ListData.into(), right), apply_wrap(DefaultFunction::ListData.into(), right),
@ -4786,7 +4599,7 @@ impl<'a> CodeGenerator<'a> {
} }
if_else( if_else(
apply_wrap(apply_wrap(default_builtin.into(), left), right), apply_wrap(apply_wrap(builtin, left), right),
Term::Constant(UplcConstant::Bool(false).into()), Term::Constant(UplcConstant::Bool(false).into()),
Term::Constant(UplcConstant::Bool(true).into()), Term::Constant(UplcConstant::Bool(true).into()),
) )

View File

@ -38,31 +38,92 @@ impl Term<Name> {
Term::Constant(Constant::Integer(i).into()) Term::Constant(Constant::Integer(i).into())
} }
pub fn string(s: String) -> Self {
Term::Constant(Constant::String(s).into())
}
pub fn byte_string(b: Vec<u8>) -> Self {
Term::Constant(Constant::ByteString(b).into())
}
pub fn bool(b: bool) -> Self {
Term::Constant(Constant::Bool(b).into())
}
pub fn constr_data() -> Self { pub fn constr_data() -> Self {
Term::Builtin(DefaultFunction::ConstrData) Term::Builtin(DefaultFunction::ConstrData)
} }
pub fn map_data() -> Self {
Term::Builtin(DefaultFunction::MapData)
}
pub fn list_data() -> Self {
Term::Builtin(DefaultFunction::ListData)
}
pub fn b_data() -> Self {
Term::Builtin(DefaultFunction::BData)
}
pub fn i_data() -> Self {
Term::Builtin(DefaultFunction::IData)
}
pub fn equals_integer() -> Self { pub fn equals_integer() -> Self {
Term::Builtin(DefaultFunction::EqualsInteger) Term::Builtin(DefaultFunction::EqualsInteger)
} }
pub fn equals_string() -> Self {
Term::Builtin(DefaultFunction::EqualsString)
}
pub fn equals_bytestring() -> Self {
Term::Builtin(DefaultFunction::EqualsByteString)
}
pub fn equals_data() -> Self {
Term::Builtin(DefaultFunction::EqualsData)
}
pub fn head_list() -> Self { pub fn head_list() -> Self {
Term::Builtin(DefaultFunction::HeadList).force() Term::Builtin(DefaultFunction::HeadList).force()
} }
pub fn tail_list() -> Self {
Term::Builtin(DefaultFunction::TailList).force()
}
pub fn mk_cons() -> Self {
Term::Builtin(DefaultFunction::MkCons).force()
}
pub fn fst_pair() -> Self {
Term::Builtin(DefaultFunction::FstPair).force().force()
}
pub fn snd_pair() -> Self {
Term::Builtin(DefaultFunction::SndPair).force().force()
}
pub fn mk_pair_data() -> Self {
Term::Builtin(DefaultFunction::MkPairData)
}
pub fn if_else(self, then_term: Self, else_term: Self) -> Self {
Term::Builtin(DefaultFunction::IfThenElse)
.force()
.apply(self)
.apply(then_term)
.apply(else_term)
}
pub fn delayed_if_else(self, then_term: Self, else_term: Self) -> Self { pub fn delayed_if_else(self, then_term: Self, else_term: Self) -> Self {
Term::Apply { Term::Builtin(DefaultFunction::IfThenElse)
function: Term::Apply { .force()
function: Term::Apply { .apply(self)
function: Term::Builtin(DefaultFunction::IfThenElse).force().into(), .apply(then_term.delay())
argument: self.into(), .apply(else_term.delay())
}
.into(),
argument: Term::Delay(then_term.into()).into(),
}
.into(),
argument: Term::Delay(else_term.into()).into(),
}
.force() .force()
} }
} }