fixing bugs and edge cases

This commit is contained in:
microproofs 2023-07-26 16:48:43 -04:00 committed by Kasey
parent 2f4319f162
commit 58b327e5b3
3 changed files with 148 additions and 90 deletions

View File

@ -173,11 +173,8 @@ impl<'a> CodeGenerator<'a> {
let full_tree = self.hoist_functions_to_validator(air_tree); let full_tree = self.hoist_functions_to_validator(air_tree);
// optimizations on air tree // optimizations on air tree
let full_vec = full_tree.to_vec(); let full_vec = full_tree.to_vec();
println!("FULL VEC {:#?}", full_vec);
let term = self.uplc_code_gen(full_vec); let term = self.uplc_code_gen(full_vec);
self.finalize(term) self.finalize(term)
@ -197,8 +194,8 @@ impl<'a> CodeGenerator<'a> {
term, term,
}; };
println!("Program: {}", program.to_pretty());
program = aiken_optimize_and_intern(program); program = aiken_optimize_and_intern(program);
// println!("PROGRAM {}", program.to_pretty());
// This is very important to call here. // This is very important to call here.
// If this isn't done, re-using the same instance // If this isn't done, re-using the same instance
@ -1865,7 +1862,10 @@ impl<'a> CodeGenerator<'a> {
*complex_clause = *complex_clause || elem_props.complex_clause; *complex_clause = *complex_clause || elem_props.complex_clause;
air_elems.push(statement); air_elems.push(statement);
list_tail = Some((tail, elem_name.to_string())); if &elem_name != "_" {
list_tail = Some((tail, elem_name.to_string()));
}
defined_heads.push(elem_name) defined_heads.push(elem_name)
}); });
@ -1882,6 +1882,7 @@ impl<'a> CodeGenerator<'a> {
defined_heads defined_heads
.into_iter() .into_iter()
.zip(defined_tails.into_iter()) .zip(defined_tails.into_iter())
.filter(|(head, _)| head != "_")
.map(|(head, tail)| (tail, head)) .map(|(head, tail)| (tail, head))
.collect_vec(), .collect_vec(),
list_tail, list_tail,
@ -2130,8 +2131,6 @@ impl<'a> CodeGenerator<'a> {
), ),
); );
} }
println!("WE GOT SEQUENCE {:#?}", sequence);
(AirTree::void(), AirTree::UnhoistedSequence(sequence)) (AirTree::void(), AirTree::UnhoistedSequence(sequence))
} }
} }
@ -2384,9 +2383,10 @@ impl<'a> CodeGenerator<'a> {
.get(&variant_name) .get(&variant_name)
.unwrap_or_else(|| panic!("Missing Function Variant Definition")); .unwrap_or_else(|| panic!("Missing Function Variant Definition"));
if let UserFunction::Function(body, deps) = function { if let UserFunction::Function { body, deps, params } = function {
let mut hoist_body = body.clone(); let mut hoist_body = body.clone();
let mut hoist_deps = deps.clone(); let mut hoist_deps = deps.clone();
let params = params.clone();
let mut tree_path = tree_path.clone(); let mut tree_path = tree_path.clone();
@ -2407,11 +2407,20 @@ impl<'a> CodeGenerator<'a> {
.get_mut(&variant_name) .get_mut(&variant_name)
.expect("Missing Function Variant Definition"); .expect("Missing Function Variant Definition");
*function = UserFunction::Function(hoist_body, hoist_deps); if params.is_empty() {
validator_hoistable.push((key, variant_name));
}
*function = UserFunction::Function {
body: hoist_body,
deps: hoist_deps,
params,
};
} else { } else {
todo!("Deal with Link later") todo!("Deal with Link later")
} }
} }
validator_hoistable.dedup();
// First we need to sort functions by dependencies // First we need to sort functions by dependencies
// here's also where we deal with mutual recursion // here's also where we deal with mutual recursion
@ -2433,7 +2442,7 @@ impl<'a> CodeGenerator<'a> {
.unwrap_or_else(|| panic!("Missing Function Variant Definition")); .unwrap_or_else(|| panic!("Missing Function Variant Definition"));
// TODO: change this part to handle mutual recursion // TODO: change this part to handle mutual recursion
if let UserFunction::Function(_, deps) = function { if let UserFunction::Function { deps, .. } = function {
if function_has_params { if function_has_params {
for (dep_generic_func, dep_variant) in deps.iter() { for (dep_generic_func, dep_variant) in deps.iter() {
if !(dep_generic_func == &generic_func && dep_variant == &variant) { if !(dep_generic_func == &generic_func && dep_variant == &variant) {
@ -2501,9 +2510,13 @@ impl<'a> CodeGenerator<'a> {
>, >,
hoisted_functions: &mut Vec<(FunctionAccessKey, String)>, hoisted_functions: &mut Vec<(FunctionAccessKey, String)>,
) { ) {
if let UserFunction::Function(body, func_deps) = function { if let UserFunction::Function {
body,
deps: func_deps,
params,
} = function
{
let mut body = body.clone(); let mut body = body.clone();
let node_to_edit = air_tree.find_air_tree_node(tree_path);
let (key, variant) = key_var; let (key, variant) = key_var;
@ -2513,28 +2526,7 @@ impl<'a> CodeGenerator<'a> {
.any(|(dep_key, dep_variant)| dep_key == key && dep_variant == variant); .any(|(dep_key, dep_variant)| dep_key == key && dep_variant == variant);
// first grab dependencies // first grab dependencies
let func_params = self let func_params = params;
.functions
.get(key)
.map(|func| {
func.arguments
.iter()
.map(|func_arg| {
func_arg
.arg_name
.get_variable_name()
.unwrap_or("_")
.to_string()
})
.collect_vec()
})
.unwrap_or_else(|| {
let Some(CodeGenFunction::Function { params, .. }) =
self.code_gen_functions.get(&key.function_name)
else { unreachable!() };
params.clone()
});
let params_empty = func_params.is_empty(); let params_empty = func_params.is_empty();
@ -2551,7 +2543,7 @@ impl<'a> CodeGenerator<'a> {
&key.function_name, &key.function_name,
&key.module_name, &key.module_name,
variant, variant,
func_params, func_params.clone(),
is_recursive, is_recursive,
body, body,
); );
@ -2564,6 +2556,7 @@ impl<'a> CodeGenerator<'a> {
hoisted_functions, hoisted_functions,
functions_to_hoist, functions_to_hoist,
); );
let node_to_edit = air_tree.find_air_tree_node(tree_path);
// now hoist full function onto validator tree // now hoist full function onto validator tree
*node_to_edit = function_deps.hoist_over(body.hoist_over(node_to_edit.clone())); *node_to_edit = function_deps.hoist_over(body.hoist_over(node_to_edit.clone()));
@ -2630,34 +2623,15 @@ impl<'a> CodeGenerator<'a> {
// In the case of zero args, we need to hoist the dependency function to the top of the zero arg function // In the case of zero args, we need to hoist the dependency function to the top of the zero arg function
if &dep_path.common_ancestor(func_path) == func_path || params_empty { if &dep_path.common_ancestor(func_path) == func_path || params_empty {
let dependent_params = self let UserFunction::Function { body: mut dep_air_tree, deps: dependency_deps, params: dependent_params } =
.functions
.get(&dep_key)
.map(|dep_func| {
dep_func
.arguments
.iter()
.map(|func_arg| {
func_arg
.arg_name
.get_variable_name()
.unwrap_or("_")
.to_string()
})
.collect_vec()
})
.unwrap_or_else(|| {
let Some(CodeGenFunction::Function { params, .. }) =
self.code_gen_functions.get(&dep_key.function_name)
else { unreachable!() };
params.clone()
});
let UserFunction::Function(mut dep_air_tree, dependency_deps) =
dep_function.clone() dep_function.clone()
else { unreachable!() }; else { unreachable!() };
if dependent_params.is_empty() {
// continue for zero arg functions. They are treated like global hoists.
continue;
}
let is_dependent_recursive = dependency_deps let is_dependent_recursive = dependency_deps
.iter() .iter()
.any(|(key, variant)| &dep_key == key && &dep_variant == variant); .any(|(key, variant)| &dep_key == key && &dep_variant == variant);
@ -2666,8 +2640,8 @@ impl<'a> CodeGenerator<'a> {
.iter() .iter()
.filter(|(dep_k, dep_v)| !(dep_k == &dep_key && dep_v == &dep_variant)) .filter(|(dep_k, dep_v)| !(dep_k == &dep_key && dep_v == &dep_variant))
.filter(|(dep_k, dep_v)| { .filter(|(dep_k, dep_v)| {
!params_empty params_empty
&& !hoisted_functions || !hoisted_functions
.iter() .iter()
.any(|(generic, variant)| generic == dep_k && variant == dep_v) .any(|(generic, variant)| generic == dep_k && variant == dep_v)
}) })
@ -2690,10 +2664,7 @@ impl<'a> CodeGenerator<'a> {
)); ));
deps_vec.extend(dependency_deps_to_add); deps_vec.extend(dependency_deps_to_add);
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
} }
} }
@ -2800,7 +2771,7 @@ impl<'a> CodeGenerator<'a> {
let (path, _) = func_variants.get_mut("").unwrap(); let (path, _) = func_variants.get_mut("").unwrap();
*path = path.common_ancestor(tree_path); *path = path.common_ancestor(tree_path);
} else { } else {
let CodeGenFunction::Function{ body, .. } = code_gen_func let CodeGenFunction::Function{ body, params } = code_gen_func
else { unreachable!() }; else { unreachable!() };
let mut function_variant_path = IndexMap::new(); let mut function_variant_path = IndexMap::new();
@ -2809,7 +2780,11 @@ impl<'a> CodeGenerator<'a> {
"".to_string(), "".to_string(),
( (
tree_path.clone(), tree_path.clone(),
UserFunction::Function(body.clone(), vec![]), UserFunction::Function {
body: body.clone(),
deps: vec![],
params: params.clone(),
},
), ),
); );
@ -2831,17 +2806,20 @@ impl<'a> CodeGenerator<'a> {
let mut function_def_types = function_def let mut function_def_types = function_def
.arguments .arguments
.iter() .iter()
.map(|arg| &arg.tipo) .map(|arg| convert_opaque_type(&arg.tipo, &self.data_types))
.collect_vec(); .collect_vec();
function_def_types.push(&function_def.return_type); function_def_types.push(convert_opaque_type(
&function_def.return_type,
&self.data_types,
));
let mono_types: IndexMap<u64, Arc<Type>> = if !function_def_types.is_empty() { let mono_types: IndexMap<u64, Arc<Type>> = if !function_def_types.is_empty() {
function_def_types function_def_types
.into_iter() .into_iter()
.zip(function_var_types.into_iter()) .zip(function_var_types.into_iter())
.flat_map(|(func_tipo, var_tipo)| { .flat_map(|(func_tipo, var_tipo)| {
get_generic_id_and_type(func_tipo, &var_tipo) get_generic_id_and_type(&func_tipo, &var_tipo)
}) })
.collect() .collect()
} else { } else {
@ -2867,29 +2845,45 @@ impl<'a> CodeGenerator<'a> {
if let Some((path, _)) = func_variants.get_mut(&variant) { if let Some((path, _)) = func_variants.get_mut(&variant) {
*path = path.common_ancestor(tree_path); *path = path.common_ancestor(tree_path);
} else { } else {
let params = function_def
.arguments
.iter()
.map(|arg| {
arg.arg_name.get_variable_name().unwrap_or("_").to_string()
})
.collect_vec();
let mut function_air_tree_body = self.build(&function_def.body); let mut function_air_tree_body = self.build(&function_def.body);
function_air_tree_body.traverse_tree_with(&mut |air_tree, _| { function_air_tree_body.traverse_tree_with(&mut |air_tree, _| {
monomorphize(air_tree, &mono_types);
erase_opaque_type_operations(air_tree, &self.data_types); erase_opaque_type_operations(air_tree, &self.data_types);
monomorphize(air_tree, &mono_types);
}); });
func_variants.insert( func_variants.insert(
variant, variant,
( (
tree_path.clone(), tree_path.clone(),
UserFunction::Function(function_air_tree_body, vec![]), UserFunction::Function {
body: function_air_tree_body,
deps: vec![],
params,
},
), ),
); );
} }
} else { } else {
let params = function_def
.arguments
.iter()
.map(|arg| arg.arg_name.get_variable_name().unwrap_or("_").to_string())
.collect_vec();
let mut function_air_tree_body = self.build(&function_def.body); let mut function_air_tree_body = self.build(&function_def.body);
function_air_tree_body.traverse_tree_with(&mut |air_tree, _| { function_air_tree_body.traverse_tree_with(&mut |air_tree, _| {
monomorphize(air_tree, &mono_types);
erase_opaque_type_operations(air_tree, &self.data_types); erase_opaque_type_operations(air_tree, &self.data_types);
monomorphize(air_tree, &mono_types);
}); });
let mut function_variant_path = IndexMap::new(); let mut function_variant_path = IndexMap::new();
@ -2898,7 +2892,11 @@ impl<'a> CodeGenerator<'a> {
variant, variant,
( (
tree_path.clone(), tree_path.clone(),
UserFunction::Function(function_air_tree_body, vec![]), UserFunction::Function {
body: function_air_tree_body,
deps: vec![],
params,
},
), ),
); );

View File

@ -38,7 +38,11 @@ pub enum CodeGenFunction {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum UserFunction { pub enum UserFunction {
Function(AirTree, Vec<(FunctionAccessKey, String)>), Function {
body: AirTree,
deps: Vec<(FunctionAccessKey, String)>,
params: Vec<String>,
},
Link(String), Link(String),
} }
@ -542,7 +546,7 @@ pub fn erase_opaque_type_operations(
} }
} }
AirExpression::RecordAccess { tipo, record, .. } => { AirExpression::RecordAccess { tipo, record, .. } => {
if check_replaceable_opaque_type(tipo, data_types) { if check_replaceable_opaque_type(&record.return_type(), data_types) {
*air_tree = (**record).clone(); *air_tree = (**record).clone();
} }
} }

View File

@ -74,6 +74,8 @@ fn assert_uplc(source_code: &str, expected: Term<Name>, should_fail: bool) {
let expected = optimize::aiken_optimize_and_intern(expected); let expected = optimize::aiken_optimize_and_intern(expected);
// println!("expected: {}", expected.to_pretty());
let expected: Program<DeBruijn> = expected.try_into().unwrap(); let expected: Program<DeBruijn> = expected.try_into().unwrap();
assert_eq!(debruijn_program.to_pretty(), expected.to_pretty()); assert_eq!(debruijn_program.to_pretty(), expected.to_pretty());
@ -477,7 +479,7 @@ fn acceptance_test_5_direct_2_heads() {
when xs is { when xs is {
[] -> None [] -> None
[a] -> Some(xs) [a] -> Some(xs)
[a, b, ..c] -> Some([a,b]) [a, b, ..] -> Some([a,b])
} }
} }
@ -495,11 +497,58 @@ fn acceptance_test_5_direct_2_heads() {
Term::var("xs") Term::var("xs")
.delayed_choose_list( .delayed_choose_list(
Term::Constant(Constant::Data(Data::constr(1, vec![])).into()), Term::Constant(Constant::Data(Data::constr(1, vec![])).into()),
Term::constr_data().apply(Term::integer(0.into())).apply( Term::var("tail_1")
Term::mk_cons() .delayed_choose_list(
.apply(Term::head_list().apply(Term::var("xs"))) Term::constr_data()
.apply(Term::empty_list()), .apply(Term::integer(0.into()))
), .apply(
Term::mk_cons()
.apply(Term::list_data().apply(Term::var("xs")))
.apply(Term::empty_list()),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(Term::var("xs")),
),
),
Term::constr_data()
.apply(Term::integer(0.into()))
.apply(
Term::mk_cons()
.apply(
Term::list_data().apply(
Term::mk_cons()
.apply(
Term::i_data()
.apply(Term::var("a")),
)
.apply(
Term::mk_cons()
.apply(
Term::i_data().apply(
Term::var("b"),
),
)
.apply(Term::empty_list()),
),
),
)
.apply(Term::empty_list()),
)
.lambda("b")
.apply(Term::un_i_data().apply(
Term::head_list().apply(Term::var("tail_1")),
))
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(Term::var("xs")),
),
),
)
.lambda("tail_1")
.apply(Term::tail_list().apply(Term::var("xs"))),
) )
.lambda("xs"), .lambda("xs"),
) )
@ -510,7 +559,14 @@ fn acceptance_test_5_direct_2_heads() {
])), ])),
) )
.apply(Term::Constant( .apply(Term::Constant(
Constant::Data(Data::constr(0, vec![Data::integer(1.into())])).into(), Constant::Data(Data::constr(
0,
vec![Data::list(vec![
Data::integer(1.into()),
Data::integer(2.into()),
])],
))
.into(),
)), )),
false, false,
); );
@ -749,14 +805,14 @@ fn acceptance_test_7_unzip() {
.apply(Term::var("unzip")) .apply(Term::var("unzip"))
.apply(Term::var("rest")), .apply(Term::var("rest")),
) )
.lambda("a")
.apply(Term::un_i_data().apply(
Term::fst_pair().apply(Term::var("head_pair")),
))
.lambda("b") .lambda("b")
.apply(Term::un_b_data().apply( .apply(Term::un_b_data().apply(
Term::snd_pair().apply(Term::var("head_pair")), Term::snd_pair().apply(Term::var("head_pair")),
)) ))
.lambda("a")
.apply(Term::un_i_data().apply(
Term::fst_pair().apply(Term::var("head_pair")),
))
.lambda("rest") .lambda("rest")
.apply(Term::tail_list().apply(Term::var("xs"))) .apply(Term::tail_list().apply(Term::var("xs")))
.lambda("head_pair") .lambda("head_pair")