refactor: uplc code gen uses shorthand methods

This commit is contained in:
Kasey White 2023-03-17 22:05:04 -04:00 committed by Lucas
parent 74a7a2f214
commit ef3862ade8
2 changed files with 161 additions and 287 deletions

View File

@ -3999,24 +3999,16 @@ impl<'a> CodeGenerator<'a> {
fn gen_uplc(&mut self, ir: Air, arg_stack: &mut Vec<Term<Name>>) {
match ir {
Air::Int { value, .. } => {
let integer = value.parse().unwrap();
let term = Term::Constant(UplcConstant::Integer(integer).into());
arg_stack.push(term);
arg_stack.push(Term::integer(value.parse().unwrap()));
}
Air::String { value, .. } => {
let term = Term::Constant(UplcConstant::String(value).into());
arg_stack.push(term);
arg_stack.push(Term::string(value));
}
Air::ByteArray { bytes, .. } => {
let term = Term::Constant(UplcConstant::ByteString(bytes).into());
arg_stack.push(term);
arg_stack.push(Term::byte_string(bytes));
}
Air::Bool { value, .. } => {
let term = Term::Constant(UplcConstant::Bool(value).into());
arg_stack.push(term);
arg_stack.push(Term::bool(value));
}
Air::Var {
name,
@ -4214,19 +4206,17 @@ impl<'a> CodeGenerator<'a> {
.take(names.len())
.collect_vec();
term = apply_wrap(
list_access_to_uplc(
&names,
&id_list,
tail,
0,
term,
inner_types,
check_last_item,
true,
),
value,
);
term = list_access_to_uplc(
&names,
&id_list,
tail,
0,
term,
inner_types,
check_last_item,
true,
)
.apply(value);
arg_stack.push(term);
}
@ -4239,66 +4229,21 @@ impl<'a> CodeGenerator<'a> {
let mut term = arg_stack.pop().unwrap();
if let Some((tail_var, tail_name)) = tail {
term = apply_wrap(
Term::Lambda {
parameter_name: Name {
text: tail_name,
unique: 0.into(),
}
.into(),
body: term.into(),
},
apply_wrap(
Term::Builtin(DefaultFunction::TailList).force(),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
),
);
term = term
.lambda(tail_name)
.apply(Term::tail_list().apply(Term::var(tail_var)));
}
for (tail_var, head_name) in tail_head_names.into_iter().rev() {
let head_list = if tipo.is_map() {
apply_wrap(
Term::Force(Term::Builtin(DefaultFunction::HeadList).into()),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
)
Term::head_list().apply(Term::var(tail_var))
} else {
convert_data_to_type(
apply_wrap(
Term::Builtin(DefaultFunction::HeadList).force(),
Term::Var(
Name {
text: tail_var,
unique: 0.into(),
}
.into(),
),
),
Term::head_list().apply(Term::var(tail_var)),
&tipo.get_inner_types()[0],
)
};
term = apply_wrap(
Term::Lambda {
parameter_name: Name {
text: head_name,
unique: 0.into(),
}
.into(),
body: term.into(),
},
head_list,
);
term = term.lambda(head_name).apply(head_list);
}
arg_stack.push(term);
@ -4307,14 +4252,7 @@ impl<'a> CodeGenerator<'a> {
let mut term = arg_stack.pop().unwrap();
for param in params.iter().rev() {
term = Term::Lambda {
parameter_name: Name {
text: param.clone(),
unique: 0.into(),
}
.into(),
body: term.into(),
};
term = term.lambda(param);
}
arg_stack.push(term);
@ -4326,7 +4264,7 @@ impl<'a> CodeGenerator<'a> {
for _ in 0..count {
let arg = arg_stack.pop().unwrap();
term = apply_wrap(term, arg);
term = term.apply(arg);
}
arg_stack.push(term);
} else {
@ -4395,7 +4333,7 @@ impl<'a> CodeGenerator<'a> {
for (index, arg) in arg_vec.into_iter().enumerate() {
let arg = if matches!(func, DefaultFunction::ChooseData) && index > 0 {
Term::Delay(arg.into())
arg.delay()
} else {
arg
};
@ -4409,29 +4347,13 @@ impl<'a> CodeGenerator<'a> {
let temp_var = format!("__item_{}", self.id_gen.next());
if count == 0 {
term = apply_wrap(
term,
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
),
);
term = term.apply(Term::var(temp_var.clone()));
}
term = convert_data_to_type(term, &tipo);
if count == 0 {
term = Term::Lambda {
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
term = term.lambda(temp_var);
}
}
DefaultFunction::UnConstrData => {
@ -4440,72 +4362,23 @@ impl<'a> CodeGenerator<'a> {
let temp_var = format!("__item_{}", self.id_gen.next());
if count == 0 {
term = apply_wrap(
term,
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
),
);
term = term.apply(Term::var(temp_var.clone()));
}
term = apply_wrap(
Term::Lambda {
parameter_name: Name {
text: temp_tuple.clone(),
unique: 0.into(),
}
.into(),
body: apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkPairData),
apply_wrap(
Term::Builtin(DefaultFunction::IData),
apply_wrap(
Term::Builtin(DefaultFunction::FstPair)
.force()
.force(),
Term::Var(
Name {
text: temp_tuple.clone(),
unique: 0.into(),
}
.into(),
),
),
),
),
apply_wrap(
Term::Builtin(DefaultFunction::ListData),
apply_wrap(
Term::Builtin(DefaultFunction::SndPair).force().force(),
Term::Var(
Name {
text: temp_tuple,
unique: 0.into(),
}
.into(),
),
),
),
)
.into(),
},
term,
);
term = Term::mk_pair_data()
.apply(
Term::i_data()
.apply(Term::fst_pair().apply(Term::var(temp_tuple.clone()))),
)
.apply(
Term::list_data()
.apply(Term::snd_pair().apply(Term::var(temp_tuple.clone()))),
)
.lambda(temp_tuple)
.apply(term);
if count == 0 {
term = Term::Lambda {
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
term = term.lambda(temp_var);
}
}
DefaultFunction::MkCons => {
@ -4521,29 +4394,11 @@ impl<'a> CodeGenerator<'a> {
if count == 0 {
for (index, temp_var) in temp_vars.iter().enumerate() {
term = apply_wrap(
term,
if index > 0 {
Term::Delay(
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
)
.into(),
)
} else {
Term::Var(
Name {
text: temp_var.clone(),
unique: 0.into(),
}
.into(),
)
},
);
term = term.apply(if index > 0 {
Term::var(temp_var.clone()).delay()
} else {
Term::var(temp_var.clone())
});
}
}
@ -4551,14 +4406,7 @@ impl<'a> CodeGenerator<'a> {
if count == 0 {
for temp_var in temp_vars.into_iter().rev() {
term = Term::Lambda {
parameter_name: Name {
text: temp_var,
unique: 0.into(),
}
.into(),
body: term.into(),
};
term = term.lambda(temp_var);
}
}
}
@ -4570,86 +4418,42 @@ impl<'a> CodeGenerator<'a> {
let left = arg_stack.pop().unwrap();
let right = arg_stack.pop().unwrap();
let default_builtin = if tipo.is_int() {
DefaultFunction::EqualsInteger
let builtin = if tipo.is_int() {
Term::equals_integer()
} else if tipo.is_string() {
DefaultFunction::EqualsString
Term::equals_string()
} else if tipo.is_bytearray() {
DefaultFunction::EqualsByteString
Term::equals_bytestring()
} else {
DefaultFunction::EqualsData
Term::equals_data()
};
let term = match name {
BinOp::And => delayed_if_else(
left,
right,
Term::Constant(UplcConstant::Bool(false).into()),
),
BinOp::Or => delayed_if_else(
left,
Term::Constant(UplcConstant::Bool(true).into()),
right,
),
BinOp::And => left.delayed_if_else(right, Term::bool(false)),
BinOp::Or => left.delayed_if_else(Term::bool(true), right),
BinOp::Eq => {
if tipo.is_bool() {
let term = delayed_if_else(
left,
let term = left.delayed_if_else(
right.clone(),
if_else(
right,
Term::Constant(UplcConstant::Bool(false).into()),
Term::Constant(UplcConstant::Bool(true).into()),
),
right.if_else(Term::bool(false), Term::bool(true)),
);
arg_stack.push(term);
return;
} else if tipo.is_map() {
let term = apply_wrap(
apply_wrap(
default_builtin.into(),
apply_wrap(DefaultFunction::MapData.into(), left),
),
apply_wrap(DefaultFunction::MapData.into(), right),
);
let term = builtin
.apply(Term::map_data().apply(left))
.apply(Term::map_data().apply(right));
arg_stack.push(term);
return;
} else if tipo.is_tuple()
&& matches!(tipo.get_uplc_type(), UplcType::Pair(_, _))
{
let term = apply_wrap(
apply_wrap(
default_builtin.into(),
apply_wrap(
DefaultFunction::MapData.into(),
apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkCons).force(),
left,
),
Term::Constant(
UplcConstant::ProtoList(
UplcType::Pair(
UplcType::Data.into(),
UplcType::Data.into(),
),
vec![],
)
.into(),
),
),
),
),
apply_wrap(
DefaultFunction::MapData.into(),
apply_wrap(
apply_wrap(
Term::Builtin(DefaultFunction::MkCons).force(),
right,
),
Term::Constant(
let term = builtin
.apply(
Term::map_data().apply(
Term::mk_cons().apply(left).apply(Term::Constant(
UplcConstant::ProtoList(
UplcType::Pair(
UplcType::Data.into(),
@ -4658,30 +4462,39 @@ impl<'a> CodeGenerator<'a> {
vec![],
)
.into(),
),
)),
),
),
);
)
.apply(
Term::map_data().apply(
Term::mk_cons().apply(right).apply(Term::Constant(
UplcConstant::ProtoList(
UplcType::Pair(
UplcType::Data.into(),
UplcType::Data.into(),
),
vec![],
)
.into(),
)),
),
);
arg_stack.push(term);
return;
} else if tipo.is_list() || tipo.is_tuple() {
let term = apply_wrap(
apply_wrap(
default_builtin.into(),
apply_wrap(DefaultFunction::ListData.into(), left),
),
apply_wrap(DefaultFunction::ListData.into(), right),
);
let term = builtin
.apply(Term::list_data().apply(left))
.apply(Term::list_data().apply(right));
arg_stack.push(term);
return;
} else if tipo.is_void() {
arg_stack.push(Term::Constant(UplcConstant::Bool(true).into()));
arg_stack.push(Term::bool(true));
return;
}
apply_wrap(apply_wrap(default_builtin.into(), left), right)
builtin.apply(left).apply(right)
}
BinOp::NotEq => {
if tipo.is_bool() {
@ -4700,7 +4513,7 @@ impl<'a> CodeGenerator<'a> {
let term = if_else(
apply_wrap(
apply_wrap(
default_builtin.into(),
builtin,
apply_wrap(DefaultFunction::MapData.into(), left),
),
apply_wrap(DefaultFunction::MapData.into(), right),
@ -4716,7 +4529,7 @@ impl<'a> CodeGenerator<'a> {
{
let mut term = apply_wrap(
apply_wrap(
default_builtin.into(),
builtin,
apply_wrap(
DefaultFunction::MapData.into(),
apply_wrap(
@ -4769,7 +4582,7 @@ impl<'a> CodeGenerator<'a> {
let term = if_else(
apply_wrap(
apply_wrap(
default_builtin.into(),
builtin,
apply_wrap(DefaultFunction::ListData.into(), left),
),
apply_wrap(DefaultFunction::ListData.into(), right),
@ -4786,7 +4599,7 @@ impl<'a> CodeGenerator<'a> {
}
if_else(
apply_wrap(apply_wrap(default_builtin.into(), left), right),
apply_wrap(apply_wrap(builtin, left), right),
Term::Constant(UplcConstant::Bool(false).into()),
Term::Constant(UplcConstant::Bool(true).into()),
)

View File

@ -38,32 +38,93 @@ impl Term<Name> {
Term::Constant(Constant::Integer(i).into())
}
pub fn string(s: String) -> Self {
Term::Constant(Constant::String(s).into())
}
pub fn byte_string(b: Vec<u8>) -> Self {
Term::Constant(Constant::ByteString(b).into())
}
pub fn bool(b: bool) -> Self {
Term::Constant(Constant::Bool(b).into())
}
pub fn constr_data() -> Self {
Term::Builtin(DefaultFunction::ConstrData)
}
pub fn map_data() -> Self {
Term::Builtin(DefaultFunction::MapData)
}
pub fn list_data() -> Self {
Term::Builtin(DefaultFunction::ListData)
}
pub fn b_data() -> Self {
Term::Builtin(DefaultFunction::BData)
}
pub fn i_data() -> Self {
Term::Builtin(DefaultFunction::IData)
}
pub fn equals_integer() -> Self {
Term::Builtin(DefaultFunction::EqualsInteger)
}
pub fn equals_string() -> Self {
Term::Builtin(DefaultFunction::EqualsString)
}
pub fn equals_bytestring() -> Self {
Term::Builtin(DefaultFunction::EqualsByteString)
}
pub fn equals_data() -> Self {
Term::Builtin(DefaultFunction::EqualsData)
}
pub fn head_list() -> Self {
Term::Builtin(DefaultFunction::HeadList).force()
}
pub fn tail_list() -> Self {
Term::Builtin(DefaultFunction::TailList).force()
}
pub fn mk_cons() -> Self {
Term::Builtin(DefaultFunction::MkCons).force()
}
pub fn fst_pair() -> Self {
Term::Builtin(DefaultFunction::FstPair).force().force()
}
pub fn snd_pair() -> Self {
Term::Builtin(DefaultFunction::SndPair).force().force()
}
pub fn mk_pair_data() -> Self {
Term::Builtin(DefaultFunction::MkPairData)
}
pub fn if_else(self, then_term: Self, else_term: Self) -> Self {
Term::Builtin(DefaultFunction::IfThenElse)
.force()
.apply(self)
.apply(then_term)
.apply(else_term)
}
pub fn delayed_if_else(self, then_term: Self, else_term: Self) -> Self {
Term::Apply {
function: Term::Apply {
function: Term::Apply {
function: Term::Builtin(DefaultFunction::IfThenElse).force().into(),
argument: self.into(),
}
.into(),
argument: Term::Delay(then_term.into()).into(),
}
.into(),
argument: Term::Delay(else_term.into()).into(),
}
.force()
Term::Builtin(DefaultFunction::IfThenElse)
.force()
.apply(self)
.apply(then_term.delay())
.apply(else_term.delay())
.force()
}
}