diff --git a/crates/flat/src/encoder.rs b/crates/flat/src/encoder.rs index 229efca0..436b7c42 100644 --- a/crates/flat/src/encoder.rs +++ b/crates/flat/src/encoder.rs @@ -43,10 +43,12 @@ impl Encoder { /// Encode a `bool` value. pub fn bool(&mut self, x: bool) -> &mut Self { if x { - self.one() + self.one(); } else { - self.zero() + self.zero(); } + + self } pub fn bytes(&mut self, arr: &[u8]) -> Result<&mut Self, String> { @@ -73,17 +75,15 @@ impl Encoder { Ok(self) } - fn zero(&mut self) -> &mut Self { + fn zero(&mut self) { if self.used_bits == 7 { self.next_word(); } else { self.used_bits += 1; } - - self } - fn one(&mut self) -> &mut Self { + fn one(&mut self) { if self.used_bits == 7 { self.current_byte |= 1; self.next_word(); @@ -91,8 +91,6 @@ impl Encoder { self.current_byte |= 128 >> self.used_bits; self.used_bits += 1; } - - self } fn byte_unaligned(&mut self, x: u8) { @@ -139,23 +137,54 @@ impl Encoder { } } - fn bits(&mut self, num_bits: i64, val: u8) { - self.used_bits += num_bits; - let unused_bits = 8 - self.used_bits; - match unused_bits { - x if x > 0 => { - self.current_byte |= val << x; + pub fn encode_list_with(&mut self, list: Vec) -> Result<(), String> { + for item in list { + self.one(); + self.encode(item)?; + } + self.zero(); + Ok(()) + } + + pub fn bits(&mut self, num_bits: i64, val: u8) { + match (num_bits, val) { + (1, 0) => self.zero(), + (1, 1) => self.one(), + (2, 0) => { + self.zero(); + self.zero(); } - x if x == 0 => { - self.current_byte |= val; - self.next_word(); + (2, 1) => { + self.zero(); + self.one(); } - x => { - let used = -x; - self.current_byte |= val >> used; - self.next_word(); - self.current_byte = val << (8 - used); - self.used_bits = used; + (2, 2) => { + self.one(); + self.zero(); + } + (2, 3) => { + self.one(); + self.one(); + } + (_, _) => { + self.used_bits += num_bits; + let unused_bits = 8 - self.used_bits; + match unused_bits { + x if x > 0 => { + self.current_byte |= val << x; + } + x if x == 0 => { + self.current_byte |= val; + self.next_word(); + } + x => { + let used = -x; + self.current_byte |= val >> used; + self.next_word(); + self.current_byte = val << (8 - used); + self.used_bits = used; + } + } } } } diff --git a/crates/uplc/src/ast.rs b/crates/uplc/src/ast.rs index b18043fe..f788deb6 100644 --- a/crates/uplc/src/ast.rs +++ b/crates/uplc/src/ast.rs @@ -1,5 +1,9 @@ +use flat::en::{Encode, Encoder}; + use crate::builtins::DefaultFunction; +const TERM_TAG_WIDTH: u32 = 4; + #[derive(Debug)] pub struct Program { pub version: String, @@ -32,6 +36,22 @@ pub enum Term { Builtin(DefaultFunction), } +pub fn encode_term_tag(tag: u8, e: &mut Encoder) -> Result<(), String> { + safe_encode_bits(TERM_TAG_WIDTH, tag, e) +} + +pub fn safe_encode_bits(num_bits: u32, byte: u8, e: &mut Encoder) -> Result<(), String> { + if 2_u8.pow(num_bits) < byte { + Err(format!( + "Overflow detected, cannot fit {} in {} bits.", + byte, num_bits + )) + } else { + e.bits(num_bits as i64, byte); + Ok(()) + } +} + #[derive(Debug, Clone)] pub enum Constant { // TODO: figure out the right size for this @@ -48,3 +68,43 @@ pub enum Constant { // tag: 5 Bool(bool), } + +impl Encode for Program { + fn encode(&self, e: &mut Encoder) -> Result<(), String> { + self.version.encode(e)?; + self.term.encode(e)?; + + Ok(()) + } +} + +impl Encode for Term { + fn encode(&self, e: &mut Encoder) -> Result<(), String> { + match self { + Term::Constant(constant) => { + encode_term_tag(4, e)?; + constant.encode(e)?; + } + rest => { + todo!("Implement: {:?}", rest) + } + } + + Ok(()) + } +} + +impl Encode for &Constant { + fn encode(&self, e: &mut Encoder) -> Result<(), String> { + match self { + Constant::Integer(_) => todo!(), + Constant::ByteString(bytes) => bytes.encode(e)?, + Constant::String(s) => s.encode(e)?, + Constant::Char(c) => c.encode(e)?, + Constant::Unit => todo!(), + Constant::Bool(b) => b.encode(e)?, + } + + Ok(()) + } +}