feat: support self recursion functions and fix making constrs

This commit is contained in:
Kasey White 2022-11-24 15:14:01 -05:00 committed by Kasey White
parent 09e77e1918
commit 6babebde28
2 changed files with 416 additions and 109 deletions

View File

@ -6,12 +6,13 @@ use uplc::{
ast::{Constant, Name, Program, Term, Type as UplcType, Unique},
builtins::DefaultFunction,
parser::interner::Interner,
BigInt, PlutusData,
};
use crate::{
ast::{AssignmentKind, BinOp, DataType, Function, Pattern, Span, TypedArg, TypedPattern},
expr::TypedExpr,
tipo::{self, ModuleValueConstructor, Type, ValueConstructorVariant},
tipo::{self, ModuleValueConstructor, Type, ValueConstructor, ValueConstructorVariant},
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -96,7 +97,7 @@ pub struct DataTypeKey {
pub type ConstrUsageKey = String;
#[derive(Clone, Eq, PartialEq, Hash)]
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct FunctionAccessKey {
pub module_name: String,
pub function_name: String,
@ -123,6 +124,7 @@ pub struct CodeGenerator<'a> {
uplc_data_holder_lookup: IndexMap<ConstrFieldKey, ScopedExpr>,
uplc_data_constr_lookup: IndexMap<DataTypeKey, ScopeLevels>,
uplc_data_usage_holder_lookup: IndexMap<ConstrUsageKey, ScopeLevels>,
function_recurse_lookup: IndexMap<FunctionAccessKey, usize>,
functions: &'a HashMap<FunctionAccessKey, &'a Function<Arc<tipo::Type>, TypedExpr>>,
// type_aliases: &'a HashMap<(String, String), &'a TypeAlias<Arc<tipo::Type>>>,
data_types: &'a HashMap<DataTypeKey, &'a DataType<Arc<tipo::Type>>>,
@ -144,6 +146,7 @@ impl<'a> CodeGenerator<'a> {
uplc_data_holder_lookup: IndexMap::new(),
uplc_data_constr_lookup: IndexMap::new(),
uplc_data_usage_holder_lookup: IndexMap::new(),
function_recurse_lookup: IndexMap::new(),
functions,
// type_aliases,
data_types,
@ -241,8 +244,6 @@ impl<'a> CodeGenerator<'a> {
term,
};
println!("{}", program.to_pretty());
let mut interner = Interner::new();
interner.program(&mut program);
@ -286,15 +287,15 @@ impl<'a> CodeGenerator<'a> {
})
.unwrap();
self.recurse_scope_level(&func_def.body, scope_level.clone());
self.uplc_function_holder_lookup.insert(
FunctionAccessKey {
module_name: module,
function_name: name,
},
scope_level,
scope_level.clone(),
);
self.recurse_scope_level(&func_def.body, scope_level);
} else if scope_level.is_less_than(
self.uplc_function_holder_lookup
.get(&FunctionAccessKey {
@ -467,7 +468,7 @@ impl<'a> CodeGenerator<'a> {
}
}
TypedExpr::ModuleSelect { constructor, .. } => match constructor {
ModuleValueConstructor::Record { .. } => todo!(),
ModuleValueConstructor::Record { .. } => {}
ModuleValueConstructor::Fn { module, name, .. } => {
if self
.uplc_function_holder_lookup
@ -804,7 +805,51 @@ impl<'a> CodeGenerator<'a> {
text: format!("{module}_{name}"),
unique: 0.into(),
}),
ValueConstructorVariant::Record { .. } => todo!(),
ValueConstructorVariant::Record {
name: constr_name, ..
} => {
let data_type_key = match &*constructor.tipo {
Type::App { module, name, .. } => DataTypeKey {
module_name: module.to_string(),
defined_type: name.to_string(),
},
Type::Fn { .. } => todo!(),
Type::Var { .. } => todo!(),
};
if let Some(data_type) = self.data_types.get(&data_type_key) {
let (constr_index, _constr) = data_type
.constructors
.iter()
.enumerate()
.find(|(_, x)| x.name == *constr_name)
.unwrap();
Term::Apply {
function: Term::Builtin(DefaultFunction::ConstrData).into(),
argument: Term::Apply {
function: Term::Apply {
function: Term::Builtin(DefaultFunction::MkPairData)
.into(),
argument: Term::Constant(Constant::Data(
PlutusData::BigInt(BigInt::Int(
(constr_index as i128).try_into().unwrap(),
)),
))
.into(),
}
.into(),
argument: Term::Constant(Constant::Data(
PlutusData::Array(vec![]),
))
.into(),
}
.into(),
}
} else {
todo!()
}
}
}
}
}
@ -891,102 +936,305 @@ impl<'a> CodeGenerator<'a> {
TypedExpr::Call {
fun, args, tipo, ..
} => {
if let (
Type::App { module, name, .. },
TypedExpr::Var {
name: constr_name, ..
},
) = (&**tipo, &**fun)
{
let mut term: Term<Name> =
Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![]));
match (&**tipo, &**fun) {
(
Type::App {
name: tipo_name, ..
},
TypedExpr::Var {
constructor: ValueConstructor { variant, .. },
..
},
) => match variant {
ValueConstructorVariant::LocalVariable { .. } => todo!(),
ValueConstructorVariant::ModuleConstant { .. } => todo!(),
ValueConstructorVariant::ModuleFn { name, module, .. } => {
let func_key = FunctionAccessKey {
module_name: module.to_string(),
function_name: name.to_string(),
};
if let Some(val) = self.function_recurse_lookup.get(&func_key) {
self.function_recurse_lookup.insert(func_key, *val + 1);
} else {
self.function_recurse_lookup.insert(func_key, 1);
}
let mut term =
self.recurse_code_gen(fun, scope_level.scope_increment(1));
if let Some(data_type) = self.data_types.get(&DataTypeKey {
module_name: module.to_string(),
defined_type: name.to_string(),
}) {
let constr = data_type
.constructors
.iter()
.find(|x| x.name == *constr_name)
.unwrap();
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
}
term
}
ValueConstructorVariant::Record {
name: constr_name,
module,
..
} => {
let mut term: Term<Name> =
Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![]));
let arg_to_data: Vec<(bool, Term<Name>)> = constr
.arguments
.iter()
.map(|x| {
if let Type::App { name, .. } = &*x.tipo {
if name == "ByteArray" {
(true, Term::Builtin(DefaultFunction::BData))
} else if name == "Int" {
(true, Term::Builtin(DefaultFunction::IData))
} else {
(false, Term::Constant(Constant::Unit))
}
} else {
unreachable!()
}
})
.collect();
if let Some(data_type) = self.data_types.get(&DataTypeKey {
module_name: module.to_string(),
defined_type: tipo_name.to_string(),
}) {
let (constr_index, constr) = data_type
.constructors
.iter()
.enumerate()
.find(|(_, x)| x.name == *constr_name)
.unwrap();
for (i, arg) in args.iter().enumerate().rev() {
let arg_term = self.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 1),
);
term = Term::Apply {
function: Term::Apply {
function: Term::Force(
Term::Builtin(DefaultFunction::MkCons).into(),
)
.into(),
argument: if arg_to_data[i].0 {
Term::Apply {
function: arg_to_data[i].1.clone().into(),
argument: arg_term.into(),
// TODO: order arguments by data type field map
let arg_to_data: Vec<(bool, Term<Name>)> = constr
.arguments
.iter()
.map(|x| {
if let Type::App { name, .. } = &*x.tipo {
if name == "ByteArray" {
(true, Term::Builtin(DefaultFunction::BData))
} else if name == "Int" {
(true, Term::Builtin(DefaultFunction::IData))
} else {
(false, Term::Constant(Constant::Unit))
}
} else {
unreachable!()
}
.into()
} else {
arg_term.into()
},
}
.into(),
argument: term.into(),
};
}
term
} else {
let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1));
})
.collect();
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
for (i, arg) in args.iter().enumerate().rev() {
let arg_term = self.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
}
term
}
} else {
let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1));
scope_level.scope_increment(i as i32 + 1),
);
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
term = Term::Apply {
function: Term::Apply {
function: Term::Force(
Term::Builtin(DefaultFunction::MkCons).into(),
)
.into(),
argument: if arg_to_data[i].0 {
Term::Apply {
function: arg_to_data[i].1.clone().into(),
argument: arg_term.into(),
}
.into()
} else {
arg_term.into()
},
}
.into(),
argument: term.into(),
};
}
term = Term::Apply {
function: Term::Builtin(DefaultFunction::ConstrData).into(),
argument: Term::Apply {
function: Term::Apply {
function: Term::Builtin(DefaultFunction::MkPairData)
.into(),
argument: Term::Constant(Constant::Data(
PlutusData::BigInt(BigInt::Int(
(constr_index as i128).try_into().unwrap(),
)),
))
.into(),
}
.into(),
argument: Term::Apply {
function: Term::Builtin(DefaultFunction::ListData)
.into(),
argument: term.into(),
}
.into(),
}
.into(),
};
term
} else {
let mut term =
self.recurse_code_gen(fun, scope_level.scope_increment(1));
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
}
term
}
}
},
(
Type::App {
name: tipo_name, ..
},
TypedExpr::ModuleSelect {
constructor,
module_name: module,
..
},
) => {
match constructor {
ModuleValueConstructor::Constant { .. } => todo!(),
ModuleValueConstructor::Fn { name, module, .. } => {
let func_key = FunctionAccessKey {
module_name: module.to_string(),
function_name: name.to_string(),
};
if let Some(val) = self.function_recurse_lookup.get(&func_key) {
self.function_recurse_lookup.insert(func_key, *val + 1);
} else {
self.function_recurse_lookup.insert(func_key, 1);
}
let mut term =
self.recurse_code_gen(fun, scope_level.scope_increment(1));
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
}
term
}
ModuleValueConstructor::Record {
name: constr_name, ..
} => {
let mut term: Term<Name> = Term::Constant(Constant::ProtoList(
uplc::ast::Type::Data,
vec![],
));
if let Some(data_type) = self.data_types.get(&DataTypeKey {
module_name: module.to_string(),
defined_type: tipo_name.to_string(),
}) {
let (constr_index, constr) = data_type
.constructors
.iter()
.enumerate()
.find(|(_, x)| x.name == *constr_name)
.unwrap();
// TODO: order arguments by data type field map
let arg_to_data: Vec<(bool, Term<Name>)> = constr
.arguments
.iter()
.map(|x| {
if let Type::App { name, .. } = &*x.tipo {
if name == "ByteArray" {
(true, Term::Builtin(DefaultFunction::BData))
} else if name == "Int" {
(true, Term::Builtin(DefaultFunction::IData))
} else {
(false, Term::Constant(Constant::Unit))
}
} else {
unreachable!()
}
})
.collect();
for (i, arg) in args.iter().enumerate().rev() {
let arg_term = self.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 1),
);
term = Term::Apply {
function: Term::Apply {
function: Term::Force(
Term::Builtin(DefaultFunction::MkCons).into(),
)
.into(),
argument: if arg_to_data[i].0 {
Term::Apply {
function: arg_to_data[i].1.clone().into(),
argument: arg_term.into(),
}
.into()
} else {
arg_term.into()
},
}
.into(),
argument: term.into(),
};
}
term = Term::Apply {
function: Term::Builtin(DefaultFunction::ConstrData).into(),
argument: Term::Apply {
function: Term::Apply {
function: Term::Builtin(
DefaultFunction::MkPairData,
)
.into(),
argument: Term::Constant(Constant::Data(
PlutusData::BigInt(BigInt::Int(
(constr_index as i128).try_into().unwrap(),
)),
))
.into(),
}
.into(),
argument: Term::Apply {
function: Term::Builtin(DefaultFunction::ListData)
.into(),
argument: term.into(),
}
.into(),
}
.into(),
};
term
} else {
let mut term =
self.recurse_code_gen(fun, scope_level.scope_increment(1));
for (i, arg) in args.iter().enumerate() {
term = Term::Apply {
function: term.into(),
argument: self
.recurse_code_gen(
&arg.value,
scope_level.scope_increment(i as i32 + 2),
)
.into(),
};
}
term
}
}
}
}
term
_ => todo!(),
}
}
TypedExpr::BinOp {
@ -2114,7 +2362,6 @@ impl<'a> CodeGenerator<'a> {
scope_level: ScopeLevels,
) -> Term<Name> {
let mut term = current_term;
// attempt to insert function definitions where needed
for func_key in self.uplc_function_holder_lookup.clone().keys() {
if scope_level.is_less_than(
@ -2126,11 +2373,71 @@ impl<'a> CodeGenerator<'a> {
) {
let func_def = self.functions.get(func_key).unwrap();
let current_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0);
let mut function_body = self.recurse_code_gen(
&func_def.body,
scope_level.scope_increment_sequence(func_def.arguments.len() as i32),
);
let recurse_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0);
if recurse_called > current_called {
for arg in func_def.arguments.iter().rev() {
function_body = Term::Lambda {
parameter_name: Name {
text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
unique: Unique::new(0),
},
body: Rc::new(function_body),
}
}
function_body = Term::Lambda {
parameter_name: Name {
text: format!("{}_{}", func_key.module_name, func_key.function_name),
unique: 0.into(),
},
body: function_body.into(),
};
let mut recurse_term = Term::Apply {
function: Term::Var(Name {
text: "recurse".to_string(),
unique: 0.into(),
})
.into(),
argument: Term::Var(Name {
text: "recurse".into(),
unique: 0.into(),
})
.into(),
};
for arg in func_def.arguments.iter() {
recurse_term = Term::Apply {
function: recurse_term.into(),
argument: Term::Var(Name {
text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
unique: 0.into(),
})
.into(),
};
}
function_body = Term::Apply {
function: Term::Lambda {
parameter_name: Name {
text: "recurse".into(),
unique: 0.into(),
},
body: recurse_term.into(),
}
.into(),
argument: function_body.into(),
}
}
for arg in func_def.arguments.iter().rev() {
function_body = Term::Lambda {
parameter_name: Name {

View File

@ -28,20 +28,20 @@ pub fn final_check(z: Int) {
z < 4
}
pub fn incrementor(counter: Int, target: Int) -> Int {
if counter == target {
target
} else {
incrementor(counter + 1, target)
}
}
pub fn spend(
datum: sample.Datum,
rdmr: Redeemer,
ctx: spend.ScriptContext,
) -> Bool {
let x = datum.rdmr
let y = [datum.fin, 2, 3]
let z = [1, ..y]
when z is {
[] -> False
[a] -> a == 1
[a, b] -> b == 2
[a, b, c] -> a > 1
[a, b, c, ..d] -> b > 1
_other -> True
}
let x = Sell
let z = incrementor(0, 4) == 4
z
}