diff --git a/crates/lang/src/ir.rs b/crates/lang/src/ir.rs index 2c6c9cae..199dab98 100644 --- a/crates/lang/src/ir.rs +++ b/crates/lang/src/ir.rs @@ -96,6 +96,10 @@ pub enum IR { count: usize, }, + Lam { + name: String, + }, + // Try { // tipo: Arc, // value: Box, @@ -104,6 +108,7 @@ pub enum IR { // }, When { count: usize, + tipo: Arc, subject_name: String, }, @@ -136,6 +141,7 @@ pub enum IR { FieldsExpose { count: usize, + indices: Vec, }, // ModuleSelect { diff --git a/crates/lang/src/uplc_two.rs b/crates/lang/src/uplc_two.rs index afcd86e9..905055f9 100644 --- a/crates/lang/src/uplc_two.rs +++ b/crates/lang/src/uplc_two.rs @@ -1,5 +1,6 @@ -use std::{collections::HashMap, fmt::format, ops::Deref, sync::Arc}; +use std::{collections::HashMap, ops::Deref, sync::Arc}; +use itertools::Itertools; use uplc::{ ast::{ builder::{self, CONSTR_FIELDS_EXPOSER, CONSTR_GET_FIELD}, @@ -194,40 +195,80 @@ impl<'a> CodeGenerator<'a> { subjects, clauses, .. } => { let subject_name = format!("__subject_name_{}", self.id_gen.next()); + let constr_var = format!("__constr_name_{}", self.id_gen.next()); // assuming one subject at the moment - ir_stack.push(IR::When { - count: clauses.len() + 1, - subject_name: subject_name.clone(), - }); - let subject = subjects[0].clone(); - - self.build_ir(&subject, ir_stack); + let mut needs_constr_var = false; if let Some((last_clause, clauses)) = clauses.split_last() { let mut clauses_vec = vec![]; let mut pattern_vec = vec![]; for clause in clauses { - self.build_ir(&clause.then, &mut clauses_vec); + pattern_vec.push(IR::Clause { + count: 2, + tipo: subject.tipo().clone(), + subject_name: subject_name.clone(), + }); + self.build_ir(&clause.then, &mut clauses_vec); self.when_ir( &clause.pattern[0], &mut pattern_vec, &mut clauses_vec, &subject.tipo(), - subject_name.clone(), - ) + constr_var.clone(), + &mut needs_constr_var, + ); + } + + let last_pattern = &last_clause.pattern[0]; + pattern_vec.push(IR::Finally); + + self.build_ir(&last_clause.then, &mut clauses_vec); + self.when_ir( + last_pattern, + &mut pattern_vec, + &mut clauses_vec, + &subject.tipo(), + constr_var.clone(), + &mut needs_constr_var, + ); + + if needs_constr_var { + ir_stack.push(IR::Lam { + name: constr_var.clone(), + }); + + self.build_ir(&subject, ir_stack); + + ir_stack.push(IR::When { + count: clauses.len() + 1, + subject_name, + tipo: subject.tipo(), + }); + + ir_stack.push(IR::Var { + constructor: ValueConstructor::public( + subject.tipo(), + ValueConstructorVariant::LocalVariable { + location: Span::empty(), + }, + ), + name: constr_var, + }) + } else { + ir_stack.push(IR::When { + count: clauses.len() + 1, + subject_name, + tipo: subject.tipo(), + }); + + self.build_ir(&subject, ir_stack); } ir_stack.append(&mut pattern_vec); - - let last_pattern = &last_clause.pattern[0]; - ir_stack.push(IR::Finally); - - self.build_ir(&last_clause.then, &mut clauses_vec); - self.pattern_ir(last_pattern, ir_stack, &mut clauses_vec); }; } TypedExpr::If { .. } => todo!(), @@ -313,7 +354,7 @@ impl<'a> CodeGenerator<'a> { } fn assignment_ir( - &self, + &mut self, pattern: &Pattern>, pattern_vec: &mut Vec, value_vec: &mut Vec, @@ -342,21 +383,16 @@ impl<'a> CodeGenerator<'a> { } fn when_ir( - &self, + &mut self, pattern: &Pattern>, pattern_vec: &mut Vec, values: &mut Vec, tipo: &Type, - subject_name: String, + constr_var: String, + needs_constr_var: &mut bool, ) { match pattern { Pattern::Int { value, .. } => { - pattern_vec.push(IR::Clause { - count: 2, - tipo: tipo.clone().into(), - subject_name, - }); - pattern_vec.push(IR::Int { value: value.clone(), }); @@ -369,17 +405,48 @@ impl<'a> CodeGenerator<'a> { Pattern::Assign { .. } => todo!(), Pattern::Discard { .. } => unreachable!(), Pattern::List { .. } => todo!(), - Pattern::Constructor { .. } => todo!(), + Pattern::Constructor { arguments, .. } => { + let mut needs_access_to_constr_var = false; + for arg in arguments { + match arg.value { + Pattern::Var { .. } + | Pattern::List { .. } + | Pattern::Constructor { .. } => { + needs_access_to_constr_var = true; + } + _ => {} + } + } + + let mut new_vec = vec![IR::Var { + constructor: ValueConstructor::public( + tipo.clone().into(), + ValueConstructorVariant::LocalVariable { + location: Span::empty(), + }, + ), + name: constr_var, + }]; + + if needs_access_to_constr_var { + *needs_constr_var = true; + new_vec.append(values); + + self.pattern_ir(pattern, pattern_vec, &mut new_vec); + } else { + self.pattern_ir(pattern, pattern_vec, values); + } + } } } fn pattern_ir( - &self, + &mut self, pattern: &Pattern>, pattern_vec: &mut Vec, values: &mut Vec, ) { - match pattern { + match dbg!(pattern) { Pattern::Int { .. } => todo!(), Pattern::String { .. } => todo!(), Pattern::Var { .. } => todo!(), @@ -440,7 +507,101 @@ impl<'a> CodeGenerator<'a> { pattern_vec.append(values); pattern_vec.append(&mut elements_vec); } - Pattern::Constructor { .. } => todo!(), + Pattern::Constructor { + is_record, + name: constr_name, + arguments, + constructor, + tipo, + .. + } => { + if *is_record { + let data_type_key = match tipo.as_ref() { + Type::Fn { ret, .. } => match &**ret { + Type::App { module, name, .. } => DataTypeKey { + module_name: module.clone(), + defined_type: name.clone(), + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let data_type = self.data_types.get(&data_type_key).unwrap(); + let (index, constructor_type) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, dt)| &dt.name == constr_name) + .unwrap(); + + let field_map = match constructor { + tipo::PatternConstructor::Record { field_map, .. } => { + field_map.clone().unwrap() + } + }; + + let mut type_map: HashMap> = HashMap::new(); + + for arg in &constructor_type.arguments { + let label = arg.label.clone().unwrap(); + let field_type = arg.tipo.clone(); + + type_map.insert(label, field_type); + } + + let arguments_index = arguments + .iter() + .map(|item| { + let label = item.label.clone().unwrap_or_default(); + let field_index = field_map.fields.get(&label).unwrap_or(&0); + let (discard, var_name) = match &item.value { + Pattern::Var { name, .. } => (false, name.clone()), + Pattern::Discard { .. } => (true, "".to_string()), + Pattern::List { .. } => todo!(), + Pattern::Constructor { .. } => todo!(), + _ => todo!(), + }; + + (label, var_name, *field_index, discard) + }) + .filter(|(_, _, _, discard)| !discard) + .sorted_by(|item1, item2| item1.2.cmp(&item2.2)) + .collect::>(); + + // push constructor Index + pattern_vec.push(IR::Int { + value: index.to_string(), + }); + if !arguments_index.is_empty() { + pattern_vec.push(IR::FieldsExpose { + count: arguments_index.len() + 2, + indices: arguments_index + .iter() + .map(|(_, _, index, _)| *index) + .collect_vec(), + }); + + for arg in arguments_index { + let field_label = arg.0; + let field_type = type_map.get(&field_label).unwrap(); + let field_var = arg.1; + pattern_vec.push(IR::Var { + constructor: ValueConstructor::public( + field_type.clone(), + ValueConstructorVariant::LocalVariable { + location: Span::empty(), + }, + ), + name: field_var, + }) + } + } + pattern_vec.append(values); + } else { + println!("todo"); + } + } } } @@ -807,7 +968,7 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } - IR::DefineFunc { func_name, .. } => { + IR::DefineFunc { .. } => { let _body = arg_stack.pop().unwrap(); todo!() @@ -815,11 +976,34 @@ impl<'a> CodeGenerator<'a> { IR::DefineConst { .. } => todo!(), IR::DefineConstrFields { .. } => todo!(), IR::DefineConstrFieldAccess { .. } => todo!(), - IR::When { .. } => todo!(), + IR::Lam { .. } => todo!(), + IR::When { + subject_name, tipo, .. + } => { + let subject = arg_stack.pop().unwrap(); + + let mut term = arg_stack.pop().unwrap(); + + term = if tipo.is_int() { + Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: subject_name, + unique: 0.into(), + }, + body: term.into(), + } + .into(), + argument: subject.into(), + } + } else { + todo!() + }; + + arg_stack.push(term); + } IR::Clause { - count, - tipo, - subject_name, + tipo, subject_name, .. } => { // clause to compare let clause = arg_stack.pop().unwrap(); @@ -831,17 +1015,45 @@ impl<'a> CodeGenerator<'a> { let mut term = arg_stack.pop().unwrap(); let checker = if tipo.is_int() { - DefaultFunction::EqualsInteger.into() + Term::Apply { + function: DefaultFunction::EqualsInteger.into(), + argument: Term::Var(Name { + text: subject_name, + unique: 0.into(), + }) + .into(), + } } else if tipo.is_bytearray() { - DefaultFunction::EqualsByteString.into() + Term::Apply { + function: DefaultFunction::EqualsByteString.into(), + argument: Term::Var(Name { + text: subject_name, + unique: 0.into(), + }) + .into(), + } } else if tipo.is_bool() { todo!() } else if tipo.is_string() { - DefaultFunction::EqualsString.into() + Term::Apply { + function: DefaultFunction::EqualsString.into(), + argument: Term::Var(Name { + text: subject_name, + unique: 0.into(), + }) + .into(), + } } else if tipo.is_list() { todo!() } else { - DefaultFunction::EqualsData.into() + Term::Apply { + function: DefaultFunction::EqualsInteger.into(), + argument: Term::Var(Name { + text: subject_name, + unique: 0.into(), + }) + .into(), + } }; term = Term::Apply { @@ -849,24 +1061,16 @@ impl<'a> CodeGenerator<'a> { function: Term::Apply { function: Term::Force(DefaultFunction::IfThenElse.into()).into(), argument: Term::Apply { - function: Term::Apply { - function: checker, - argument: Term::Var(Name { - text: subject_name, - unique: 0.into(), - }) - .into(), - } - .into(), + function: checker.into(), argument: clause.into(), } .into(), } .into(), - argument: body.into(), + argument: Term::Delay(body.into()).into(), } .into(), - argument: term.into(), + argument: Term::Delay(term.into()).into(), } .force_wrap(); diff --git a/examples/sample/validators/swap.ak b/examples/sample/validators/swap.ak index 7f22cf84..8ec4eb27 100644 --- a/examples/sample/validators/swap.ak +++ b/examples/sample/validators/swap.ak @@ -44,13 +44,15 @@ pub fn who(a: ByteArray) -> ByteArray { } pub type Datum { - thing: Int, - stuff: Int, + Offer { price: Int, asset_class: ByteArray } + Sell + Hold(Int) } pub fn spend(datum: Datum, _rdmr: Nil, _ctx: Nil) -> Bool { - when datum.thing is { - 0 -> True - _ -> False + when datum is { + Offer { price, .. } -> price > 0 + Hold(less) -> less < 0 + Sell -> False } }