diff --git a/crates/lang/src/uplc.rs b/crates/lang/src/uplc.rs index 689ee903..e6f10c2f 100644 --- a/crates/lang/src/uplc.rs +++ b/crates/lang/src/uplc.rs @@ -2,7 +2,6 @@ use std::{cmp::Ordering, collections::HashMap, rc::Rc, sync::Arc}; use indexmap::IndexMap; -use itertools::Itertools; use uplc::{ ast::{Constant, Name, Program, Term, Unique}, builtins::DefaultFunction, @@ -18,23 +17,28 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ScopeLevels { scope_tracker: Vec, + field_depth: i32, } impl ScopeLevels { pub fn new() -> Self { ScopeLevels { scope_tracker: vec![0], + field_depth: 0, } } - pub fn is_less_than(&self, other: &ScopeLevels) -> bool { + pub fn is_less_than(&self, other: &ScopeLevels, include_depth: bool) -> bool { if self.scope_tracker.is_empty() && !other.scope_tracker.is_empty() { return true; } else if other.scope_tracker.is_empty() { return false; } - let mut result = self.scope_tracker.len() < other.scope_tracker.len(); + let mut result = self.scope_tracker.len() < other.scope_tracker.len() + || (self.scope_tracker.len() == other.scope_tracker.len() + && include_depth + && self.field_depth < other.field_depth); for (scope_self, scope_other) in self.scope_tracker.iter().zip(other.scope_tracker.iter()) { match scope_self.cmp(scope_other) { @@ -64,6 +68,12 @@ impl ScopeLevels { *new_scope.scope_tracker.last_mut().unwrap() += inc; new_scope } + + pub fn depth_increment(&self, inc: i32) -> ScopeLevels { + let mut new_scope = self.clone(); + new_scope.field_depth += inc; + new_scope + } } impl Default for ScopeLevels { @@ -393,9 +403,9 @@ impl<'a> CodeGenerator<'a> { .uplc_data_usage_holder_lookup .get(&(module.to_string(), name.clone())) { - if scope_level.is_less_than(val) { + if scope_level.is_less_than(val, false) { self.uplc_data_usage_holder_lookup - .insert((module, name.clone()), scope_level.scope_increment(1)); + .insert((module, name.clone()), scope_level); } } } @@ -473,11 +483,12 @@ impl<'a> CodeGenerator<'a> { } } a @ TypedExpr::RecordAccess { label, record, .. } => { - self.recurse_scope_level(record, scope_level.scope_increment(1)); + self.recurse_scope_level(record, scope_level.clone()); let mut is_var = false; let mut current_var_name = "".to_string(); let mut module = "".to_string(); let mut current_record = *record.clone(); + let mut current_scope = scope_level.clone(); while !is_var { match current_record.clone() { TypedExpr::Var { @@ -509,6 +520,7 @@ impl<'a> CodeGenerator<'a> { format!("{label}_field_{current_var_name}") }; current_record = *record.clone(); + current_scope = current_scope.depth_increment(1); } _ => {} } @@ -519,16 +531,16 @@ impl<'a> CodeGenerator<'a> { current_var_name.clone(), label.clone(), )) { - if scope_level.is_less_than(&val.0) { + if current_scope.is_less_than(&val.0, false) { self.uplc_data_holder_lookup.insert( (module.to_string(), current_var_name.clone(), label.clone()), - (scope_level.clone(), a.clone()), + (current_scope.clone(), a.clone()), ); } } else { self.uplc_data_holder_lookup.insert( (module.to_string(), current_var_name.clone(), label.clone()), - (scope_level.clone(), a.clone()), + (current_scope.clone(), a.clone()), ); } @@ -536,13 +548,13 @@ impl<'a> CodeGenerator<'a> { .uplc_data_usage_holder_lookup .get(&(module.to_string(), current_var_name.clone())) { - if scope_level.is_less_than(val) { + if current_scope.is_less_than(val, false) { self.uplc_data_usage_holder_lookup - .insert((module, current_var_name.clone()), scope_level); + .insert((module, current_var_name), current_scope); } } else { self.uplc_data_usage_holder_lookup - .insert((module, current_var_name), scope_level); + .insert((module, current_var_name), current_scope); } } a @ TypedExpr::ModuleSelect { constructor, .. } => match constructor { @@ -579,6 +591,7 @@ impl<'a> CodeGenerator<'a> { .get(&(module.to_string(), name.to_string())) .unwrap() .0, + false, ) { self.uplc_function_holder_lookup.insert( (module.to_string(), name.to_string()), @@ -821,7 +834,7 @@ impl<'a> CodeGenerator<'a> { } .into(), argument: self - .recurse_code_gen(value, scope_level.scope_increment(1)) + .recurse_code_gen(value, scope_level.scope_increment_sequence(1)) .into(), }, @@ -1051,6 +1064,7 @@ impl<'a> CodeGenerator<'a> { .get(func) .unwrap() .0, + false, ) { let func_def = self .functions @@ -1089,48 +1103,48 @@ impl<'a> CodeGenerator<'a> { // Pull out all uplc data holder and data usage, Sort By Scope Level, Then let mut data_holder: Vec<((String, String, String), (bool, ScopeLevels, u64))> = self - .uplc_data_holder_lookup + .uplc_data_usage_holder_lookup .iter() - .filter(|record_scope| scope_level.is_less_than(&record_scope.1 .0)) - .map(|((module, name, label), (scope, expr))| { - let index = match expr { - TypedExpr::RecordAccess { index, .. } => index, - _ => todo!(), - }; + .filter(|record_scope| scope_level.is_less_than(&record_scope.1, false)) + .map(|((module, name), scope)| { ( - (module.to_string(), name.to_string(), label.to_string()), - (false, scope.clone(), *index), + (module.to_string(), name.to_string(), "".to_string()), + (true, scope.clone(), 0), ) }) .collect(); data_holder.extend( - self.uplc_data_usage_holder_lookup + self.uplc_data_holder_lookup .iter() - .filter(|record_scope| scope_level.is_less_than(&record_scope.1)) - .map(|((module, name), scope)| { + .filter(|record_scope| scope_level.is_less_than(&record_scope.1 .0, false)) + .map(|((module, name, label), (scope, expr))| { + let index = match expr { + TypedExpr::RecordAccess { index, .. } => index, + _ => todo!(), + }; ( - (module.to_string(), name.to_string(), "".to_string()), - (true, scope.clone(), 0), + (module.to_string(), name.to_string(), label.to_string()), + (false, scope.clone(), *index), ) }) .collect::>(), ); data_holder.sort_by(|b, d| { - if b.1 .1.is_less_than(&d.1 .1) { + if b.1 .1.is_less_than(&d.1 .1, true) { Ordering::Less - } else if d.1 .1.is_less_than(&b.1 .1) { + } else if d.1 .1.is_less_than(&b.1 .1, true) { Ordering::Greater } else if b.1 .0 && !d.1 .0 { Ordering::Less - } else if !b.1 .0 && d.1 .0 { + } else if d.1 .0 && !b.1 .0 { Ordering::Greater } else { Ordering::Equal } }); - for (key @ (module, name, label), (is_data_usage, _, index)) in data_holder.iter() { + for (key @ (module, name, label), (is_data_usage, _, index)) in data_holder.iter().rev() { if *is_data_usage { term = Term::Apply { function: Term::Lambda { diff --git a/examples/sample/src/sample.ak b/examples/sample/src/sample.ak index 34306fa2..e58385ab 100644 --- a/examples/sample/src/sample.ak +++ b/examples/sample/src/sample.ak @@ -1,10 +1,22 @@ + +pub type Signer { + hash: Int +} + + pub type ScriptContext { - thing: ByteArray + signer: Signer +} + +pub type Redeem { + Buy + Sell } pub type Datum { - something: ScriptContext, - fin: Int + fin: Int, + sc: ScriptContext, + rdmr: Redeem, } pub fn eqInt(a: Int, b: Int) { diff --git a/examples/sample/src/scripts/swap.ak b/examples/sample/src/scripts/swap.ak index d2cf156c..dfc90b3a 100644 --- a/examples/sample/src/scripts/swap.ak +++ b/examples/sample/src/scripts/swap.ak @@ -11,8 +11,9 @@ pub type Redeemer { pub fn spend(datum: sample.Datum, rdmr: Redeemer, ctx: spend.ScriptContext) -> Bool { let y = 2 - let x = 1 - let a = datum.something.thing - let b = 2 - b == 1 + let x = datum.sc.signer + let a = datum.sc.signer.hash + let b = datum.rdmr + let c = 1 + c == 1 }