diff --git a/crates/flat/src/decode.rs b/crates/flat/src/decode.rs index 92031d56..e74ecc7f 100644 --- a/crates/flat/src/decode.rs +++ b/crates/flat/src/decode.rs @@ -17,6 +17,24 @@ impl Decode<'_> for Vec { } } +impl Decode<'_> for u8 { + fn decode(d: &mut Decoder) -> Result { + d.u8() + } +} + +impl Decode<'_> for isize { + fn decode(d: &mut Decoder) -> Result { + d.integer() + } +} + +impl Decode<'_> for usize { + fn decode(d: &mut Decoder) -> Result { + d.word() + } +} + impl Decode<'_> for char { fn decode(d: &mut Decoder) -> Result { d.char() @@ -31,7 +49,6 @@ impl Decode<'_> for String { impl Decode<'_> for bool { fn decode(d: &mut Decoder) -> Result { - let b = d.bool(); - Ok(b) + d.bool() } } diff --git a/crates/flat/src/decoder.rs b/crates/flat/src/decoder.rs index 95ca1362..e05cb621 100644 --- a/crates/flat/src/decoder.rs +++ b/crates/flat/src/decoder.rs @@ -1,4 +1,4 @@ -use crate::decode::Decode; +use crate::{decode::Decode, zigzag}; pub struct Decoder<'b> { buffer: &'b [u8], @@ -19,11 +19,19 @@ impl<'b> Decoder<'b> { T::decode(self) } - pub fn bool(&mut self) -> bool { + pub fn integer(&mut self) -> Result { + Ok(zigzag::to_isize(self.word()?)) + } + + pub fn bool(&mut self) -> Result { let current_byte = self.buffer[self.pos]; let b = 0 != (current_byte & (128 >> self.used_bits)); self.increment_buffer_by_bit(); - b + Ok(b) + } + + pub fn u8(&mut self) -> Result { + self.bits8(8) } pub fn bytes(&mut self) -> Result, String> { @@ -48,7 +56,7 @@ impl<'b> Decoder<'b> { Ok(()) } - fn word(&mut self) -> Result { + pub fn word(&mut self) -> Result { let mut leading_bit = 1; let mut final_word: usize = 0; let mut shl: usize = 0; @@ -63,6 +71,17 @@ impl<'b> Decoder<'b> { Ok(final_word) } + pub fn decode_list_with>( + &mut self, + decoder_func: for<'r> fn(&'r mut Decoder) -> Result, + ) -> Result, String> { + let mut vec_array: Vec = Vec::new(); + while self.bit()? { + vec_array.push(decoder_func(self)?) + } + Ok(vec_array) + } + fn zero(&mut self) -> Result { let current_bit = self.bit()?; Ok(!current_bit) @@ -102,7 +121,7 @@ impl<'b> Decoder<'b> { "Decoder.bits8: incorrect value of num_bits - must be less than 9".to_string(), ); } - self.ensure_bytes(num_bits)?; + self.ensure_bits(num_bits)?; let unused_bits = 8 - self.used_bits as usize; let leading_zeroes = 8 - num_bits; let r = (self.buffer[self.pos] << self.used_bits as usize) >> leading_zeroes; @@ -112,12 +131,12 @@ impl<'b> Decoder<'b> { } else { r }; - self.drop_bits(8); + self.drop_bits(num_bits); Ok(x) } fn ensure_bytes(&mut self, required_bytes: usize) -> Result<(), String> { - if required_bytes > self.buffer.len() - self.pos { + if required_bytes as isize > self.buffer.len() as isize - self.pos as isize { return Err(format!( "DecoderState: Not enough data available: {:#?} - required bytes {}", self.buffer, required_bytes @@ -126,6 +145,18 @@ impl<'b> Decoder<'b> { Ok(()) } + fn ensure_bits(&mut self, required_bits: usize) -> Result<(), String> { + if required_bits as isize + > (self.buffer.len() as isize - self.pos as isize) * 8 - self.used_bits as isize + { + return Err(format!( + "DecoderState: Not enough data available: {:#?} - required bits {}", + self.buffer, required_bits + )); + } + Ok(()) + } + fn drop_bits(&mut self, num_bits: usize) { let all_used_bits = num_bits as i64 + self.used_bits; self.used_bits = all_used_bits % 8; diff --git a/crates/uplc/src/builtins.rs b/crates/uplc/src/builtins.rs index 9231c494..638ed6f1 100644 --- a/crates/uplc/src/builtins.rs +++ b/crates/uplc/src/builtins.rs @@ -78,3 +78,115 @@ pub enum DefaultFunction { MkNilData = 52, MkNilPairData = 53, } + +impl TryFrom for DefaultFunction { + fn try_from(v: u8) -> Result { + match v { + v if v == DefaultFunction::AddInteger as u8 => Ok(DefaultFunction::AddInteger), + v if v == DefaultFunction::SubtractInteger as u8 => { + Ok(DefaultFunction::SubtractInteger) + } + v if v == DefaultFunction::MultiplyInteger as u8 => { + Ok(DefaultFunction::MultiplyInteger) + } + v if v == DefaultFunction::DivideInteger as u8 => Ok(DefaultFunction::DivideInteger), + v if v == DefaultFunction::QuotientInteger as u8 => { + Ok(DefaultFunction::QuotientInteger) + } + v if v == DefaultFunction::RemainderInteger as u8 => { + Ok(DefaultFunction::RemainderInteger) + } + v if v == DefaultFunction::ModInteger as u8 => Ok(DefaultFunction::ModInteger), + v if v == DefaultFunction::EqualsInteger as u8 => Ok(DefaultFunction::EqualsInteger), + v if v == DefaultFunction::LessThanInteger as u8 => { + Ok(DefaultFunction::LessThanInteger) + } + v if v == DefaultFunction::LessThanEqualsInteger as u8 => { + Ok(DefaultFunction::LessThanEqualsInteger) + } + // ByteString functions + v if v == DefaultFunction::AppendByteString as u8 => { + Ok(DefaultFunction::AppendByteString) + } + v if v == DefaultFunction::ConsByteString as u8 => Ok(DefaultFunction::ConsByteString), + v if v == DefaultFunction::SliceByteString as u8 => { + Ok(DefaultFunction::SliceByteString) + } + v if v == DefaultFunction::LengthOfByteString as u8 => { + Ok(DefaultFunction::LengthOfByteString) + } + v if v == DefaultFunction::IndexByteString as u8 => { + Ok(DefaultFunction::IndexByteString) + } + v if v == DefaultFunction::EqualsByteString as u8 => { + Ok(DefaultFunction::EqualsByteString) + } + v if v == DefaultFunction::LessThanByteString as u8 => { + Ok(DefaultFunction::LessThanByteString) + } + v if v == DefaultFunction::LessThanEqualsByteString as u8 => { + Ok(DefaultFunction::LessThanEqualsByteString) + } + // Cryptography and hash functions + v if v == DefaultFunction::Sha2_256 as u8 => Ok(DefaultFunction::Sha2_256), + v if v == DefaultFunction::Sha3_256 as u8 => Ok(DefaultFunction::Sha3_256), + v if v == DefaultFunction::Blake2b_256 as u8 => Ok(DefaultFunction::Blake2b_256), + v if v == DefaultFunction::VerifySignature as u8 => { + Ok(DefaultFunction::VerifySignature) + } + v if v == DefaultFunction::VerifyEcdsaSecp256k1Signature as u8 => { + Ok(DefaultFunction::VerifyEcdsaSecp256k1Signature) + } + v if v == DefaultFunction::VerifySchnorrSecp256k1Signature as u8 => { + Ok(DefaultFunction::VerifySchnorrSecp256k1Signature) + } + // String functions + v if v == DefaultFunction::AppendString as u8 => Ok(DefaultFunction::AppendString), + v if v == DefaultFunction::EqualsString as u8 => Ok(DefaultFunction::EqualsString), + v if v == DefaultFunction::EncodeUtf8 as u8 => Ok(DefaultFunction::EncodeUtf8), + v if v == DefaultFunction::DecodeUtf8 as u8 => Ok(DefaultFunction::DecodeUtf8), + // Bool function + v if v == DefaultFunction::IfThenElse as u8 => Ok(DefaultFunction::IfThenElse), + // Unit function + v if v == DefaultFunction::ChooseUnit as u8 => Ok(DefaultFunction::ChooseUnit), + // Tracing function + v if v == DefaultFunction::Trace as u8 => Ok(DefaultFunction::Trace), + // Pairs functions + v if v == DefaultFunction::FstPair as u8 => Ok(DefaultFunction::FstPair), + v if v == DefaultFunction::SndPair as u8 => Ok(DefaultFunction::SndPair), + // List functions + v if v == DefaultFunction::ChooseList as u8 => Ok(DefaultFunction::ChooseList), + v if v == DefaultFunction::MkCons as u8 => Ok(DefaultFunction::MkCons), + v if v == DefaultFunction::HeadList as u8 => Ok(DefaultFunction::HeadList), + v if v == DefaultFunction::TailList as u8 => Ok(DefaultFunction::TailList), + v if v == DefaultFunction::NullList as u8 => Ok(DefaultFunction::NullList), + // Data functions + // It is convenient to have a "choosing" function for a data type that has more than two + // constructors to get pattern matching over it and we may end up having multiple such data + // types, hence we include the name of the data type as a suffix. + v if v == DefaultFunction::ChooseData as u8 => Ok(DefaultFunction::ChooseData), + v if v == DefaultFunction::ConstrData as u8 => Ok(DefaultFunction::ConstrData), + v if v == DefaultFunction::MapData as u8 => Ok(DefaultFunction::MapData), + v if v == DefaultFunction::ListData as u8 => Ok(DefaultFunction::ListData), + v if v == DefaultFunction::IData as u8 => Ok(DefaultFunction::IData), + v if v == DefaultFunction::BData as u8 => Ok(DefaultFunction::BData), + v if v == DefaultFunction::UnConstrData as u8 => Ok(DefaultFunction::UnConstrData), + v if v == DefaultFunction::UnMapData as u8 => Ok(DefaultFunction::UnMapData), + v if v == DefaultFunction::UnListData as u8 => Ok(DefaultFunction::UnListData), + v if v == DefaultFunction::UnIData as u8 => Ok(DefaultFunction::UnIData), + v if v == DefaultFunction::UnBData as u8 => Ok(DefaultFunction::UnBData), + v if v == DefaultFunction::EqualsData as u8 => Ok(DefaultFunction::EqualsData), + v if v == DefaultFunction::SerialiseData as u8 => Ok(DefaultFunction::SerialiseData), + // Misc constructors + // Constructors that we need for constructing e.g. Data. Polymorphic builtin + // constructors are often problematic (See note [Representable built-in + // functions over polymorphic built-in types]) + v if v == DefaultFunction::MkPairData as u8 => Ok(DefaultFunction::MkPairData), + v if v == DefaultFunction::MkNilData as u8 => Ok(DefaultFunction::MkNilData), + v if v == DefaultFunction::MkNilPairData as u8 => Ok(DefaultFunction::MkNilPairData), + _ => Err("Default Function not found".to_string()), + } + } + + type Error = String; +} diff --git a/crates/uplc/src/flat.rs b/crates/uplc/src/flat.rs index 5fc020f9..ccce887c 100644 --- a/crates/uplc/src/flat.rs +++ b/crates/uplc/src/flat.rs @@ -48,8 +48,10 @@ impl Encode for Program { } impl<'b> Decode<'b> for Program { - fn decode(_d: &mut Decoder) -> Result { - todo!() + fn decode(d: &mut Decoder) -> Result { + let version = (usize::decode(d)?, usize::decode(d)?, usize::decode(d)?); + let term = Term::decode(d)?; + Ok(Program { version, term }) } } @@ -104,8 +106,22 @@ impl Encode for Term { } impl<'b> Decode<'b> for Term { - fn decode(_d: &mut Decoder) -> Result { - todo!() + fn decode(d: &mut Decoder) -> Result { + match decode_term_tag(d)? { + 0 => Ok(Term::Var(String::decode(d)?)), + 1 => Ok(Term::Delay(Box::new(Term::decode(d)?))), + 2 => todo!(), + 3 => Ok(Term::Apply { + function: Box::new(Term::decode(d)?), + argument: Box::new(Term::decode(d)?), + }), + // Need size limit for Constant + 4 => Ok(Term::Constant(Constant::decode(d)?)), + 5 => Ok(Term::Force(Box::new(Term::decode(d)?))), + 6 => Ok(Term::Error), + 7 => Ok(Term::Builtin(DefaultFunction::decode(d)?)), + x => Err(format!("Unknown term constructor tag: {}", x)), + } } } @@ -146,8 +162,17 @@ impl Encode for &Constant { } impl<'b> Decode<'b> for Constant { - fn decode(_d: &mut Decoder) -> Result { - todo!() + fn decode(d: &mut Decoder) -> Result { + match decode_constant(d)? { + 0 => Ok(Constant::Integer(isize::decode(d)?)), + 1 => Ok(Constant::ByteString(Vec::::decode(d)?)), + 2 => Ok(Constant::String( + String::from_utf8(Vec::::decode(d)?).unwrap(), + )), + 3 => Ok(Constant::Unit), + 4 => Ok(Constant::Bool(bool::decode(d)?)), + x => Err(format!("Unknown constant constructor tag: {}", x)), + } } } @@ -160,8 +185,9 @@ impl Encode for DefaultFunction { } impl<'b> Decode<'b> for DefaultFunction { - fn decode(_d: &mut Decoder) -> Result { - todo!() + fn decode(d: &mut Decoder) -> Result { + let builtin_tag = d.bits8(BUILTIN_TAG_WIDTH as usize)?; + builtin_tag.try_into() } } @@ -169,6 +195,10 @@ fn encode_term_tag(tag: u8, e: &mut Encoder) -> Result<(), String> { safe_encode_bits(TERM_TAG_WIDTH, tag, e) } +fn decode_term_tag(d: &mut Decoder) -> Result { + d.bits8(TERM_TAG_WIDTH as usize) +} + fn safe_encode_bits(num_bits: u32, byte: u8, e: &mut Encoder) -> Result<(), String> { if 2_u8.pow(num_bits) < byte { Err(format!( @@ -185,10 +215,26 @@ pub fn encode_constant(tag: u8, e: &mut Encoder) -> Result<(), String> { e.encode_list_with(encode_constant_tag, [tag].to_vec()) } +pub fn decode_constant(d: &mut Decoder) -> Result { + let u8_list = d.decode_list_with(decode_constant_tag)?; + if u8_list.len() > 1 { + Err( + "Improper encoding on constant tag. Should be list of one item encoded in 4 bits" + .to_string(), + ) + } else { + Ok(u8_list[0]) + } +} + pub fn encode_constant_tag(tag: u8, e: &mut Encoder) -> Result<(), String> { safe_encode_bits(CONST_TAG_WIDTH, tag, e) } +pub fn decode_constant_tag(d: &mut Decoder) -> Result { + d.bits8(CONST_TAG_WIDTH as usize) +} + #[cfg(test)] mod test { use super::{Constant, Program, Term};