diff --git a/crates/lang/src/uplc.rs b/crates/lang/src/uplc.rs index 89f62973..25901570 100644 --- a/crates/lang/src/uplc.rs +++ b/crates/lang/src/uplc.rs @@ -6,12 +6,13 @@ use uplc::{ ast::{Constant, Name, Program, Term, Type as UplcType, Unique}, builtins::DefaultFunction, parser::interner::Interner, + BigInt, PlutusData, }; use crate::{ ast::{AssignmentKind, BinOp, DataType, Function, Pattern, Span, TypedArg, TypedPattern}, expr::TypedExpr, - tipo::{self, ModuleValueConstructor, Type, ValueConstructorVariant}, + tipo::{self, ModuleValueConstructor, Type, ValueConstructor, ValueConstructorVariant}, }; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -96,7 +97,7 @@ pub struct DataTypeKey { pub type ConstrUsageKey = String; -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct FunctionAccessKey { pub module_name: String, pub function_name: String, @@ -123,6 +124,7 @@ pub struct CodeGenerator<'a> { uplc_data_holder_lookup: IndexMap, uplc_data_constr_lookup: IndexMap, uplc_data_usage_holder_lookup: IndexMap, + function_recurse_lookup: IndexMap, functions: &'a HashMap, TypedExpr>>, // type_aliases: &'a HashMap<(String, String), &'a TypeAlias>>, data_types: &'a HashMap>>, @@ -144,6 +146,7 @@ impl<'a> CodeGenerator<'a> { uplc_data_holder_lookup: IndexMap::new(), uplc_data_constr_lookup: IndexMap::new(), uplc_data_usage_holder_lookup: IndexMap::new(), + function_recurse_lookup: IndexMap::new(), functions, // type_aliases, data_types, @@ -241,8 +244,6 @@ impl<'a> CodeGenerator<'a> { term, }; - println!("{}", program.to_pretty()); - let mut interner = Interner::new(); interner.program(&mut program); @@ -286,15 +287,15 @@ impl<'a> CodeGenerator<'a> { }) .unwrap(); - self.recurse_scope_level(&func_def.body, scope_level.clone()); - self.uplc_function_holder_lookup.insert( FunctionAccessKey { module_name: module, function_name: name, }, - scope_level, + scope_level.clone(), ); + + self.recurse_scope_level(&func_def.body, scope_level); } else if scope_level.is_less_than( self.uplc_function_holder_lookup .get(&FunctionAccessKey { @@ -467,7 +468,7 @@ impl<'a> CodeGenerator<'a> { } } TypedExpr::ModuleSelect { constructor, .. } => match constructor { - ModuleValueConstructor::Record { .. } => todo!(), + ModuleValueConstructor::Record { .. } => {} ModuleValueConstructor::Fn { module, name, .. } => { if self .uplc_function_holder_lookup @@ -804,7 +805,51 @@ impl<'a> CodeGenerator<'a> { text: format!("{module}_{name}"), unique: 0.into(), }), - ValueConstructorVariant::Record { .. } => todo!(), + ValueConstructorVariant::Record { + name: constr_name, .. + } => { + let data_type_key = match &*constructor.tipo { + Type::App { module, name, .. } => DataTypeKey { + module_name: module.to_string(), + defined_type: name.to_string(), + }, + Type::Fn { .. } => todo!(), + Type::Var { .. } => todo!(), + }; + + if let Some(data_type) = self.data_types.get(&data_type_key) { + let (constr_index, _constr) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, x)| x.name == *constr_name) + .unwrap(); + + Term::Apply { + function: Term::Builtin(DefaultFunction::ConstrData).into(), + argument: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkPairData) + .into(), + argument: Term::Constant(Constant::Data( + PlutusData::BigInt(BigInt::Int( + (constr_index as i128).try_into().unwrap(), + )), + )) + .into(), + } + .into(), + argument: Term::Constant(Constant::Data( + PlutusData::Array(vec![]), + )) + .into(), + } + .into(), + } + } else { + todo!() + } + } } } } @@ -891,102 +936,305 @@ impl<'a> CodeGenerator<'a> { TypedExpr::Call { fun, args, tipo, .. } => { - if let ( - Type::App { module, name, .. }, - TypedExpr::Var { - name: constr_name, .. - }, - ) = (&**tipo, &**fun) - { - let mut term: Term = - Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![])); + match (&**tipo, &**fun) { + ( + Type::App { + name: tipo_name, .. + }, + TypedExpr::Var { + constructor: ValueConstructor { variant, .. }, + .. + }, + ) => match variant { + ValueConstructorVariant::LocalVariable { .. } => todo!(), + ValueConstructorVariant::ModuleConstant { .. } => todo!(), + ValueConstructorVariant::ModuleFn { name, module, .. } => { + let func_key = FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }; + if let Some(val) = self.function_recurse_lookup.get(&func_key) { + self.function_recurse_lookup.insert(func_key, *val + 1); + } else { + self.function_recurse_lookup.insert(func_key, 1); + } + let mut term = + self.recurse_code_gen(fun, scope_level.scope_increment(1)); - if let Some(data_type) = self.data_types.get(&DataTypeKey { - module_name: module.to_string(), - defined_type: name.to_string(), - }) { - let constr = data_type - .constructors - .iter() - .find(|x| x.name == *constr_name) - .unwrap(); + for (i, arg) in args.iter().enumerate() { + term = Term::Apply { + function: term.into(), + argument: self + .recurse_code_gen( + &arg.value, + scope_level.scope_increment(i as i32 + 2), + ) + .into(), + }; + } + term + } + ValueConstructorVariant::Record { + name: constr_name, + module, + .. + } => { + let mut term: Term = + Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![])); - let arg_to_data: Vec<(bool, Term)> = constr - .arguments - .iter() - .map(|x| { - if let Type::App { name, .. } = &*x.tipo { - if name == "ByteArray" { - (true, Term::Builtin(DefaultFunction::BData)) - } else if name == "Int" { - (true, Term::Builtin(DefaultFunction::IData)) - } else { - (false, Term::Constant(Constant::Unit)) - } - } else { - unreachable!() - } - }) - .collect(); + if let Some(data_type) = self.data_types.get(&DataTypeKey { + module_name: module.to_string(), + defined_type: tipo_name.to_string(), + }) { + let (constr_index, constr) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, x)| x.name == *constr_name) + .unwrap(); - for (i, arg) in args.iter().enumerate().rev() { - let arg_term = self.recurse_code_gen( - &arg.value, - scope_level.scope_increment(i as i32 + 1), - ); - - term = Term::Apply { - function: Term::Apply { - function: Term::Force( - Term::Builtin(DefaultFunction::MkCons).into(), - ) - .into(), - argument: if arg_to_data[i].0 { - Term::Apply { - function: arg_to_data[i].1.clone().into(), - argument: arg_term.into(), + // TODO: order arguments by data type field map + let arg_to_data: Vec<(bool, Term)> = constr + .arguments + .iter() + .map(|x| { + if let Type::App { name, .. } = &*x.tipo { + if name == "ByteArray" { + (true, Term::Builtin(DefaultFunction::BData)) + } else if name == "Int" { + (true, Term::Builtin(DefaultFunction::IData)) + } else { + (false, Term::Constant(Constant::Unit)) + } + } else { + unreachable!() } - .into() - } else { - arg_term.into() - }, - } - .into(), - argument: term.into(), - }; - } - term - } else { - let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1)); + }) + .collect(); - for (i, arg) in args.iter().enumerate() { - term = Term::Apply { - function: term.into(), - argument: self - .recurse_code_gen( + for (i, arg) in args.iter().enumerate().rev() { + let arg_term = self.recurse_code_gen( &arg.value, - scope_level.scope_increment(i as i32 + 2), - ) - .into(), - }; - } - term - } - } else { - let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1)); + scope_level.scope_increment(i as i32 + 1), + ); - for (i, arg) in args.iter().enumerate() { - term = Term::Apply { - function: term.into(), - argument: self - .recurse_code_gen( - &arg.value, - scope_level.scope_increment(i as i32 + 2), - ) - .into(), - }; + term = Term::Apply { + function: Term::Apply { + function: Term::Force( + Term::Builtin(DefaultFunction::MkCons).into(), + ) + .into(), + argument: if arg_to_data[i].0 { + Term::Apply { + function: arg_to_data[i].1.clone().into(), + argument: arg_term.into(), + } + .into() + } else { + arg_term.into() + }, + } + .into(), + argument: term.into(), + }; + } + + term = Term::Apply { + function: Term::Builtin(DefaultFunction::ConstrData).into(), + argument: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkPairData) + .into(), + argument: Term::Constant(Constant::Data( + PlutusData::BigInt(BigInt::Int( + (constr_index as i128).try_into().unwrap(), + )), + )) + .into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::ListData) + .into(), + argument: term.into(), + } + .into(), + } + .into(), + }; + + term + } else { + let mut term = + self.recurse_code_gen(fun, scope_level.scope_increment(1)); + + for (i, arg) in args.iter().enumerate() { + term = Term::Apply { + function: term.into(), + argument: self + .recurse_code_gen( + &arg.value, + scope_level.scope_increment(i as i32 + 2), + ) + .into(), + }; + } + term + } + } + }, + + ( + Type::App { + name: tipo_name, .. + }, + TypedExpr::ModuleSelect { + constructor, + module_name: module, + .. + }, + ) => { + match constructor { + ModuleValueConstructor::Constant { .. } => todo!(), + ModuleValueConstructor::Fn { name, module, .. } => { + let func_key = FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }; + if let Some(val) = self.function_recurse_lookup.get(&func_key) { + self.function_recurse_lookup.insert(func_key, *val + 1); + } else { + self.function_recurse_lookup.insert(func_key, 1); + } + let mut term = + self.recurse_code_gen(fun, scope_level.scope_increment(1)); + + for (i, arg) in args.iter().enumerate() { + term = Term::Apply { + function: term.into(), + argument: self + .recurse_code_gen( + &arg.value, + scope_level.scope_increment(i as i32 + 2), + ) + .into(), + }; + } + term + } + ModuleValueConstructor::Record { + name: constr_name, .. + } => { + let mut term: Term = Term::Constant(Constant::ProtoList( + uplc::ast::Type::Data, + vec![], + )); + + if let Some(data_type) = self.data_types.get(&DataTypeKey { + module_name: module.to_string(), + defined_type: tipo_name.to_string(), + }) { + let (constr_index, constr) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, x)| x.name == *constr_name) + .unwrap(); + + // TODO: order arguments by data type field map + let arg_to_data: Vec<(bool, Term)> = constr + .arguments + .iter() + .map(|x| { + if let Type::App { name, .. } = &*x.tipo { + if name == "ByteArray" { + (true, Term::Builtin(DefaultFunction::BData)) + } else if name == "Int" { + (true, Term::Builtin(DefaultFunction::IData)) + } else { + (false, Term::Constant(Constant::Unit)) + } + } else { + unreachable!() + } + }) + .collect(); + + for (i, arg) in args.iter().enumerate().rev() { + let arg_term = self.recurse_code_gen( + &arg.value, + scope_level.scope_increment(i as i32 + 1), + ); + + term = Term::Apply { + function: Term::Apply { + function: Term::Force( + Term::Builtin(DefaultFunction::MkCons).into(), + ) + .into(), + argument: if arg_to_data[i].0 { + Term::Apply { + function: arg_to_data[i].1.clone().into(), + argument: arg_term.into(), + } + .into() + } else { + arg_term.into() + }, + } + .into(), + argument: term.into(), + }; + } + + term = Term::Apply { + function: Term::Builtin(DefaultFunction::ConstrData).into(), + argument: Term::Apply { + function: Term::Apply { + function: Term::Builtin( + DefaultFunction::MkPairData, + ) + .into(), + argument: Term::Constant(Constant::Data( + PlutusData::BigInt(BigInt::Int( + (constr_index as i128).try_into().unwrap(), + )), + )) + .into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::ListData) + .into(), + argument: term.into(), + } + .into(), + } + .into(), + }; + + term + } else { + let mut term = + self.recurse_code_gen(fun, scope_level.scope_increment(1)); + + for (i, arg) in args.iter().enumerate() { + term = Term::Apply { + function: term.into(), + argument: self + .recurse_code_gen( + &arg.value, + scope_level.scope_increment(i as i32 + 2), + ) + .into(), + }; + } + term + } + } + } } - term + _ => todo!(), } } TypedExpr::BinOp { @@ -2114,7 +2362,6 @@ impl<'a> CodeGenerator<'a> { scope_level: ScopeLevels, ) -> Term { let mut term = current_term; - // attempt to insert function definitions where needed for func_key in self.uplc_function_holder_lookup.clone().keys() { if scope_level.is_less_than( @@ -2126,11 +2373,71 @@ impl<'a> CodeGenerator<'a> { ) { let func_def = self.functions.get(func_key).unwrap(); + let current_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0); + let mut function_body = self.recurse_code_gen( &func_def.body, scope_level.scope_increment_sequence(func_def.arguments.len() as i32), ); + let recurse_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0); + + if recurse_called > current_called { + for arg in func_def.arguments.iter().rev() { + function_body = Term::Lambda { + parameter_name: Name { + text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(), + unique: Unique::new(0), + }, + body: Rc::new(function_body), + } + } + + function_body = Term::Lambda { + parameter_name: Name { + text: format!("{}_{}", func_key.module_name, func_key.function_name), + unique: 0.into(), + }, + body: function_body.into(), + }; + + let mut recurse_term = Term::Apply { + function: Term::Var(Name { + text: "recurse".to_string(), + unique: 0.into(), + }) + .into(), + argument: Term::Var(Name { + text: "recurse".into(), + unique: 0.into(), + }) + .into(), + }; + + for arg in func_def.arguments.iter() { + recurse_term = Term::Apply { + function: recurse_term.into(), + argument: Term::Var(Name { + text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(), + unique: 0.into(), + }) + .into(), + }; + } + + function_body = Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: "recurse".into(), + unique: 0.into(), + }, + body: recurse_term.into(), + } + .into(), + argument: function_body.into(), + } + } + for arg in func_def.arguments.iter().rev() { function_body = Term::Lambda { parameter_name: Name { diff --git a/examples/sample/validators/swap.ak b/examples/sample/validators/swap.ak index 08209a7f..20e4a512 100644 --- a/examples/sample/validators/swap.ak +++ b/examples/sample/validators/swap.ak @@ -28,20 +28,20 @@ pub fn final_check(z: Int) { z < 4 } +pub fn incrementor(counter: Int, target: Int) -> Int { + if counter == target { + target + } else { + incrementor(counter + 1, target) + } +} + pub fn spend( datum: sample.Datum, rdmr: Redeemer, ctx: spend.ScriptContext, ) -> Bool { - let x = datum.rdmr - let y = [datum.fin, 2, 3] - let z = [1, ..y] - when z is { - [] -> False - [a] -> a == 1 - [a, b] -> b == 2 - [a, b, c] -> a > 1 - [a, b, c, ..d] -> b > 1 - _other -> True - } + let x = Sell + let z = incrementor(0, 4) == 4 + z }