diff --git a/crates/lang/src/ir.rs b/crates/lang/src/ir.rs index 44870c5f..e21e8b86 100644 --- a/crates/lang/src/ir.rs +++ b/crates/lang/src/ir.rs @@ -138,10 +138,13 @@ pub enum IR { count: usize, tipo: Arc, subject_name: String, + complex_clause: bool, }, ClauseGuard { scope: Vec, + subject_name: String, + tipo: Arc, }, Discard { @@ -242,7 +245,7 @@ impl IR { | IR::Lam { scope, .. } | IR::When { scope, .. } | IR::Clause { scope, .. } - | IR::ClauseGuard { scope } + | IR::ClauseGuard { scope, .. } | IR::Discard { scope } | IR::Finally { scope } | IR::If { scope, .. } diff --git a/crates/lang/src/uplc_two.rs b/crates/lang/src/uplc_two.rs index 5165264d..ae5459a5 100644 --- a/crates/lang/src/uplc_two.rs +++ b/crates/lang/src/uplc_two.rs @@ -37,6 +37,12 @@ pub struct FuncComponents { // } +pub struct ClauseComplexity { + subject_var_name: String, + needs_subject_var: bool, + is_complex_clause: bool, +} + pub struct CodeGenerator<'a> { defined_functions: HashMap, functions: &'a HashMap, TypedExpr>>, @@ -259,7 +265,7 @@ impl<'a> CodeGenerator<'a> { // assuming one subject at the moment let subject = subjects[0].clone(); - let mut needs_constr_var = false; + let mut needs_subject_var = false; if let Some((last_clause, clauses)) = clauses.split_last() { let mut clauses_vec = vec![]; @@ -268,28 +274,48 @@ impl<'a> CodeGenerator<'a> { for clause in clauses { let mut scope = scope.clone(); scope.push(self.id_gen.next()); + let mut clause_subject_vec = vec![]; + + let mut clause_complexity = ClauseComplexity { + subject_var_name: constr_var.clone(), + needs_subject_var: false, + is_complex_clause: false, + }; + + self.build_ir(&clause.then, &mut clauses_vec, scope.clone()); + + self.when_ir( + &clause.pattern[0], + &mut clause_subject_vec, + &mut clauses_vec, + &subject.tipo(), + &mut clause_complexity, + scope.clone(), + ); pattern_vec.push(IR::Clause { scope: scope.clone(), count: 2, tipo: subject.tipo().clone(), subject_name: subject_name.clone(), + complex_clause: clause_complexity.is_complex_clause, }); - self.build_ir(&clause.then, &mut clauses_vec, scope.clone()); - self.when_ir( - &clause.pattern[0], - &mut pattern_vec, - &mut clauses_vec, - &subject.tipo(), - constr_var.clone(), - &mut needs_constr_var, - scope, - ); + pattern_vec.append(&mut clause_subject_vec); + + if clause_complexity.needs_subject_var { + needs_subject_var = true; + } } let last_pattern = &last_clause.pattern[0]; + let mut final_clause_complexity = ClauseComplexity { + subject_var_name: constr_var.clone(), + needs_subject_var: false, + is_complex_clause: false, + }; + let mut final_scope = scope.clone(); final_scope.push(self.id_gen.next()); pattern_vec.push(IR::Finally { @@ -302,12 +328,11 @@ impl<'a> CodeGenerator<'a> { &mut pattern_vec, &mut clauses_vec, &subject.tipo(), - constr_var.clone(), - &mut needs_constr_var, + &mut final_clause_complexity, final_scope, ); - if needs_constr_var { + if needs_subject_var || final_clause_complexity.needs_subject_var { ir_stack.push(IR::Lam { scope: scope.clone(), name: constr_var.clone(), @@ -460,8 +485,7 @@ impl<'a> CodeGenerator<'a> { pattern_vec: &mut Vec, values: &mut Vec, tipo: &Type, - constr_var: String, - needs_constr_var: &mut bool, + when_complexity: &mut ClauseComplexity, scope: Vec, ) { match pattern { @@ -500,8 +524,8 @@ impl<'a> CodeGenerator<'a> { .. } => { let mut needs_access_to_constr_var = false; - let mut needs_clause_guard = false; + for arg in arguments { match arg.value { Pattern::Var { .. } => { @@ -550,18 +574,12 @@ impl<'a> CodeGenerator<'a> { location: Span::empty(), }, ), - name: constr_var, + name: when_complexity.subject_var_name.clone(), scope: scope.clone(), }]; // if only one constructor, no need to check if data_type.constructors.len() > 1 { - if needs_clause_guard { - pattern_vec.push(IR::ClauseGuard { - scope: scope.clone(), - }); - } - // push constructor Index pattern_vec.push(IR::Int { value: index.to_string(), @@ -569,8 +587,12 @@ impl<'a> CodeGenerator<'a> { }); } + if needs_clause_guard { + when_complexity.is_complex_clause = true; + } + if needs_access_to_constr_var { - *needs_constr_var = true; + when_complexity.needs_subject_var = true; self.when_recursive_ir(pattern, pattern_vec, &mut new_vec, tipo, scope); pattern_vec.append(values); @@ -717,18 +739,48 @@ impl<'a> CodeGenerator<'a> { .. } => { let id = self.id_gen.next(); - let constr_name = format!("{constr_name}_{id}"); + let constr_var_name = format!("{constr_name}_{id}"); + let data_type_key = match tipo.as_ref() { + Type::Fn { ret, .. } => match &**ret { + Type::App { module, name, .. } => DataTypeKey { + module_name: module.clone(), + defined_type: name.clone(), + }, + _ => unreachable!(), + }, + Type::App { module, name, .. } => DataTypeKey { + module_name: module.clone(), + defined_type: name.clone(), + }, + _ => unreachable!(), + }; + + let data_type = self.data_types.get(&data_type_key).unwrap(); + + if data_type.constructors.len() > 1 { + nested_pattern.push(IR::ClauseGuard { + scope: scope.clone(), + tipo: tipo.clone(), + subject_name: constr_var_name.clone(), + }); + } + + let mut clause_complexity = ClauseComplexity { + subject_var_name: constr_var_name.clone(), + needs_subject_var: false, + is_complex_clause: false, + }; + self.when_ir( a, &mut nested_pattern, &mut vec![], tipo, - constr_name.clone(), - &mut false, + &mut clause_complexity, scope.clone(), ); - (false, constr_name) + (false, constr_var_name) } _ => todo!(), }; @@ -775,18 +827,47 @@ impl<'a> CodeGenerator<'a> { .. } => { let id = self.id_gen.next(); - let constr_name = format!("{constr_name}_{id}"); + let constr_var_name = format!("{constr_name}_{id}"); + let data_type_key = match tipo.as_ref() { + Type::Fn { ret, .. } => match &**ret { + Type::App { module, name, .. } => DataTypeKey { + module_name: module.clone(), + defined_type: name.clone(), + }, + _ => unreachable!(), + }, + Type::App { module, name, .. } => DataTypeKey { + module_name: module.clone(), + defined_type: name.clone(), + }, + _ => unreachable!(), + }; + + let data_type = self.data_types.get(&data_type_key).unwrap(); + + if data_type.constructors.len() > 1 { + nested_pattern.push(IR::ClauseGuard { + scope: scope.clone(), + tipo: tipo.clone(), + subject_name: constr_var_name.clone(), + }); + } + let mut clause_complexity = ClauseComplexity { + subject_var_name: constr_var_name.clone(), + needs_subject_var: false, + is_complex_clause: false, + }; + self.when_ir( a, &mut nested_pattern, &mut vec![], tipo, - constr_name.clone(), - &mut false, + &mut clause_complexity, scope.clone(), ); - (false, constr_name) + (false, constr_var_name) } _ => todo!(), }; @@ -1426,7 +1507,14 @@ impl<'a> CodeGenerator<'a> { .into(), argument: right.into(), }, - BinOp::GtEqInt => todo!(), + BinOp::GtEqInt => Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::LessThanEqualsInteger).into(), + argument: right.into(), + } + .into(), + argument: left.into(), + }, BinOp::GtInt => Term::Apply { function: Term::Apply { function: Term::Builtin(DefaultFunction::LessThanInteger).into(), @@ -1736,7 +1824,13 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } - IR::ClauseGuard { .. } => todo!(), + IR::ClauseGuard { .. } => { + let _condition = arg_stack.pop().unwrap(); + + let _then = arg_stack.pop().unwrap(); + + todo!(); + } IR::Finally { .. } => { let _clause = arg_stack.pop().unwrap(); } diff --git a/examples/sample/validators/swap.ak b/examples/sample/validators/swap.ak index 3a321933..bceb48f9 100644 --- a/examples/sample/validators/swap.ak +++ b/examples/sample/validators/swap.ak @@ -47,7 +47,7 @@ pub fn who(a: ByteArray) -> ByteArray { } pub type Datum { - Offer { prices: List(Int), asset_class: ByteArray, other_thing: Redeemer } + Offer { prices: List(Int), asset_class: ByteArray, other_thing: Datum } Sell Hold(Int) } @@ -58,11 +58,11 @@ pub fn spend(datum: Datum, _rdmr: Nil, _ctx: Nil) -> Bool { Offer { prices: p, asset_class: ac, - other_thing: Redeemer { - other_thing: Redeemer { signer: nested_signer, amount, .. }, + other_thing: Offer { + other_thing: Offer { asset_class: nested_signer, prices: amounts, .. }, .. }, - } -> True + } -> 1 == 1 _ -> False } }