feat: support self recursion functions and fix making constrs
This commit is contained in:
parent
09e77e1918
commit
6babebde28
|
@ -6,12 +6,13 @@ use uplc::{
|
|||
ast::{Constant, Name, Program, Term, Type as UplcType, Unique},
|
||||
builtins::DefaultFunction,
|
||||
parser::interner::Interner,
|
||||
BigInt, PlutusData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ast::{AssignmentKind, BinOp, DataType, Function, Pattern, Span, TypedArg, TypedPattern},
|
||||
expr::TypedExpr,
|
||||
tipo::{self, ModuleValueConstructor, Type, ValueConstructorVariant},
|
||||
tipo::{self, ModuleValueConstructor, Type, ValueConstructor, ValueConstructorVariant},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
|
@ -96,7 +97,7 @@ pub struct DataTypeKey {
|
|||
|
||||
pub type ConstrUsageKey = String;
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct FunctionAccessKey {
|
||||
pub module_name: String,
|
||||
pub function_name: String,
|
||||
|
@ -123,6 +124,7 @@ pub struct CodeGenerator<'a> {
|
|||
uplc_data_holder_lookup: IndexMap<ConstrFieldKey, ScopedExpr>,
|
||||
uplc_data_constr_lookup: IndexMap<DataTypeKey, ScopeLevels>,
|
||||
uplc_data_usage_holder_lookup: IndexMap<ConstrUsageKey, ScopeLevels>,
|
||||
function_recurse_lookup: IndexMap<FunctionAccessKey, usize>,
|
||||
functions: &'a HashMap<FunctionAccessKey, &'a Function<Arc<tipo::Type>, TypedExpr>>,
|
||||
// type_aliases: &'a HashMap<(String, String), &'a TypeAlias<Arc<tipo::Type>>>,
|
||||
data_types: &'a HashMap<DataTypeKey, &'a DataType<Arc<tipo::Type>>>,
|
||||
|
@ -144,6 +146,7 @@ impl<'a> CodeGenerator<'a> {
|
|||
uplc_data_holder_lookup: IndexMap::new(),
|
||||
uplc_data_constr_lookup: IndexMap::new(),
|
||||
uplc_data_usage_holder_lookup: IndexMap::new(),
|
||||
function_recurse_lookup: IndexMap::new(),
|
||||
functions,
|
||||
// type_aliases,
|
||||
data_types,
|
||||
|
@ -241,8 +244,6 @@ impl<'a> CodeGenerator<'a> {
|
|||
term,
|
||||
};
|
||||
|
||||
println!("{}", program.to_pretty());
|
||||
|
||||
let mut interner = Interner::new();
|
||||
|
||||
interner.program(&mut program);
|
||||
|
@ -286,15 +287,15 @@ impl<'a> CodeGenerator<'a> {
|
|||
})
|
||||
.unwrap();
|
||||
|
||||
self.recurse_scope_level(&func_def.body, scope_level.clone());
|
||||
|
||||
self.uplc_function_holder_lookup.insert(
|
||||
FunctionAccessKey {
|
||||
module_name: module,
|
||||
function_name: name,
|
||||
},
|
||||
scope_level,
|
||||
scope_level.clone(),
|
||||
);
|
||||
|
||||
self.recurse_scope_level(&func_def.body, scope_level);
|
||||
} else if scope_level.is_less_than(
|
||||
self.uplc_function_holder_lookup
|
||||
.get(&FunctionAccessKey {
|
||||
|
@ -467,7 +468,7 @@ impl<'a> CodeGenerator<'a> {
|
|||
}
|
||||
}
|
||||
TypedExpr::ModuleSelect { constructor, .. } => match constructor {
|
||||
ModuleValueConstructor::Record { .. } => todo!(),
|
||||
ModuleValueConstructor::Record { .. } => {}
|
||||
ModuleValueConstructor::Fn { module, name, .. } => {
|
||||
if self
|
||||
.uplc_function_holder_lookup
|
||||
|
@ -804,7 +805,51 @@ impl<'a> CodeGenerator<'a> {
|
|||
text: format!("{module}_{name}"),
|
||||
unique: 0.into(),
|
||||
}),
|
||||
ValueConstructorVariant::Record { .. } => todo!(),
|
||||
ValueConstructorVariant::Record {
|
||||
name: constr_name, ..
|
||||
} => {
|
||||
let data_type_key = match &*constructor.tipo {
|
||||
Type::App { module, name, .. } => DataTypeKey {
|
||||
module_name: module.to_string(),
|
||||
defined_type: name.to_string(),
|
||||
},
|
||||
Type::Fn { .. } => todo!(),
|
||||
Type::Var { .. } => todo!(),
|
||||
};
|
||||
|
||||
if let Some(data_type) = self.data_types.get(&data_type_key) {
|
||||
let (constr_index, _constr) = data_type
|
||||
.constructors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, x)| x.name == *constr_name)
|
||||
.unwrap();
|
||||
|
||||
Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::ConstrData).into(),
|
||||
argument: Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::MkPairData)
|
||||
.into(),
|
||||
argument: Term::Constant(Constant::Data(
|
||||
PlutusData::BigInt(BigInt::Int(
|
||||
(constr_index as i128).try_into().unwrap(),
|
||||
)),
|
||||
))
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
argument: Term::Constant(Constant::Data(
|
||||
PlutusData::Array(vec![]),
|
||||
))
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -891,102 +936,305 @@ impl<'a> CodeGenerator<'a> {
|
|||
TypedExpr::Call {
|
||||
fun, args, tipo, ..
|
||||
} => {
|
||||
if let (
|
||||
Type::App { module, name, .. },
|
||||
TypedExpr::Var {
|
||||
name: constr_name, ..
|
||||
},
|
||||
) = (&**tipo, &**fun)
|
||||
{
|
||||
let mut term: Term<Name> =
|
||||
Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![]));
|
||||
match (&**tipo, &**fun) {
|
||||
(
|
||||
Type::App {
|
||||
name: tipo_name, ..
|
||||
},
|
||||
TypedExpr::Var {
|
||||
constructor: ValueConstructor { variant, .. },
|
||||
..
|
||||
},
|
||||
) => match variant {
|
||||
ValueConstructorVariant::LocalVariable { .. } => todo!(),
|
||||
ValueConstructorVariant::ModuleConstant { .. } => todo!(),
|
||||
ValueConstructorVariant::ModuleFn { name, module, .. } => {
|
||||
let func_key = FunctionAccessKey {
|
||||
module_name: module.to_string(),
|
||||
function_name: name.to_string(),
|
||||
};
|
||||
if let Some(val) = self.function_recurse_lookup.get(&func_key) {
|
||||
self.function_recurse_lookup.insert(func_key, *val + 1);
|
||||
} else {
|
||||
self.function_recurse_lookup.insert(func_key, 1);
|
||||
}
|
||||
let mut term =
|
||||
self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
|
||||
if let Some(data_type) = self.data_types.get(&DataTypeKey {
|
||||
module_name: module.to_string(),
|
||||
defined_type: name.to_string(),
|
||||
}) {
|
||||
let constr = data_type
|
||||
.constructors
|
||||
.iter()
|
||||
.find(|x| x.name == *constr_name)
|
||||
.unwrap();
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
}
|
||||
ValueConstructorVariant::Record {
|
||||
name: constr_name,
|
||||
module,
|
||||
..
|
||||
} => {
|
||||
let mut term: Term<Name> =
|
||||
Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![]));
|
||||
|
||||
let arg_to_data: Vec<(bool, Term<Name>)> = constr
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if let Type::App { name, .. } = &*x.tipo {
|
||||
if name == "ByteArray" {
|
||||
(true, Term::Builtin(DefaultFunction::BData))
|
||||
} else if name == "Int" {
|
||||
(true, Term::Builtin(DefaultFunction::IData))
|
||||
} else {
|
||||
(false, Term::Constant(Constant::Unit))
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if let Some(data_type) = self.data_types.get(&DataTypeKey {
|
||||
module_name: module.to_string(),
|
||||
defined_type: tipo_name.to_string(),
|
||||
}) {
|
||||
let (constr_index, constr) = data_type
|
||||
.constructors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, x)| x.name == *constr_name)
|
||||
.unwrap();
|
||||
|
||||
for (i, arg) in args.iter().enumerate().rev() {
|
||||
let arg_term = self.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 1),
|
||||
);
|
||||
|
||||
term = Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Force(
|
||||
Term::Builtin(DefaultFunction::MkCons).into(),
|
||||
)
|
||||
.into(),
|
||||
argument: if arg_to_data[i].0 {
|
||||
Term::Apply {
|
||||
function: arg_to_data[i].1.clone().into(),
|
||||
argument: arg_term.into(),
|
||||
// TODO: order arguments by data type field map
|
||||
let arg_to_data: Vec<(bool, Term<Name>)> = constr
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if let Type::App { name, .. } = &*x.tipo {
|
||||
if name == "ByteArray" {
|
||||
(true, Term::Builtin(DefaultFunction::BData))
|
||||
} else if name == "Int" {
|
||||
(true, Term::Builtin(DefaultFunction::IData))
|
||||
} else {
|
||||
(false, Term::Constant(Constant::Unit))
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
arg_term.into()
|
||||
},
|
||||
}
|
||||
.into(),
|
||||
argument: term.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
} else {
|
||||
let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
for (i, arg) in args.iter().enumerate().rev() {
|
||||
let arg_term = self.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
}
|
||||
} else {
|
||||
let mut term = self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
scope_level.scope_increment(i as i32 + 1),
|
||||
);
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
term = Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Force(
|
||||
Term::Builtin(DefaultFunction::MkCons).into(),
|
||||
)
|
||||
.into(),
|
||||
argument: if arg_to_data[i].0 {
|
||||
Term::Apply {
|
||||
function: arg_to_data[i].1.clone().into(),
|
||||
argument: arg_term.into(),
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
arg_term.into()
|
||||
},
|
||||
}
|
||||
.into(),
|
||||
argument: term.into(),
|
||||
};
|
||||
}
|
||||
|
||||
term = Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::ConstrData).into(),
|
||||
argument: Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::MkPairData)
|
||||
.into(),
|
||||
argument: Term::Constant(Constant::Data(
|
||||
PlutusData::BigInt(BigInt::Int(
|
||||
(constr_index as i128).try_into().unwrap(),
|
||||
)),
|
||||
))
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
argument: Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::ListData)
|
||||
.into(),
|
||||
argument: term.into(),
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
};
|
||||
|
||||
term
|
||||
} else {
|
||||
let mut term =
|
||||
self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
(
|
||||
Type::App {
|
||||
name: tipo_name, ..
|
||||
},
|
||||
TypedExpr::ModuleSelect {
|
||||
constructor,
|
||||
module_name: module,
|
||||
..
|
||||
},
|
||||
) => {
|
||||
match constructor {
|
||||
ModuleValueConstructor::Constant { .. } => todo!(),
|
||||
ModuleValueConstructor::Fn { name, module, .. } => {
|
||||
let func_key = FunctionAccessKey {
|
||||
module_name: module.to_string(),
|
||||
function_name: name.to_string(),
|
||||
};
|
||||
if let Some(val) = self.function_recurse_lookup.get(&func_key) {
|
||||
self.function_recurse_lookup.insert(func_key, *val + 1);
|
||||
} else {
|
||||
self.function_recurse_lookup.insert(func_key, 1);
|
||||
}
|
||||
let mut term =
|
||||
self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
}
|
||||
ModuleValueConstructor::Record {
|
||||
name: constr_name, ..
|
||||
} => {
|
||||
let mut term: Term<Name> = Term::Constant(Constant::ProtoList(
|
||||
uplc::ast::Type::Data,
|
||||
vec![],
|
||||
));
|
||||
|
||||
if let Some(data_type) = self.data_types.get(&DataTypeKey {
|
||||
module_name: module.to_string(),
|
||||
defined_type: tipo_name.to_string(),
|
||||
}) {
|
||||
let (constr_index, constr) = data_type
|
||||
.constructors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, x)| x.name == *constr_name)
|
||||
.unwrap();
|
||||
|
||||
// TODO: order arguments by data type field map
|
||||
let arg_to_data: Vec<(bool, Term<Name>)> = constr
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if let Type::App { name, .. } = &*x.tipo {
|
||||
if name == "ByteArray" {
|
||||
(true, Term::Builtin(DefaultFunction::BData))
|
||||
} else if name == "Int" {
|
||||
(true, Term::Builtin(DefaultFunction::IData))
|
||||
} else {
|
||||
(false, Term::Constant(Constant::Unit))
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (i, arg) in args.iter().enumerate().rev() {
|
||||
let arg_term = self.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 1),
|
||||
);
|
||||
|
||||
term = Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Force(
|
||||
Term::Builtin(DefaultFunction::MkCons).into(),
|
||||
)
|
||||
.into(),
|
||||
argument: if arg_to_data[i].0 {
|
||||
Term::Apply {
|
||||
function: arg_to_data[i].1.clone().into(),
|
||||
argument: arg_term.into(),
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
arg_term.into()
|
||||
},
|
||||
}
|
||||
.into(),
|
||||
argument: term.into(),
|
||||
};
|
||||
}
|
||||
|
||||
term = Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::ConstrData).into(),
|
||||
argument: Term::Apply {
|
||||
function: Term::Apply {
|
||||
function: Term::Builtin(
|
||||
DefaultFunction::MkPairData,
|
||||
)
|
||||
.into(),
|
||||
argument: Term::Constant(Constant::Data(
|
||||
PlutusData::BigInt(BigInt::Int(
|
||||
(constr_index as i128).try_into().unwrap(),
|
||||
)),
|
||||
))
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
argument: Term::Apply {
|
||||
function: Term::Builtin(DefaultFunction::ListData)
|
||||
.into(),
|
||||
argument: term.into(),
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
.into(),
|
||||
};
|
||||
|
||||
term
|
||||
} else {
|
||||
let mut term =
|
||||
self.recurse_code_gen(fun, scope_level.scope_increment(1));
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
term = Term::Apply {
|
||||
function: term.into(),
|
||||
argument: self
|
||||
.recurse_code_gen(
|
||||
&arg.value,
|
||||
scope_level.scope_increment(i as i32 + 2),
|
||||
)
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
term
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
term
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
TypedExpr::BinOp {
|
||||
|
@ -2114,7 +2362,6 @@ impl<'a> CodeGenerator<'a> {
|
|||
scope_level: ScopeLevels,
|
||||
) -> Term<Name> {
|
||||
let mut term = current_term;
|
||||
|
||||
// attempt to insert function definitions where needed
|
||||
for func_key in self.uplc_function_holder_lookup.clone().keys() {
|
||||
if scope_level.is_less_than(
|
||||
|
@ -2126,11 +2373,71 @@ impl<'a> CodeGenerator<'a> {
|
|||
) {
|
||||
let func_def = self.functions.get(func_key).unwrap();
|
||||
|
||||
let current_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0);
|
||||
|
||||
let mut function_body = self.recurse_code_gen(
|
||||
&func_def.body,
|
||||
scope_level.scope_increment_sequence(func_def.arguments.len() as i32),
|
||||
);
|
||||
|
||||
let recurse_called = *self.function_recurse_lookup.get(func_key).unwrap_or(&0);
|
||||
|
||||
if recurse_called > current_called {
|
||||
for arg in func_def.arguments.iter().rev() {
|
||||
function_body = Term::Lambda {
|
||||
parameter_name: Name {
|
||||
text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
|
||||
unique: Unique::new(0),
|
||||
},
|
||||
body: Rc::new(function_body),
|
||||
}
|
||||
}
|
||||
|
||||
function_body = Term::Lambda {
|
||||
parameter_name: Name {
|
||||
text: format!("{}_{}", func_key.module_name, func_key.function_name),
|
||||
unique: 0.into(),
|
||||
},
|
||||
body: function_body.into(),
|
||||
};
|
||||
|
||||
let mut recurse_term = Term::Apply {
|
||||
function: Term::Var(Name {
|
||||
text: "recurse".to_string(),
|
||||
unique: 0.into(),
|
||||
})
|
||||
.into(),
|
||||
argument: Term::Var(Name {
|
||||
text: "recurse".into(),
|
||||
unique: 0.into(),
|
||||
})
|
||||
.into(),
|
||||
};
|
||||
|
||||
for arg in func_def.arguments.iter() {
|
||||
recurse_term = Term::Apply {
|
||||
function: recurse_term.into(),
|
||||
argument: Term::Var(Name {
|
||||
text: arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
|
||||
unique: 0.into(),
|
||||
})
|
||||
.into(),
|
||||
};
|
||||
}
|
||||
|
||||
function_body = Term::Apply {
|
||||
function: Term::Lambda {
|
||||
parameter_name: Name {
|
||||
text: "recurse".into(),
|
||||
unique: 0.into(),
|
||||
},
|
||||
body: recurse_term.into(),
|
||||
}
|
||||
.into(),
|
||||
argument: function_body.into(),
|
||||
}
|
||||
}
|
||||
|
||||
for arg in func_def.arguments.iter().rev() {
|
||||
function_body = Term::Lambda {
|
||||
parameter_name: Name {
|
||||
|
|
|
@ -28,20 +28,20 @@ pub fn final_check(z: Int) {
|
|||
z < 4
|
||||
}
|
||||
|
||||
pub fn incrementor(counter: Int, target: Int) -> Int {
|
||||
if counter == target {
|
||||
target
|
||||
} else {
|
||||
incrementor(counter + 1, target)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spend(
|
||||
datum: sample.Datum,
|
||||
rdmr: Redeemer,
|
||||
ctx: spend.ScriptContext,
|
||||
) -> Bool {
|
||||
let x = datum.rdmr
|
||||
let y = [datum.fin, 2, 3]
|
||||
let z = [1, ..y]
|
||||
when z is {
|
||||
[] -> False
|
||||
[a] -> a == 1
|
||||
[a, b] -> b == 2
|
||||
[a, b, c] -> a > 1
|
||||
[a, b, c, ..d] -> b > 1
|
||||
_other -> True
|
||||
}
|
||||
let x = Sell
|
||||
let z = incrementor(0, 4) == 4
|
||||
z
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue