Define a safer API for unwrap_xxx_or and choose_data_xxx

Cloning a 'Term' is potentially dangerous, so we don't want this to
  happen by mistake. So instead, we pass in var names and turn them into
  terms when necessary.
This commit is contained in:
KtorZ 2024-08-01 16:45:30 +02:00 committed by Kasey
parent 846c16087e
commit aefbc6e1b9
2 changed files with 93 additions and 88 deletions

View File

@ -946,9 +946,7 @@ pub fn softcast_data_to_type_otherwise(
let then_delayed = |v| then.lambda(name).apply(v).delay(); let then_delayed = |v| then.lambda(name).apply(v).delay();
value.as_var("__val", |val| match uplc_type { value.as_var("__val", |val| match uplc_type {
None => val None => Term::choose_data_constr(val, then_delayed, &otherwise_delayed).force(),
.choose_data_constr(then_delayed, &otherwise_delayed)
.force(),
Some(UplcType::Data) => just_then, Some(UplcType::Data) => just_then,
@ -956,59 +954,59 @@ pub fn softcast_data_to_type_otherwise(
unreachable!("attempted to cast Data into Bls12_381MlResult ?!") unreachable!("attempted to cast Data into Bls12_381MlResult ?!")
} }
Some(UplcType::Integer) => val Some(UplcType::Integer) => {
.choose_data_integer(then_delayed, &otherwise_delayed) Term::choose_data_integer(val, then_delayed, &otherwise_delayed).force()
.force(), }
Some(UplcType::ByteString) => val Some(UplcType::ByteString) => {
.choose_data_bytearray(then_delayed, &otherwise_delayed) Term::choose_data_bytearray(val, then_delayed, &otherwise_delayed).force()
.force(), }
Some(UplcType::String) => val Some(UplcType::String) => Term::choose_data_bytearray(
.choose_data_bytearray( val,
|bytes| then_delayed(Term::decode_utf8().apply(bytes)), |bytes| then_delayed(Term::decode_utf8().apply(bytes)),
&otherwise_delayed, &otherwise_delayed,
) )
.force(), .force(),
Some(UplcType::List(_)) if field_type.is_map() => val Some(UplcType::List(_)) if field_type.is_map() => {
.choose_data_map(then_delayed, &otherwise_delayed) Term::choose_data_map(val, then_delayed, &otherwise_delayed).force()
.force(), }
Some(UplcType::List(_)) => val Some(UplcType::List(_)) => {
.choose_data_list(then_delayed, &otherwise_delayed) Term::choose_data_list(val, then_delayed, &otherwise_delayed).force()
.force(), }
Some(UplcType::Bls12_381G1Element) => val Some(UplcType::Bls12_381G1Element) => Term::choose_data_bytearray(
.choose_data_bytearray( val,
|bytes| then_delayed(Term::bls12_381_g1_uncompress().apply(bytes)), |bytes| then_delayed(Term::bls12_381_g1_uncompress().apply(bytes)),
&otherwise_delayed, &otherwise_delayed,
) )
.force(), .force(),
Some(UplcType::Bls12_381G2Element) => val Some(UplcType::Bls12_381G2Element) => Term::choose_data_bytearray(
.choose_data_bytearray( val,
|bytes| then_delayed(Term::bls12_381_g2_uncompress().apply(bytes)), |bytes| then_delayed(Term::bls12_381_g2_uncompress().apply(bytes)),
&otherwise_delayed, &otherwise_delayed,
) )
.force(), .force(),
Some(UplcType::Pair(_, _)) => val Some(UplcType::Pair(_, _)) => Term::choose_data_list(
.choose_data_list( val,
|list| list.unwrap_pair_or(then_delayed, &otherwise_delayed), |list| list.unwrap_pair_or(then_delayed, &otherwise_delayed),
&otherwise_delayed, &otherwise_delayed,
) )
.force(), .force(),
Some(UplcType::Bool) => val Some(UplcType::Bool) => Term::choose_data_constr(
.choose_data_constr( val,
|constr| constr.unwrap_bool_or(then_delayed, &otherwise_delayed), |constr| constr.unwrap_bool_or(then_delayed, &otherwise_delayed),
&otherwise_delayed, &otherwise_delayed,
) )
.force(), .force(),
Some(UplcType::Unit) => val Some(UplcType::Unit) => Term::choose_data_constr(
.choose_data_constr( val,
|constr| constr.unwrap_void_or(then_delayed, &otherwise_delayed), |constr| constr.unwrap_void_or(then_delayed, &otherwise_delayed),
&otherwise_delayed, &otherwise_delayed,
) )

View File

@ -3,6 +3,7 @@ use crate::{
builtins::DefaultFunction, builtins::DefaultFunction,
}; };
use pallas_primitives::alonzo::PlutusData; use pallas_primitives::alonzo::PlutusData;
use std::rc::Rc;
pub const CONSTR_FIELDS_EXPOSER: &str = "__constr_fields_exposer"; pub const CONSTR_FIELDS_EXPOSER: &str = "__constr_fields_exposer";
pub const CONSTR_INDEX_EXPOSER: &str = "__constr_index_exposer"; pub const CONSTR_INDEX_EXPOSER: &str = "__constr_index_exposer";
@ -499,51 +500,53 @@ impl Term<Name> {
/// ``` /// ```
pub fn as_var<F>(self, var_name: &str, callback: F) -> Term<Name> pub fn as_var<F>(self, var_name: &str, callback: F) -> Term<Name>
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Rc<Name>) -> Term<Name>,
{ {
callback(Term::var(var_name)).lambda(var_name).apply(self) callback(Name::text(var_name).into())
.lambda(var_name)
.apply(self)
} }
/// Continue a computation provided that the current term is a Data-wrapped integer. /// Continue a computation provided that the current term is a Data-wrapped integer.
/// The 'callback' receives an integer constant Term as argument. /// The 'callback' receives an integer constant Term as argument.
pub fn choose_data_integer<F>(self, callback: F, otherwise: &Term<Name>) -> Self pub fn choose_data_integer<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Self
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone().choose_data( Term::Var(var.clone()).choose_data(
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
callback(Term::un_i_data().apply(self)), callback(Term::un_i_data().apply(Term::Var(var))),
otherwise.clone(), otherwise.clone(),
) )
} }
/// Continue a computation provided that the current term is a Data-wrapped /// Continue a computation provided that the current term is a Data-wrapped
/// bytearray. The 'callback' receives a bytearray constant Term as argument. /// bytearray. The 'callback' receives a bytearray constant Term as argument.
pub fn choose_data_bytearray<F>(self, callback: F, otherwise: &Term<Name>) -> Self pub fn choose_data_bytearray<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Self
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone().choose_data( Term::Var(var.clone()).choose_data(
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
callback(Term::un_b_data().apply(self)), callback(Term::un_b_data().apply(Term::Var(var))),
) )
} }
/// Continue a computation provided that the current term is a Data-wrapped /// Continue a computation provided that the current term is a Data-wrapped
/// list. The 'callback' receives a ProtoList Term as argument. /// list. The 'callback' receives a ProtoList Term as argument.
pub fn choose_data_list<F>(self, callback: F, otherwise: &Term<Name>) -> Self pub fn choose_data_list<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Self
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone().choose_data( Term::Var(var.clone()).choose_data(
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
callback(Term::unlist_data().apply(self)), callback(Term::unlist_data().apply(Term::Var(var))),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
) )
@ -551,13 +554,13 @@ impl Term<Name> {
/// Continue a computation provided that the current term is a Data-wrapped /// Continue a computation provided that the current term is a Data-wrapped
/// list. The 'callback' receives a ProtoMap Term as argument. /// list. The 'callback' receives a ProtoMap Term as argument.
pub fn choose_data_map<F>(self, callback: F, otherwise: &Term<Name>) -> Self pub fn choose_data_map<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Self
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone().choose_data( Term::Var(var.clone()).choose_data(
otherwise.clone(), otherwise.clone(),
callback(Term::unmap_data().apply(self)), callback(Term::unmap_data().apply(Term::Var(var))),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
@ -566,12 +569,12 @@ impl Term<Name> {
/// Continue a computation provided that the current term is a Data-wrapped /// Continue a computation provided that the current term is a Data-wrapped
/// constr. The 'callback' receives a Data as argument. /// constr. The 'callback' receives a Data as argument.
pub fn choose_data_constr<F>(self, callback: F, otherwise: &Term<Name>) -> Self pub fn choose_data_constr<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Self
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone().choose_data( Term::Var(var.clone()).choose_data(
callback(self), callback(Term::Var(var)),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
otherwise.clone(), otherwise.clone(),
@ -591,21 +594,21 @@ impl Term<Name> {
Term::unconstr_data() Term::unconstr_data()
.apply(self) .apply(self)
.as_var("__pair__", |pair| { .as_var("__pair__", |pair| {
Term::snd_pair().apply(pair.clone()).choose_list( Term::snd_pair().apply(Term::Var(pair.clone())).choose_list(
Term::less_than_equals_integer() Term::less_than_equals_integer()
.apply(Term::integer(2.into())) .apply(Term::integer(2.into()))
.apply(Term::fst_pair().apply(pair.clone())) .apply(Term::fst_pair().apply(Term::Var(pair.clone())))
.if_then_else( .if_then_else(
otherwise.clone(), otherwise.clone(),
Term::less_than_integer() Term::less_than_integer()
.apply(Term::fst_pair().apply(pair.clone())) .apply(Term::fst_pair().apply(Term::Var(pair.clone())))
.apply(Term::integer(0.into())) .apply(Term::integer(0.into()))
.if_then_else( .if_then_else(
otherwise.clone(), otherwise.clone(),
callback( callback(
Term::equals_integer() Term::equals_integer()
.apply(Term::integer(1.into())) .apply(Term::integer(1.into()))
.apply(Term::fst_pair().apply(pair)), .apply(Term::fst_pair().apply(Term::Var(pair))),
), ),
), ),
), ),
@ -647,12 +650,14 @@ impl Term<Name> {
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.as_var("__list_data", |list| { self.as_var("__list_data", |list| {
let left = Term::head_list().apply(list.clone()); let left = Term::head_list().apply(Term::Var(list.clone()));
list.unwrap_tail_or( Term::unwrap_tail_or(
list,
|tail| { |tail| {
tail.as_var("__tail", |tail| { tail.as_var("__tail", |tail| {
let right = Term::head_list().apply(tail.clone()); let right = Term::head_list().apply(Term::Var(tail.clone()));
tail.unwrap_tail_or( Term::unwrap_tail_or(
tail,
|leftovers| { |leftovers| {
leftovers leftovers
.choose_list( .choose_list(
@ -674,12 +679,14 @@ impl Term<Name> {
} }
/// Continue with the tail of a list, if any; or fallback 'otherwise'. /// Continue with the tail of a list, if any; or fallback 'otherwise'.
pub fn unwrap_tail_or<F>(self, callback: F, otherwise: &Term<Name>) -> Term<Name> pub fn unwrap_tail_or<F>(var: Rc<Name>, callback: F, otherwise: &Term<Name>) -> Term<Name>
where where
F: FnOnce(Term<Name>) -> Term<Name>, F: FnOnce(Term<Name>) -> Term<Name>,
{ {
self.clone() Term::Var(var.clone()).choose_list(
.choose_list(otherwise.clone(), callback(Term::tail_list().apply(self))) otherwise.clone(),
callback(Term::tail_list().apply(Term::Var(var))),
)
} }
} }
@ -749,11 +756,9 @@ mod tests {
#[test] #[test]
fn unwrap_tail_or_0_elems() { fn unwrap_tail_or_0_elems() {
let result = quick_eval( let result = quick_eval(Term::list_values(vec![]).as_var("__tail", |tail| {
Term::list_values(vec![]) Term::unwrap_tail_or(tail, |p| p.delay(), &Term::Error.delay()).force()
.unwrap_tail_or(|p| p.delay(), &Term::Error.delay()) }));
.force(),
);
assert_eq!(result, Err(Error::EvaluationFailure)); assert_eq!(result, Err(Error::EvaluationFailure));
} }
@ -762,8 +767,9 @@ mod tests {
fn unwrap_tail_or_1_elem() { fn unwrap_tail_or_1_elem() {
let result = quick_eval( let result = quick_eval(
Term::list_values(vec![Constant::Data(Data::integer(1.into()))]) Term::list_values(vec![Constant::Data(Data::integer(1.into()))])
.unwrap_tail_or(|p| p.delay(), &Term::Error.delay()) .as_var("__tail", |tail| {
.force(), Term::unwrap_tail_or(tail, |p| p.delay(), &Term::Error.delay()).force()
}),
); );
assert_eq!(result, Ok(Term::list_values(vec![])),); assert_eq!(result, Ok(Term::list_values(vec![])),);
@ -776,8 +782,9 @@ mod tests {
Constant::Data(Data::integer(1.into())), Constant::Data(Data::integer(1.into())),
Constant::Data(Data::integer(2.into())), Constant::Data(Data::integer(2.into())),
]) ])
.unwrap_tail_or(|p| p.delay(), &Term::Error.delay()) .as_var("__tail", |tail| {
.force(), Term::unwrap_tail_or(tail, |p| p.delay(), &Term::Error.delay()).force()
}),
); );
assert_eq!( assert_eq!(