diff --git a/.editorconfig b/.editorconfig index 0759674c..4f309038 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,9 +1,13 @@ root = true - -[*.ak] -indent_style = space -indent_size = 2 end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true + +[*.ak] +indent_style = space +indent_size = 2 + +[Makefile] +indent_style = tabs +indent_size = 4 diff --git a/crates/cli/src/cmd/check.rs b/crates/cli/src/cmd/check.rs index 2ec59807..15c7b8f7 100644 --- a/crates/cli/src/cmd/check.rs +++ b/crates/cli/src/cmd/check.rs @@ -11,6 +11,10 @@ pub struct Args { #[clap(short, long)] skip_tests: bool, + /// When enabled, also pretty-print test UPLC on failure + #[clap(long)] + debug: bool, + /// Only run tests if their path + name match the given string #[clap(short, long)] match_tests: Option, @@ -20,8 +24,11 @@ pub fn exec( Args { directory, skip_tests, + debug, match_tests, }: Args, ) -> miette::Result<()> { - crate::with_project(directory, |p| p.check(skip_tests, match_tests.clone())) + crate::with_project(directory, |p| { + p.check(skip_tests, match_tests.clone(), debug) + }) } diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index 0034bf3e..dc195f14 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -1,10 +1,7 @@ -use std::{env, path::PathBuf}; +use std::collections::BTreeMap; +use std::{env, path::PathBuf, process}; -use aiken_project::{ - config::Config, - telemetry::{self, TestInfo}, - Project, -}; +use aiken_project::{config::Config, pretty, script::EvalInfo, telemetry, Project}; use miette::IntoDiagnostic; use owo_colors::OwoColorize; use uplc::machine::cost_model::ExBudget; @@ -35,12 +32,20 @@ where if let Err(err) = build_result { err.report(); - - miette::bail!("Failed: {} error(s), {warning_count} warning(s)", err.len(),); - }; - - println!("\nFinished with {warning_count} warning(s)\n"); - + println!("{}", "Summary".purple().bold()); + println!( + " {} error(s), {}", + err.len(), + format!("{warning_count} warning(s)").yellow(), + ); + process::exit(1); + } else { + println!("{}", "Summary".purple().bold()); + println!( + " 0 error, {}", + format!("{warning_count} warning(s)").yellow(), + ); + } Ok(()) } @@ -76,89 +81,149 @@ impl telemetry::EventListener for Terminal { output_path.to_str().unwrap_or("").bright_blue() ); } + telemetry::Event::EvaluatingFunction { results } => { + println!("{}\n", "...Evaluating function".bold().purple()); + + let (max_mem, max_cpu) = find_max_execution_units(&results); + + for eval_info in &results { + println!(" {}", fmt_eval(eval_info, max_mem, max_cpu)) + } + } telemetry::Event::RunningTests => { println!("{}\n", "...Running tests".bold().purple()); } telemetry::Event::FinishedTests { tests } => { - let (max_mem, max_cpu) = tests.iter().fold( - (0, 0), - |(max_mem, max_cpu), TestInfo { spent_budget, .. }| { - if spent_budget.mem >= max_mem && spent_budget.cpu >= max_cpu { - (spent_budget.mem, spent_budget.cpu) - } else if spent_budget.mem > max_mem { - (spent_budget.mem, max_cpu) - } else if spent_budget.cpu > max_cpu { - (max_mem, spent_budget.cpu) - } else { - (max_mem, max_cpu) - } - }, - ); + let (max_mem, max_cpu) = find_max_execution_units(&tests); - let max_mem = max_mem.to_string().len() as i32; - let max_cpu = max_cpu.to_string().len() as i32; - - for test_info in &tests { - println!("{}", fmt_test(test_info, max_mem, max_cpu)) + for (module, infos) in &group_by_module(&tests) { + let first = fmt_test(infos.first().unwrap(), max_mem, max_cpu, false).len(); + println!( + "{} {} {}", + " ┌──".bright_black(), + module.bold().blue(), + pretty::pad_left("".to_string(), first - module.len() - 3, "─") + .bright_black() + ); + for eval_info in infos { + println!( + " {} {}", + "│".bright_black(), + fmt_test(eval_info, max_mem, max_cpu, true) + ) + } + let last = fmt_test(infos.last().unwrap(), max_mem, max_cpu, false).len(); + let summary = fmt_test_summary(infos, false).len(); + println!( + "{} {}\n", + pretty::pad_right(" └".to_string(), last - summary + 5, "─") + .bright_black(), + fmt_test_summary(infos, true), + ); } - - let (n_passed, n_failed) = - tests - .iter() - .fold((0, 0), |(n_passed, n_failed), test_info| { - if test_info.is_passing { - (n_passed + 1, n_failed) - } else { - (n_passed, n_failed + 1) - } - }); - - println!( - "{}", - format!( - "\n Summary: {} test(s), {}; {}.", - tests.len(), - format!("{} passed", n_passed).bright_green(), - format!("{} failed", n_failed).bright_red() - ) - .bold() - ) } } } } -fn fmt_test(test_info: &TestInfo, max_mem: i32, max_cpu: i32) -> String { - let TestInfo { - is_passing, - test, +fn fmt_test(eval_info: &EvalInfo, max_mem: usize, max_cpu: usize, styled: bool) -> String { + let EvalInfo { + success, + script, spent_budget, - } = test_info; + .. + } = eval_info; + + let ExBudget { mem, cpu } = spent_budget; + let mem_pad = pretty::pad_left(mem.to_string(), max_mem, " "); + let cpu_pad = pretty::pad_left(cpu.to_string(), max_cpu, " "); + + format!( + "{} [mem: {}, cpu: {}] {}", + if *success { + pretty::style_if(styled, "PASS".to_string(), |s| s.bold().green().to_string()) + } else { + pretty::style_if(styled, "FAIL".to_string(), |s| s.bold().red().to_string()) + }, + pretty::style_if(styled, mem_pad, |s| s.bright_white().to_string()), + pretty::style_if(styled, cpu_pad, |s| s.bright_white().to_string()), + pretty::style_if(styled, script.name.clone(), |s| s.bright_blue().to_string()), + ) +} + +fn fmt_test_summary(tests: &Vec<&EvalInfo>, styled: bool) -> String { + let (n_passed, n_failed) = tests + .iter() + .fold((0, 0), |(n_passed, n_failed), test_info| { + if test_info.success { + (n_passed + 1, n_failed) + } else { + (n_passed, n_failed + 1) + } + }); + format!( + "{} | {} | {}", + pretty::style_if(styled, format!("{} tests", tests.len()), |s| s + .bold() + .to_string()), + pretty::style_if(styled, format!("{} passed", n_passed), |s| s + .bright_green() + .bold() + .to_string()), + pretty::style_if(styled, format!("{} failed", n_failed), |s| s + .bright_red() + .bold() + .to_string()), + ) +} + +fn fmt_eval(eval_info: &EvalInfo, max_mem: usize, max_cpu: usize) -> String { + let EvalInfo { + output, + script, + spent_budget, + .. + } = eval_info; let ExBudget { mem, cpu } = spent_budget; format!( - " [{}] [mem: {}, cpu: {}] {}::{}", - if *is_passing { - "PASS".bold().green().to_string() - } else { - "FAIL".bold().red().to_string() - }, - pad_left(mem.to_string(), max_mem, " "), - pad_left(cpu.to_string(), max_cpu, " "), - test.module.blue(), - test.name.bright_blue() + " {}::{} [mem: {}, cpu: {}]\n │\n ╰─▶ {}", + script.module.blue(), + script.name.bright_blue(), + pretty::pad_left(mem.to_string(), max_mem, " "), + pretty::pad_left(cpu.to_string(), max_cpu, " "), + output + .as_ref() + .map(|x| format!("{}", x)) + .unwrap_or_else(|| "Error.".to_string()), ) } -fn pad_left(mut text: String, n: i32, delimiter: &str) -> String { - let diff = n - text.len() as i32; - - if diff.is_positive() { - for _ in 0..diff { - text.insert_str(0, delimiter); - } +fn group_by_module(infos: &Vec) -> BTreeMap> { + let mut modules = BTreeMap::new(); + for eval_info in infos { + let xs: &mut Vec<&EvalInfo> = modules.entry(eval_info.script.module.clone()).or_default(); + xs.push(eval_info); } - - text + modules +} + +fn find_max_execution_units(xs: &[EvalInfo]) -> (usize, usize) { + let (max_mem, max_cpu) = xs.iter().fold( + (0, 0), + |(max_mem, max_cpu), EvalInfo { spent_budget, .. }| { + if spent_budget.mem >= max_mem && spent_budget.cpu >= max_cpu { + (spent_budget.mem, spent_budget.cpu) + } else if spent_budget.mem > max_mem { + (spent_budget.mem, max_cpu) + } else if spent_budget.cpu > max_cpu { + (max_mem, spent_budget.cpu) + } else { + (max_mem, max_cpu) + } + }, + ); + + (max_mem.to_string().len(), max_cpu.to_string().len()) } diff --git a/crates/lang/src/air.rs b/crates/lang/src/air.rs index fde5981c..9a8e8ee8 100644 --- a/crates/lang/src/air.rs +++ b/crates/lang/src/air.rs @@ -28,16 +28,13 @@ pub enum Air { scope: Vec, constructor: ValueConstructor, name: String, + variant_name: String, }, - // Fn { - // scope: Vec, - // tipo: Arc, - // is_capture: bool, - // args: Vec>>, - // body: Box, - // return_annotation: Option, - // }, + Fn { + scope: Vec, + params: Vec, + }, List { scope: Vec, count: usize, @@ -67,6 +64,7 @@ pub enum Air { Builtin { scope: Vec, func: DefaultFunction, + tipo: Arc, }, BinOp { @@ -88,6 +86,7 @@ pub enum Air { module_name: String, params: Vec, recursive: bool, + variant_name: String, }, DefineConst { @@ -237,6 +236,7 @@ impl Air { | Air::List { scope, .. } | Air::ListAccessor { scope, .. } | Air::ListExpose { scope, .. } + | Air::Fn { scope, .. } | Air::Call { scope, .. } | Air::Builtin { scope, .. } | Air::BinOp { scope, .. } diff --git a/crates/lang/src/ast.rs b/crates/lang/src/ast.rs index 3afbdd82..f839ca7a 100644 --- a/crates/lang/src/ast.rs +++ b/crates/lang/src/ast.rs @@ -66,10 +66,8 @@ impl UntypedModule { } } -pub type TypedDefinition = Definition, TypedExpr, String, String>; -pub type UntypedDefinition = Definition<(), UntypedExpr, (), ()>; - pub type TypedFunction = Function, TypedExpr>; +pub type UntypedFunction = Function<(), UntypedExpr>; #[derive(Debug, Clone, PartialEq)] pub struct Function { @@ -84,6 +82,24 @@ pub struct Function { pub end_position: usize, } +pub type TypedTypeAlias = TypeAlias>; +pub type UntypedTypeAlias = TypeAlias<()>; + +impl TypedFunction { + pub fn test_hint(&self) -> Option<(BinOp, Box, Box)> { + match &self.body { + TypedExpr::BinOp { + name, + tipo, + left, + right, + .. + } if tipo == &bool() => Some((*name, left.clone(), right.clone())), + _ => None, + } + } +} + #[derive(Debug, Clone, PartialEq)] pub struct TypeAlias { pub alias: String, @@ -95,6 +111,9 @@ pub struct TypeAlias { pub tipo: T, } +pub type TypedDataType = DataType>; +pub type UntypedDataType = DataType<()>; + #[derive(Debug, Clone, PartialEq)] pub struct DataType { pub constructors: Vec>, @@ -107,6 +126,9 @@ pub struct DataType { pub typed_parameters: Vec, } +pub type TypedUse = Use; +pub type UntypedUse = Use<()>; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Use { pub as_name: Option, @@ -116,6 +138,9 @@ pub struct Use { pub unqualified: Vec, } +pub type TypedModuleConstant = ModuleConstant, String>; +pub type UntypedModuleConstant = ModuleConstant<(), ()>; + #[derive(Debug, Clone, PartialEq)] pub struct ModuleConstant { pub doc: Option, @@ -127,6 +152,9 @@ pub struct ModuleConstant { pub tipo: T, } +pub type TypedDefinition = Definition, TypedExpr, String, String>; +pub type UntypedDefinition = Definition<(), UntypedExpr, (), ()>; + #[derive(Debug, Clone, PartialEq)] pub enum Definition { Fn(Function), diff --git a/crates/lang/src/builder.rs b/crates/lang/src/builder.rs new file mode 100644 index 00000000..522d2ea2 --- /dev/null +++ b/crates/lang/src/builder.rs @@ -0,0 +1,1181 @@ +use std::{cell::RefCell, collections::HashMap, sync::Arc}; + +use itertools::Itertools; +use uplc::{ + ast::{Constant as UplcConstant, Name, Term, Type as UplcType}, + builtins::DefaultFunction, + machine::runtime::convert_constr_to_tag, + BigInt, Constr, KeyValuePairs, PlutusData, +}; + +use crate::{ + air::Air, + ast::{Clause, Constant, Pattern, Span}, + expr::TypedExpr, + tipo::{PatternConstructor, Type, TypeVar, ValueConstructorVariant}, +}; + +#[derive(Clone, Debug)] +pub struct FuncComponents { + pub ir: Vec, + pub dependencies: Vec, + pub args: Vec, + pub recursive: bool, +} + +#[derive(Clone, Eq, Debug, PartialEq, Hash)] +pub struct ConstrFieldKey { + pub local_var: String, + pub field_name: String, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct DataTypeKey { + pub module_name: String, + pub defined_type: String, +} + +pub type ConstrUsageKey = String; + +#[derive(Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub struct FunctionAccessKey { + pub module_name: String, + pub function_name: String, + pub variant_name: String, +} + +#[derive(Clone, Debug)] +pub struct ClauseProperties { + pub clause_var_name: String, + pub needs_constr_var: bool, + pub is_complex_clause: bool, + pub current_index: usize, + pub original_subject_name: String, +} + +pub fn convert_type_to_data(term: Term, field_type: &Arc) -> Term { + if field_type.is_bytearray() { + Term::Apply { + function: DefaultFunction::BData.into(), + argument: term.into(), + } + } else if field_type.is_int() { + Term::Apply { + function: DefaultFunction::IData.into(), + argument: term.into(), + } + } else if field_type.is_map() { + Term::Apply { + function: DefaultFunction::MapData.into(), + argument: term.into(), + } + } else if field_type.is_list() { + Term::Apply { + function: DefaultFunction::ListData.into(), + argument: term.into(), + } + } else if field_type.is_string() { + Term::Apply { + function: DefaultFunction::BData.into(), + argument: Term::Apply { + function: DefaultFunction::EncodeUtf8.into(), + argument: term.into(), + } + .into(), + } + } else if field_type.is_tuple() { + match field_type.get_uplc_type() { + UplcType::List(_) => Term::Apply { + function: DefaultFunction::ListData.into(), + argument: term.into(), + }, + UplcType::Pair(_, _) => Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: "__pair".to_string(), + unique: 0.into(), + }, + body: Term::Apply { + function: DefaultFunction::ListData.into(), + argument: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkCons) + .force_wrap() + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::FstPair) + .force_wrap() + .force_wrap() + .into(), + argument: Term::Var(Name { + text: "__pair".to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + + argument: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkCons) + .force_wrap() + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::SndPair) + .force_wrap() + .force_wrap() + .into(), + argument: Term::Var(Name { + text: "__pair".to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + argument: Term::Constant(UplcConstant::ProtoList( + UplcType::Data, + vec![], + )) + .into(), + } + .into(), + } + .into(), + } + .into(), + } + .into(), + argument: term.into(), + }, + _ => unreachable!(), + } + } else if field_type.is_bool() { + Term::Apply { + function: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::IfThenElse) + .force_wrap() + .into(), + argument: term.into(), + } + .into(), + argument: Term::Constant(UplcConstant::Data(PlutusData::Constr(Constr { + tag: convert_constr_to_tag(1), + any_constructor: None, + fields: vec![], + }))) + .into(), + } + .into(), + argument: Term::Constant(UplcConstant::Data(PlutusData::Constr(Constr { + tag: convert_constr_to_tag(0), + any_constructor: None, + fields: vec![], + }))) + .into(), + } + } else { + term + } +} + +pub fn convert_data_to_type(term: Term, field_type: &Arc) -> Term { + if field_type.is_int() { + Term::Apply { + function: DefaultFunction::UnIData.into(), + argument: term.into(), + } + } else if field_type.is_bytearray() { + Term::Apply { + function: DefaultFunction::UnBData.into(), + argument: term.into(), + } + } else if field_type.is_map() { + Term::Apply { + function: DefaultFunction::UnMapData.into(), + argument: term.into(), + } + } else if field_type.is_list() { + Term::Apply { + function: DefaultFunction::UnListData.into(), + argument: term.into(), + } + } else if field_type.is_string() { + Term::Apply { + function: DefaultFunction::DecodeUtf8.into(), + argument: Term::Apply { + function: DefaultFunction::UnBData.into(), + argument: term.into(), + } + .into(), + } + } else if field_type.is_tuple() { + match field_type.get_uplc_type() { + UplcType::List(_) => Term::Apply { + function: DefaultFunction::UnListData.into(), + argument: term.into(), + }, + UplcType::Pair(_, _) => Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: "__list_data".to_string(), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: "__tail".to_string(), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkPairData).into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::HeadList) + .force_wrap() + .into(), + argument: Term::Var(Name { + text: "__list_data".to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::HeadList) + .force_wrap() + .into(), + argument: Term::Var(Name { + text: "__tail".to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::TailList).force_wrap().into(), + argument: Term::Var(Name { + text: "__list_data".to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::UnListData) + .force_wrap() + .into(), + argument: term.into(), + } + .into(), + }, + _ => unreachable!(), + } + } else if field_type.is_bool() { + Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::EqualsInteger).into(), + argument: Term::Constant(UplcConstant::Integer(1)).into(), + } + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::FstPair) + .force_wrap() + .force_wrap() + .into(), + argument: Term::Apply { + function: Term::Builtin(DefaultFunction::UnConstrData).into(), + argument: term.into(), + } + .into(), + } + .into(), + } + } else { + term + } +} + +pub fn rearrange_clauses( + clauses: Vec, String>>, +) -> Vec, String>> { + let mut sorted_clauses = clauses; + + // if we have a list sort clauses so we can plug holes for cases not covered by clauses + // TODO: while having 10000000 element list is impossible to destructure in plutus budget, + // let's sort clauses by a safer manner + // TODO: how shall tails be weighted? Since any clause after will not run + sorted_clauses.sort_by(|clause1, clause2| { + let clause1_len = match &clause1.pattern[0] { + Pattern::List { elements, tail, .. } => elements.len() + usize::from(tail.is_some()), + _ => 10000000, + }; + let clause2_len = match &clause2.pattern[0] { + Pattern::List { elements, tail, .. } => elements.len() + usize::from(tail.is_some()), + _ => 10000001, + }; + + clause1_len.cmp(&clause2_len) + }); + + let mut elems_len = 0; + let mut final_clauses = sorted_clauses.clone(); + let mut holes_to_fill = vec![]; + let mut assign_plug_in_name = None; + let mut last_clause_index = sorted_clauses.len() - 1; + let mut last_clause_set = false; + + // If we have a catch all, use that. Otherwise use todo which will result in error + // TODO: fill in todo label with description + let plug_in_then = match &sorted_clauses[sorted_clauses.len() - 1].pattern[0] { + Pattern::Var { name, .. } => { + assign_plug_in_name = Some(name); + sorted_clauses[sorted_clauses.len() - 1].clone().then + } + Pattern::Discard { .. } => sorted_clauses[sorted_clauses.len() - 1].clone().then, + _ => TypedExpr::Todo { + location: Span::empty(), + label: None, + tipo: sorted_clauses[sorted_clauses.len() - 1].then.tipo(), + }, + }; + + for (index, clause) in sorted_clauses.iter().enumerate() { + if let Pattern::List { elements, .. } = &clause.pattern[0] { + // found a hole and now we plug it + while elems_len < elements.len() { + let mut discard_elems = vec![]; + + for _ in 0..elems_len { + discard_elems.push(Pattern::Discard { + name: "_".to_string(), + location: Span::empty(), + }); + } + + // If we have a named catch all then in scope the name and create list of discards, otherwise list of discards + let clause_to_fill = if let Some(name) = assign_plug_in_name { + Clause { + location: Span::empty(), + pattern: vec![Pattern::Assign { + name: name.clone(), + location: Span::empty(), + pattern: Pattern::List { + location: Span::empty(), + elements: discard_elems, + tail: None, + } + .into(), + }], + alternative_patterns: vec![], + guard: None, + then: plug_in_then.clone(), + } + } else { + Clause { + location: Span::empty(), + pattern: vec![Pattern::List { + location: Span::empty(), + elements: discard_elems, + tail: None, + }], + alternative_patterns: vec![], + guard: None, + then: plug_in_then.clone(), + } + }; + + holes_to_fill.push((index, clause_to_fill)); + elems_len += 1; + } + } + + // if we have a pattern with no clause guards and a tail then no lists will get past here to other clauses + if let Pattern::List { + elements, + tail: Some(tail), + .. + } = &clause.pattern[0] + { + let mut elements = elements.clone(); + elements.push(*tail.clone()); + if elements + .iter() + .all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. })) + && !last_clause_set + { + last_clause_index = index; + last_clause_set = true; + } + } + + // If the last condition doesn't have a catch all or tail then add a catch all with a todo + if index == sorted_clauses.len() - 1 { + if let Pattern::List { + elements, + tail: Some(tail), + .. + } = &clause.pattern[0] + { + let mut elements = elements.clone(); + elements.push(*tail.clone()); + if !elements + .iter() + .all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. })) + { + final_clauses.push(Clause { + location: Span::empty(), + pattern: vec![Pattern::Discard { + name: "_".to_string(), + location: Span::empty(), + }], + alternative_patterns: vec![], + guard: None, + then: plug_in_then.clone(), + }); + } + } + } + + elems_len += 1; + } + + // Encountered a tail so stop there with that as last clause + final_clauses = final_clauses[0..(last_clause_index + 1)].to_vec(); + + // insert hole fillers into clauses + for (index, clause) in holes_to_fill.into_iter().rev() { + final_clauses.insert(index, clause); + } + + final_clauses +} + +pub fn list_access_to_uplc( + names: &[String], + id_list: &[u64], + tail: bool, + current_index: usize, + term: Term, + tipo: &Type, +) -> Term { + let (first, names) = names.split_first().unwrap(); + + let head_list = if tipo.is_map() { + Term::Apply { + function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), + argument: Term::Var(Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }) + .into(), + } + } else { + convert_data_to_type( + Term::Apply { + function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), + argument: Term::Var(Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }) + .into(), + }, + &tipo.clone().get_inner_types()[0], + ) + }; + + if names.len() == 1 && tail { + Term::Lambda { + parameter_name: Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: first.clone(), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: names[0].clone(), + unique: 0.into(), + }, + body: term.into(), + } + .into(), + argument: Term::Apply { + function: Term::Force(Term::Builtin(DefaultFunction::TailList).into()) + .into(), + argument: Term::Var(Name { + text: format!( + "tail_index_{}_{}", + current_index, id_list[current_index] + ), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + } + .into(), + argument: head_list.into(), + } + .into(), + } + } else if names.is_empty() { + Term::Lambda { + parameter_name: Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: first.clone(), + unique: 0.into(), + }, + body: term.into(), + } + .into(), + argument: Term::Apply { + function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), + argument: Term::Var(Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + } + } else { + Term::Lambda { + parameter_name: Name { + text: format!("tail_index_{}_{}", current_index, id_list[current_index]), + unique: 0.into(), + }, + body: Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: first.clone(), + unique: 0.into(), + }, + body: Term::Apply { + function: list_access_to_uplc( + names, + id_list, + tail, + current_index + 1, + term, + tipo, + ) + .into(), + argument: Term::Apply { + function: Term::Force(Term::Builtin(DefaultFunction::TailList).into()) + .into(), + argument: Term::Var(Name { + text: format!( + "tail_index_{}_{}", + current_index, id_list[current_index] + ), + unique: 0.into(), + }) + .into(), + } + .into(), + } + .into(), + } + .into(), + argument: head_list.into(), + } + .into(), + } + } +} + +pub fn get_common_ancestor(scope: &[u64], scope_prev: &[u64]) -> Vec { + let longest_length = if scope.len() >= scope_prev.len() { + scope.len() + } else { + scope_prev.len() + }; + + if *scope == *scope_prev { + return scope.to_vec(); + } + + for index in 0..longest_length { + if scope.get(index).is_none() { + return scope.to_vec(); + } else if scope_prev.get(index).is_none() { + return scope_prev.to_vec(); + } else if scope[index] != scope_prev[index] { + return scope[0..index].to_vec(); + } + } + vec![] +} + +pub fn check_when_pattern_needs( + pattern: &Pattern>, + needs_access_to_constr_var: &mut bool, + needs_clause_guard: &mut bool, +) { + match pattern { + Pattern::Var { .. } => { + *needs_access_to_constr_var = true; + } + Pattern::List { .. } + | Pattern::Constructor { .. } + | Pattern::Tuple { .. } + | Pattern::Int { .. } => { + *needs_access_to_constr_var = true; + *needs_clause_guard = true; + } + Pattern::Discard { .. } => {} + + _ => todo!("{pattern:#?}"), + } +} + +pub fn constants_ir( + literal: &Constant, String>, + ir_stack: &mut Vec, + scope: Vec, +) { + match literal { + Constant::Int { value, .. } => { + ir_stack.push(Air::Int { + scope, + value: value.clone(), + }); + } + Constant::String { value, .. } => { + ir_stack.push(Air::String { + scope, + value: value.clone(), + }); + } + Constant::Tuple { .. } => { + todo!() + } + Constant::List { elements, tipo, .. } => { + ir_stack.push(Air::List { + scope: scope.clone(), + count: elements.len(), + tipo: tipo.clone(), + tail: false, + }); + + for element in elements { + constants_ir(element, ir_stack, scope.clone()); + } + } + Constant::Record { .. } => { + // ir_stack.push(Air::Record { scope, }); + todo!() + } + Constant::ByteArray { bytes, .. } => { + ir_stack.push(Air::ByteArray { + scope, + bytes: bytes.clone(), + }); + } + Constant::Var { .. } => todo!(), + }; +} + +pub fn match_ir_for_recursion( + ir: Air, + insert_var_vec: &mut Vec<(usize, Air)>, + function_access_key: &FunctionAccessKey, + index: usize, +) { + if let Air::Var { + scope, + constructor, + variant_name, + .. + } = ir + { + if let ValueConstructorVariant::ModuleFn { + name: func_name, + module, + .. + } = constructor.clone().variant + { + let var_func_access = FunctionAccessKey { + module_name: module, + function_name: func_name.clone(), + variant_name: variant_name.clone(), + }; + + if function_access_key.clone() == var_func_access { + insert_var_vec.push(( + index, + Air::Var { + scope, + constructor, + name: func_name, + variant_name, + }, + )); + } + } + } +} + +pub fn find_generics_to_replace(tipo: &mut Arc, generic_types: &HashMap>) { + if let Some(id) = tipo.get_generic() { + *tipo = generic_types.get(&id).unwrap().clone(); + } else if tipo.is_generic() { + match &**tipo { + Type::App { + args, + public, + module, + name, + } => { + let mut new_args = vec![]; + for arg in args { + let mut arg = arg.clone(); + find_generics_to_replace(&mut arg, generic_types); + new_args.push(arg); + } + let t = Type::App { + args: new_args, + public: *public, + module: module.clone(), + name: name.clone(), + }; + *tipo = t.into(); + } + Type::Fn { args, ret } => { + let mut new_args = vec![]; + for arg in args { + let mut arg = arg.clone(); + find_generics_to_replace(&mut arg, generic_types); + new_args.push(arg); + } + + let mut ret = ret.clone(); + find_generics_to_replace(&mut ret, generic_types); + + let t = Type::Fn { + args: new_args, + ret, + }; + *tipo = t.into(); + } + Type::Tuple { elems } => { + let mut new_elems = vec![]; + for elem in elems { + let mut elem = elem.clone(); + find_generics_to_replace(&mut elem, generic_types); + new_elems.push(elem); + } + let t = Type::Tuple { elems: new_elems }; + *tipo = t.into(); + } + Type::Var { tipo: var_tipo } => { + let var_type = var_tipo.as_ref().borrow().clone(); + let var_tipo = match var_type { + TypeVar::Unbound { .. } => todo!(), + TypeVar::Link { tipo } => { + let mut tipo = tipo; + find_generics_to_replace(&mut tipo, generic_types); + tipo + } + TypeVar::Generic { .. } => unreachable!(), + }; + + let t = Type::Var { + tipo: RefCell::from(TypeVar::Link { tipo: var_tipo }).into(), + }; + *tipo = t.into() + } + }; + } +} + +pub fn get_generics_and_type(tipo: &Type, param: &Type) -> Vec<(u64, Arc)> { + let mut generics_ids = vec![]; + + if let Some(id) = tipo.get_generic() { + generics_ids.push((id, param.clone().into())); + } + + for (tipo, param_type) in tipo + .get_inner_types() + .iter() + .zip(param.get_inner_types().iter()) + { + generics_ids.append(&mut get_generics_and_type(tipo, param_type)); + } + generics_ids +} + +pub fn get_variant_name(new_name: &mut String, t: &Arc) { + new_name.push_str(&format!( + "_{}", + if t.is_string() { + "string".to_string() + } else if t.is_int() { + "int".to_string() + } else if t.is_bool() { + "bool".to_string() + } else if t.is_map() { + let mut full_type = "map".to_string(); + let pair_type = &t.get_inner_types()[0]; + let fst_type = &pair_type.get_inner_types()[0]; + let snd_type = &pair_type.get_inner_types()[1]; + + get_variant_name(&mut full_type, fst_type); + get_variant_name(&mut full_type, snd_type); + full_type + } else if t.is_list() { + let mut full_type = "list".to_string(); + let list_type = &t.get_inner_types()[0]; + get_variant_name(&mut full_type, list_type); + full_type + } else { + "data".to_string() + } + )); +} + +pub fn convert_constants_to_data(constants: Vec) -> Vec { + let mut new_constants = vec![]; + for constant in constants { + let constant = match constant { + UplcConstant::Integer(i) => { + UplcConstant::Data(PlutusData::BigInt(BigInt::Int((i).try_into().unwrap()))) + } + UplcConstant::ByteString(b) => { + UplcConstant::Data(PlutusData::BoundedBytes(b.try_into().unwrap())) + } + UplcConstant::String(s) => UplcConstant::Data(PlutusData::BoundedBytes( + s.as_bytes().to_vec().try_into().unwrap(), + )), + + UplcConstant::Bool(b) => UplcConstant::Data(PlutusData::Constr(Constr { + tag: u64::from(b), + any_constructor: None, + fields: vec![], + })), + UplcConstant::ProtoList(_, constants) => { + let inner_constants = convert_constants_to_data(constants) + .into_iter() + .map(|constant| match constant { + UplcConstant::Data(d) => d, + _ => todo!(), + }) + .collect_vec(); + + UplcConstant::Data(PlutusData::Array(inner_constants)) + } + UplcConstant::ProtoPair(_, _, left, right) => { + let inner_constants = vec![*left, *right]; + let inner_constants = convert_constants_to_data(inner_constants) + .into_iter() + .map(|constant| match constant { + UplcConstant::Data(d) => d, + _ => todo!(), + }) + .collect_vec(); + + UplcConstant::Data(PlutusData::Map(KeyValuePairs::Def(vec![( + inner_constants[0].clone(), + inner_constants[1].clone(), + )]))) + } + d @ UplcConstant::Data(_) => d, + _ => unreachable!(), + }; + new_constants.push(constant); + } + new_constants +} + +pub fn monomorphize( + ir: Vec, + generic_types: HashMap>, + full_type: &Arc, +) -> (String, Vec) { + let mut new_air = ir.clone(); + let mut new_name = String::new(); + + for (index, ir) in ir.into_iter().enumerate() { + match ir { + Air::Var { + constructor, + scope, + name, + .. + } => { + if constructor.tipo.is_generic() { + let mut tipo = constructor.tipo.clone(); + + find_generics_to_replace(&mut tipo, &generic_types); + + let mut variant = String::new(); + + let mut constructor = constructor.clone(); + constructor.tipo = tipo; + + if let Type::Fn { args, .. } = &*constructor.tipo { + if matches!( + constructor.variant, + ValueConstructorVariant::ModuleFn { .. } + ) { + for arg in args { + get_variant_name(&mut variant, arg); + } + } + } + new_air[index] = Air::Var { + scope, + constructor, + name, + variant_name: variant, + }; + } + } + Air::List { + tipo, + scope, + count, + tail, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::List { + scope, + count, + tipo, + tail, + }; + } + } + Air::ListAccessor { + scope, + tipo, + names, + tail, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::ListAccessor { + scope, + names, + tipo, + tail, + }; + } + } + Air::ListExpose { + scope, + tipo, + tail_head_names, + tail, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::ListExpose { + scope, + tail_head_names, + tipo, + tail, + }; + } + } + + Air::BinOp { + scope, + name, + count, + tipo, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::BinOp { + scope, + name, + tipo, + count, + }; + } + } + Air::Builtin { scope, func, tipo } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::Builtin { scope, func, tipo }; + } + } + // TODO check on assignment if type is needed + Air::Assignment { .. } => {} + Air::When { + scope, + tipo, + subject_name, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::When { + scope, + subject_name, + tipo, + }; + } + } + Air::Clause { + scope, + tipo, + subject_name, + complex_clause, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::Clause { + scope, + tipo, + subject_name, + complex_clause, + }; + } + } + Air::ListClause { + scope, + tipo, + tail_name, + complex_clause, + next_tail_name, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::ListClause { + scope, + tipo, + tail_name, + complex_clause, + next_tail_name, + }; + } + } + Air::ClauseGuard { + tipo, + scope, + subject_name, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::ClauseGuard { + scope, + subject_name, + tipo, + }; + } + } + Air::RecordAccess { + scope, + index: record_index, + tipo, + } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::RecordAccess { + scope, + index: record_index, + tipo, + }; + } + } + Air::FieldsExpose { + scope, + count, + indices, + } => { + let mut new_indices = vec![]; + for (ind, name, tipo) in indices { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + } + new_indices.push((ind, name, tipo)); + } + new_air[index] = Air::FieldsExpose { + scope, + count, + indices: new_indices, + }; + } + Air::Tuple { scope, tipo, count } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::Tuple { scope, count, tipo }; + } + } + Air::Todo { scope, label, tipo } => { + if tipo.is_generic() { + let mut tipo = tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + + new_air[index] = Air::Todo { scope, label, tipo }; + } + } + Air::RecordUpdate { .. } => todo!(), + Air::TupleAccessor { .. } => todo!(), + _ => {} + } + } + + if let Type::Fn { args, .. } = &**full_type { + for arg in args { + get_variant_name(&mut new_name, arg); + } + } + + (new_name, new_air) +} diff --git a/crates/lang/src/lib.rs b/crates/lang/src/lib.rs index a3d8a530..b2be2f0e 100644 --- a/crates/lang/src/lib.rs +++ b/crates/lang/src/lib.rs @@ -5,6 +5,7 @@ use std::sync::{ pub mod air; pub mod ast; +pub mod builder; pub mod builtins; pub mod expr; pub mod format; diff --git a/crates/lang/src/tipo.rs b/crates/lang/src/tipo.rs index def274c8..6b457ff3 100644 --- a/crates/lang/src/tipo.rs +++ b/crates/lang/src/tipo.rs @@ -135,6 +135,14 @@ impl Type { } } + pub fn is_option(&self) -> bool { + match self { + Self::App { module, name, .. } if "Option" == name && module.is_empty() => true, + Self::Var { tipo } => tipo.borrow().is_option(), + _ => false, + } + } + pub fn is_map(&self) -> bool { match self { Self::App { @@ -143,7 +151,7 @@ impl Type { if let Type::Tuple { elems } = &*args[0] { elems.len() == 2 } else if let Type::Var { tipo } = &*args[0] { - matches!(tipo.borrow().get_uplc_type(), UplcType::Pair(_, _)) + matches!(tipo.borrow().get_uplc_type(), Some(UplcType::Pair(_, _))) } else { false } @@ -157,7 +165,42 @@ impl Type { matches!(self, Self::Tuple { .. }) } - pub fn get_inner_type(&self) -> Vec> { + pub fn is_generic(&self) -> bool { + match self { + Type::App { args, .. } => { + let mut is_a_generic = false; + for arg in args { + is_a_generic = is_a_generic || arg.is_generic(); + } + is_a_generic + } + + Type::Var { tipo } => tipo.borrow().is_generic(), + Type::Tuple { elems } => { + let mut is_a_generic = false; + for elem in elems { + is_a_generic = is_a_generic || elem.is_generic(); + } + is_a_generic + } + Type::Fn { args, .. } => { + let mut is_a_generic = false; + for arg in args { + is_a_generic = is_a_generic || arg.is_generic(); + } + is_a_generic + } + } + } + + pub fn get_generic(&self) -> Option { + match self { + Type::Var { tipo } => tipo.borrow().get_generic(), + _ => None, + } + } + + pub fn get_inner_types(&self) -> Vec> { if self.is_list() { match self { Self::App { args, .. } => args.clone(), @@ -169,6 +212,13 @@ impl Type { Self::Tuple { elems } => elems.to_vec(), _ => vec![], } + } else if matches!(self.get_uplc_type(), UplcType::Data) { + match self { + Type::App { args, .. } => args.clone(), + Type::Fn { args, .. } => args.clone(), + Type::Var { tipo } => tipo.borrow().get_inner_type(), + _ => unreachable!(), + } } else { vec![] } @@ -374,6 +424,13 @@ impl TypeVar { } } + pub fn is_option(&self) -> bool { + match self { + Self::Link { tipo } => tipo.is_option(), + _ => false, + } + } + pub fn is_map(&self) -> bool { match self { Self::Link { tipo } => tipo.is_map(), @@ -381,17 +438,33 @@ impl TypeVar { } } + pub fn is_generic(&self) -> bool { + match self { + TypeVar::Generic { .. } => true, + TypeVar::Link { tipo } => tipo.is_generic(), + _ => false, + } + } + + pub fn get_generic(&self) -> Option { + match self { + TypeVar::Generic { id } => Some(*id), + TypeVar::Link { tipo } => tipo.get_generic(), + _ => None, + } + } + pub fn get_inner_type(&self) -> Vec> { match self { - Self::Link { tipo } => tipo.get_inner_type(), + Self::Link { tipo } => tipo.get_inner_types(), _ => vec![], } } - pub fn get_uplc_type(&self) -> UplcType { + pub fn get_uplc_type(&self) -> Option { match self { - Self::Link { tipo } => tipo.get_uplc_type(), - _ => unreachable!(), + Self::Link { tipo } => Some(tipo.get_uplc_type()), + _ => None, } } } diff --git a/crates/lang/src/uplc.rs b/crates/lang/src/uplc.rs index 6cc2f72e..74a470ed 100644 --- a/crates/lang/src/uplc.rs +++ b/crates/lang/src/uplc.rs @@ -8,64 +8,31 @@ use uplc::{ Constant as UplcConstant, Name, Program, Term, Type as UplcType, }, builtins::DefaultFunction, - machine::runtime::convert_constr_to_tag, parser::interner::Interner, - BigInt, Constr, PlutusData, }; use crate::{ air::Air, ast::{ - ArgName, AssignmentKind, BinOp, Clause, Constant, DataType, Function, Pattern, Span, - TypedArg, + ArgName, AssignmentKind, BinOp, Clause, Pattern, Span, TypedArg, TypedDataType, + TypedFunction, + }, + builder::{ + check_when_pattern_needs, constants_ir, convert_constants_to_data, convert_data_to_type, + convert_type_to_data, get_common_ancestor, get_generics_and_type, get_variant_name, + list_access_to_uplc, match_ir_for_recursion, monomorphize, rearrange_clauses, + ClauseProperties, DataTypeKey, FuncComponents, FunctionAccessKey, }, expr::TypedExpr, tipo::{self, PatternConstructor, Type, TypeInfo, ValueConstructor, ValueConstructorVariant}, IdGenerator, }; -#[derive(Clone, Debug)] -pub struct FuncComponents { - ir: Vec, - dependencies: Vec, - args: Vec, - recursive: bool, -} - -#[derive(Clone, Eq, Debug, PartialEq, Hash)] -pub struct ConstrFieldKey { - pub local_var: String, - pub field_name: String, -} - -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct DataTypeKey { - pub module_name: String, - pub defined_type: String, -} - -pub type ConstrUsageKey = String; - -#[derive(Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub struct FunctionAccessKey { - pub module_name: String, - pub function_name: String, -} - -#[derive(Clone, Debug)] -pub struct ClauseProperties { - clause_var_name: String, - needs_constr_var: bool, - is_complex_clause: bool, - current_index: usize, - original_subject_name: String, -} - pub struct CodeGenerator<'a> { defined_functions: HashMap, - functions: &'a HashMap, TypedExpr>>, + functions: &'a HashMap, // type_aliases: &'a HashMap<(String, String), &'a TypeAlias>>, - data_types: &'a HashMap>>, + data_types: &'a HashMap, // imports: &'a HashMap<(String, String), &'a Use>, // constants: &'a HashMap<(String, String), &'a ModuleConstant, String>>, module_types: &'a HashMap, @@ -75,9 +42,9 @@ pub struct CodeGenerator<'a> { impl<'a> CodeGenerator<'a> { pub fn new( - functions: &'a HashMap, TypedExpr>>, + functions: &'a HashMap, // type_aliases: &'a HashMap<(String, String), &'a TypeAlias>>, - data_types: &'a HashMap>>, + data_types: &'a HashMap, // imports: &'a HashMap<(String, String), &'a Use>, // constants: &'a HashMap<(String, String), &'a ModuleConstant, String>>, module_types: &'a HashMap, @@ -95,7 +62,12 @@ impl<'a> CodeGenerator<'a> { } } - pub fn generate(&mut self, body: TypedExpr, arguments: Vec) -> Program { + pub fn generate( + &mut self, + body: TypedExpr, + arguments: Vec, + wrap_as_validator: bool, + ) -> Program { let mut ir_stack = vec![]; let scope = vec![self.id_gen.next()]; @@ -112,7 +84,11 @@ impl<'a> CodeGenerator<'a> { } // Wrap the validator body if ifThenElse term unit error - term = builder::final_wrapper(term); + term = if wrap_as_validator { + builder::final_wrapper(term) + } else { + term + }; for arg in arguments.iter().rev() { term = Term::Lambda { @@ -166,20 +142,50 @@ impl<'a> CodeGenerator<'a> { } TypedExpr::Var { constructor, name, .. - } => { - if let ValueConstructorVariant::ModuleConstant { literal, .. } = - &constructor.variant - { + } => match &constructor.variant { + ValueConstructorVariant::ModuleConstant { literal, .. } => { constants_ir(literal, ir_stack, scope); - } else { + } + ValueConstructorVariant::ModuleFn { + builtin: Some(builtin), + .. + } => { + ir_stack.push(Air::Builtin { + scope, + func: *builtin, + tipo: constructor.tipo.clone(), + }); + } + _ => { ir_stack.push(Air::Var { scope, constructor: constructor.clone(), name: name.clone(), + variant_name: String::new(), }); } + }, + TypedExpr::Fn { args, body, .. } => { + let mut func_body = vec![]; + let mut func_scope = scope.clone(); + func_scope.push(self.id_gen.next()); + self.build_ir(body, &mut func_body, func_scope); + let mut arg_names = vec![]; + for arg in args { + let name = arg + .arg_name + .get_variable_name() + .unwrap_or_default() + .to_string(); + arg_names.push(name); + } + + ir_stack.push(Air::Fn { + scope, + params: arg_names, + }); + ir_stack.append(&mut func_body); } - TypedExpr::Fn { .. } => todo!(), TypedExpr::List { elements, tail, @@ -188,11 +194,7 @@ impl<'a> CodeGenerator<'a> { } => { ir_stack.push(Air::List { scope: scope.clone(), - count: if tail.is_some() { - elements.len() + 1 - } else { - elements.len() - }, + count: elements.len(), tipo: tipo.clone(), tail: tail.is_some(), }); @@ -213,7 +215,7 @@ impl<'a> CodeGenerator<'a> { TypedExpr::Call { fun, args, .. } => { ir_stack.push(Air::Call { scope: scope.clone(), - count: args.len() + 1, + count: args.len(), }); let mut scope_fun = scope.clone(); scope_fun.push(self.id_gen.next()); @@ -280,7 +282,6 @@ impl<'a> CodeGenerator<'a> { // assuming one subject at the moment let subject = subjects[0].clone(); - let mut needs_subject_var = false; let clauses = if matches!(clauses[0].pattern[0], Pattern::List { .. }) { rearrange_clauses(clauses.clone()) @@ -289,7 +290,6 @@ impl<'a> CodeGenerator<'a> { }; if let Some((last_clause, clauses)) = clauses.split_last() { - let mut clauses_vec = vec![]; let mut pattern_vec = vec![]; let mut clause_properties = ClauseProperties { @@ -300,66 +300,13 @@ impl<'a> CodeGenerator<'a> { original_subject_name: subject_name.clone(), }; - for (index, clause) in clauses.iter().enumerate() { - // scope per clause is different - let mut scope = scope.clone(); - scope.push(self.id_gen.next()); - - // holds when clause pattern Air - let mut clause_subject_vec = vec![]; - - // reset complex clause setting per clause back to default - clause_properties.is_complex_clause = false; - - self.build_ir(&clause.then, &mut clauses_vec, scope.clone()); - - self.when_ir( - &clause.pattern[0], - &mut clause_subject_vec, - &mut clauses_vec, - &subject.tipo(), - &mut clause_properties, - scope.clone(), - ); - - if clause_properties.needs_constr_var { - needs_subject_var = true; - } - - let subject_name = if clause_properties.current_index == 0 { - subject_name.clone() - } else { - format!("__tail_{}", clause_properties.current_index - 1) - }; - - // Clause is first in Air pattern vec - if subject.tipo().is_list() { - let next_tail = if index == clauses.len() - 1 { - None - } else { - Some(format!("__tail_{}", clause_properties.current_index)) - }; - - pattern_vec.push(Air::ListClause { - scope, - tipo: subject.tipo().clone(), - tail_name: subject_name, - complex_clause: clause_properties.is_complex_clause, - next_tail_name: next_tail, - }); - - clause_properties.current_index += 1; - } else { - pattern_vec.push(Air::Clause { - scope, - tipo: subject.tipo().clone(), - subject_name, - complex_clause: clause_properties.is_complex_clause, - }); - } - - pattern_vec.append(&mut clause_subject_vec); - } + self.handle_each_clause( + &mut pattern_vec, + &mut clause_properties, + clauses, + &subject.tipo(), + scope.clone(), + ); let last_pattern = &last_clause.pattern[0]; @@ -369,18 +316,24 @@ impl<'a> CodeGenerator<'a> { scope: final_scope.clone(), }); - self.build_ir(&last_clause.then, &mut clauses_vec, final_scope.clone()); + let mut final_clause_vec = vec![]; + + self.build_ir( + &last_clause.then, + &mut final_clause_vec, + final_scope.clone(), + ); self.when_ir( last_pattern, &mut pattern_vec, - &mut clauses_vec, + &mut final_clause_vec, &subject.tipo(), &mut clause_properties, final_scope, ); - if needs_subject_var || clause_properties.needs_constr_var { + if clause_properties.needs_constr_var { ir_stack.push(Air::Lam { scope: scope.clone(), name: constr_var.clone(), @@ -406,6 +359,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: constr_var, + variant_name: String::new(), }) } else { ir_stack.push(Air::When { @@ -473,6 +427,7 @@ impl<'a> CodeGenerator<'a> { TypedExpr::ModuleSelect { constructor, module_name, + tipo, .. } => match constructor { tipo::ModuleValueConstructor::Record { .. } => todo!(), @@ -480,6 +435,7 @@ impl<'a> CodeGenerator<'a> { let func = self.functions.get(&FunctionAccessKey { module_name: module_name.clone(), function_name: name.clone(), + variant_name: String::new(), }); if let Some(func) = func { @@ -497,6 +453,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: format!("{module}_{name}"), + variant_name: String::new(), }); } else { let type_info = self.module_types.get(module_name).unwrap(); @@ -508,6 +465,7 @@ impl<'a> CodeGenerator<'a> { ir_stack.push(Air::Builtin { func: builtin, scope, + tipo: tipo.clone(), }); } _ => unreachable!(), @@ -547,6 +505,72 @@ impl<'a> CodeGenerator<'a> { } } + fn handle_each_clause( + &mut self, + ir_stack: &mut Vec, + clause_properties: &mut ClauseProperties, + clauses: &[Clause, String>], + subject_type: &Arc, + scope: Vec, + ) { + for (index, clause) in clauses.iter().enumerate() { + // scope per clause is different + let mut scope = scope.clone(); + scope.push(self.id_gen.next()); + + // holds when clause pattern Air + let mut clause_subject_vec = vec![]; + let mut clauses_vec = vec![]; + + // reset complex clause setting per clause back to default + clause_properties.is_complex_clause = false; + + self.build_ir(&clause.then, &mut clauses_vec, scope.clone()); + + self.when_ir( + &clause.pattern[0], + &mut clause_subject_vec, + &mut clauses_vec, + subject_type, + clause_properties, + scope.clone(), + ); + + let subject_name = if clause_properties.current_index == 0 { + clause_properties.original_subject_name.clone() + } else { + format!("__tail_{}", clause_properties.current_index - 1) + }; + + // Clause is last in Air pattern vec + if subject_type.is_list() { + let next_tail = if index == clauses.len() - 1 { + None + } else { + Some(format!("__tail_{}", clause_properties.current_index)) + }; + + ir_stack.push(Air::ListClause { + scope, + tipo: subject_type.clone(), + tail_name: subject_name, + complex_clause: clause_properties.is_complex_clause, + next_tail_name: next_tail, + }); + + clause_properties.current_index += 1; + } else { + ir_stack.push(Air::Clause { + scope, + tipo: subject_type.clone(), + complex_clause: clause_properties.is_complex_clause, + subject_name, + }); + } + ir_stack.append(&mut clause_subject_vec); + } + } + fn when_ir( &mut self, pattern: &Pattern>, @@ -584,6 +608,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: clause_properties.original_subject_name.clone(), + variant_name: String::new(), }); pattern_vec.append(values); } @@ -603,6 +628,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: clause_properties.original_subject_name.clone(), + variant_name: String::new(), }); new_vec.append(values); @@ -656,14 +682,11 @@ impl<'a> CodeGenerator<'a> { name: constr_name, .. } => { - let mut needs_access_to_constr_var = false; - let mut needs_clause_guard = false; - for arg in arguments { check_when_pattern_needs( &arg.value, - &mut needs_access_to_constr_var, - &mut needs_clause_guard, + &mut clause_properties.needs_constr_var, + &mut clause_properties.is_complex_clause, ); } @@ -700,6 +723,7 @@ impl<'a> CodeGenerator<'a> { ), name: clause_properties.clause_var_name.clone(), scope: scope.clone(), + variant_name: String::new(), }]; // if only one constructor, no need to check @@ -711,13 +735,7 @@ impl<'a> CodeGenerator<'a> { }); } - if needs_clause_guard { - clause_properties.is_complex_clause = true; - } - - if needs_access_to_constr_var { - clause_properties.needs_constr_var = true; - + if clause_properties.needs_constr_var { self.when_recursive_ir( pattern, pattern_vec, @@ -1165,12 +1183,13 @@ impl<'a> CodeGenerator<'a> { ), name: item_name, scope: scope.clone(), + variant_name: String::new(), }); self.pattern_ir( a, &mut elements_vec, &mut var_vec, - &tipo.get_inner_type()[0], + &tipo.get_inner_types()[0], scope.clone(), ); } @@ -1271,6 +1290,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: constr_name.clone(), + variant_name: String::new(), }], tipo, scope.clone(), @@ -1336,6 +1356,7 @@ impl<'a> CodeGenerator<'a> { }, ), name: constr_name.clone(), + variant_name: String::new(), }], tipo, scope.clone(), @@ -1398,12 +1419,13 @@ impl<'a> CodeGenerator<'a> { ), name: item_name, scope: scope.clone(), + variant_name: String::new(), }); self.pattern_ir( a, &mut elements_vec, &mut var_vec, - &tipo.get_inner_type()[0], + &tipo.get_inner_types()[0], scope.clone(), ); } @@ -1422,6 +1444,465 @@ impl<'a> CodeGenerator<'a> { } } + fn define_ir(&mut self, ir_stack: &mut Vec) { + let mut func_components = IndexMap::new(); + let mut func_index_map = IndexMap::new(); + + let recursion_func_map = IndexMap::new(); + + self.define_recurse_ir( + ir_stack, + &mut func_components, + &mut func_index_map, + recursion_func_map, + ); + + let mut insert_var_vec = vec![]; + let mut final_func_dep_ir = IndexMap::new(); + for func in func_index_map.clone() { + if self.defined_functions.contains_key(&func.0) { + continue; + } + let mut funt_comp = func_components.get(&func.0).unwrap().clone(); + let func_scope = func_index_map.get(&func.0).unwrap(); + + let mut dep_ir = vec![]; + + // deal with function dependencies + while let Some(dependency) = funt_comp.dependencies.pop() { + if self.defined_functions.contains_key(&dependency) + || func_components.get(&dependency).is_none() + { + continue; + } + + let depend_comp = func_components.get(&dependency).unwrap(); + + let dep_scope = func_index_map.get(&dependency).unwrap(); + + if get_common_ancestor(dep_scope, func_scope) == func_scope.clone() { + funt_comp + .dependencies + .extend(depend_comp.dependencies.clone()); + + let mut temp_ir = vec![Air::DefineFunc { + scope: func_scope.clone(), + func_name: dependency.function_name.clone(), + module_name: dependency.module_name.clone(), + params: depend_comp.args.clone(), + recursive: depend_comp.recursive, + variant_name: dependency.variant_name.clone(), + }]; + + for (index, ir) in depend_comp.ir.iter().enumerate() { + match_ir_for_recursion( + ir.clone(), + &mut insert_var_vec, + &FunctionAccessKey { + function_name: dependency.function_name.clone(), + module_name: dependency.module_name.clone(), + variant_name: dependency.variant_name.clone(), + }, + index, + ); + } + + let mut recursion_ir = depend_comp.ir.clone(); + for (index, ir) in insert_var_vec.clone() { + recursion_ir.insert(index, ir); + + let current_call = recursion_ir[index - 1].clone(); + + match current_call { + Air::Call { scope, count } => { + recursion_ir[index - 1] = Air::Call { + scope, + count: count + 1, + } + } + _ => unreachable!(), + } + } + + temp_ir.append(&mut recursion_ir); + + temp_ir.append(&mut dep_ir); + + dep_ir = temp_ir; + self.defined_functions.insert(dependency, ()); + insert_var_vec = vec![]; + } + } + + final_func_dep_ir.insert(func.0, dep_ir); + } + + for (index, ir) in ir_stack.clone().into_iter().enumerate().rev() { + { + let temp_func_index_map = func_index_map.clone(); + let to_insert = temp_func_index_map + .iter() + .filter(|func| { + func.1.clone() == ir.scope() && !self.defined_functions.contains_key(func.0) + }) + .collect_vec(); + + for (function_access_key, scopes) in to_insert.into_iter() { + func_index_map.remove(function_access_key); + + self.defined_functions + .insert(function_access_key.clone(), ()); + + let mut full_func_ir = + final_func_dep_ir.get(function_access_key).unwrap().clone(); + + let mut func_comp = func_components.get(function_access_key).unwrap().clone(); + + full_func_ir.push(Air::DefineFunc { + scope: scopes.clone(), + func_name: function_access_key.function_name.clone(), + module_name: function_access_key.module_name.clone(), + params: func_comp.args.clone(), + recursive: func_comp.recursive, + variant_name: function_access_key.variant_name.clone(), + }); + + for (index, ir) in func_comp.ir.clone().iter().enumerate() { + match_ir_for_recursion( + ir.clone(), + &mut insert_var_vec, + function_access_key, + index, + ); + } + + for (index, ir) in insert_var_vec { + func_comp.ir.insert(index, ir); + + let current_call = func_comp.ir[index - 1].clone(); + + match current_call { + Air::Call { scope, count } => { + func_comp.ir[index - 1] = Air::Call { + scope, + count: count + 1, + } + } + _ => unreachable!("{current_call:#?}"), + } + } + insert_var_vec = vec![]; + + full_func_ir.extend(func_comp.ir.clone()); + + for ir in full_func_ir.into_iter().rev() { + ir_stack.insert(index, ir); + } + } + } + } + } + + fn define_recurse_ir( + &mut self, + ir_stack: &mut [Air], + func_components: &mut IndexMap, + func_index_map: &mut IndexMap>, + recursion_func_map: IndexMap, + ) { + self.process_define_ir(ir_stack, func_components, func_index_map); + + let mut recursion_func_map = recursion_func_map; + + for func_index in func_index_map.clone().iter() { + let func = func_index.0; + + let function_components = func_components.get(func).unwrap(); + let mut function_ir = function_components.ir.clone(); + + for ir in function_ir.clone() { + if let Air::Var { + constructor: + ValueConstructor { + variant: + ValueConstructorVariant::ModuleFn { + name: func_name, + module, + .. + }, + .. + }, + variant_name, + .. + } = ir + { + if recursion_func_map.contains_key(&FunctionAccessKey { + module_name: module.clone(), + function_name: func_name.clone(), + variant_name: variant_name.clone(), + }) { + return; + } else { + recursion_func_map.insert( + FunctionAccessKey { + module_name: module.clone(), + function_name: func_name.clone(), + variant_name: variant_name.clone(), + }, + (), + ); + } + } + } + + let mut inner_func_components = IndexMap::new(); + + let mut inner_func_index_map = IndexMap::new(); + + self.define_recurse_ir( + &mut function_ir, + &mut inner_func_components, + &mut inner_func_index_map, + recursion_func_map.clone(), + ); + + //now unify + for item in inner_func_components { + if !func_components.contains_key(&item.0) { + func_components.insert(item.0, item.1); + } + } + + for item in inner_func_index_map { + if let Some(entry) = func_index_map.get_mut(&item.0) { + *entry = get_common_ancestor(entry, &item.1); + } else { + func_index_map.insert(item.0, item.1); + } + } + } + } + + fn process_define_ir( + &mut self, + ir_stack: &mut [Air], + func_components: &mut IndexMap, + func_index_map: &mut IndexMap>, + ) { + let mut to_be_defined_map: IndexMap> = IndexMap::new(); + for (index, ir) in ir_stack.to_vec().iter().enumerate().rev() { + match ir { + Air::Var { + scope, constructor, .. + } => { + if let ValueConstructorVariant::ModuleFn { + name, + module, + builtin, + .. + } = &constructor.variant + { + if builtin.is_none() { + let mut function_key = FunctionAccessKey { + module_name: module.clone(), + function_name: name.clone(), + variant_name: String::new(), + }; + if let Some(scope_prev) = to_be_defined_map.get(&function_key) { + let new_scope = get_common_ancestor(scope, scope_prev); + + to_be_defined_map.insert(function_key, new_scope); + } else if func_components.get(&function_key).is_some() { + to_be_defined_map.insert(function_key.clone(), scope.to_vec()); + } else { + let function = self.functions.get(&function_key).unwrap(); + + let mut func_ir = vec![]; + + self.build_ir(&function.body, &mut func_ir, scope.to_vec()); + + let (param_types, _) = constructor.tipo.function_types().unwrap(); + + let mut generics_type_map: HashMap> = HashMap::new(); + + for (index, arg) in function.arguments.iter().enumerate() { + if arg.tipo.is_generic() { + let mut map = generics_type_map.into_iter().collect_vec(); + map.append(&mut get_generics_and_type( + &arg.tipo, + ¶m_types[index], + )); + + generics_type_map = map.into_iter().collect(); + } + } + + let (variant_name, mut func_ir) = + monomorphize(func_ir, generics_type_map, &constructor.tipo); + + function_key = FunctionAccessKey { + module_name: module.clone(), + function_name: function_key.function_name, + variant_name: variant_name.clone(), + }; + + to_be_defined_map.insert(function_key.clone(), scope.to_vec()); + let mut func_calls = vec![]; + + for (index, ir) in func_ir.clone().into_iter().enumerate() { + if let Air::Var { + constructor: + ValueConstructor { + variant: + ValueConstructorVariant::ModuleFn { + name: func_name, + module, + field_map, + arity, + location, + .. + }, + public, + tipo, + }, + scope, + name, + .. + } = ir + { + let current_func = FunctionAccessKey { + module_name: module.clone(), + function_name: func_name.clone(), + variant_name: String::new(), + }; + + let current_func_as_variant = FunctionAccessKey { + module_name: module.clone(), + function_name: func_name.clone(), + variant_name: variant_name.clone(), + }; + + let function = self.functions.get(¤t_func); + if function_key.clone() == current_func_as_variant { + func_ir[index] = Air::Var { + scope, + constructor: ValueConstructor { + public, + variant: ValueConstructorVariant::ModuleFn { + name: func_name, + field_map, + module, + arity, + location, + builtin: None, + }, + tipo, + }, + name, + variant_name: variant_name.clone(), + }; + func_calls.push(current_func_as_variant); + } else if let (Some(function), Type::Fn { args, .. }) = + (function, &*tipo) + { + if function + .arguments + .iter() + .any(|arg| arg.tipo.is_generic()) + { + let mut new_name = String::new(); + for arg in args.iter() { + get_variant_name(&mut new_name, arg); + } + func_calls.push(FunctionAccessKey { + module_name: module, + function_name: func_name, + variant_name: new_name, + }); + } else { + func_calls.push(current_func); + } + } else { + func_calls.push(current_func); + } + } + } + + let mut args = vec![]; + + for arg in function.arguments.iter() { + match &arg.arg_name { + ArgName::Named { name, .. } + | ArgName::NamedLabeled { name, .. } => { + args.push(name.clone()); + } + _ => {} + } + } + let recursive = if let Ok(index) = + func_calls.binary_search(&function_key) + { + func_calls.remove(index); + while let Ok(index) = func_calls.binary_search(&function_key) { + func_calls.remove(index); + } + true + } else { + false + }; + + ir_stack[index] = Air::Var { + scope: scope.clone(), + constructor: constructor.clone(), + name: name.clone(), + variant_name, + }; + + func_components.insert( + function_key, + FuncComponents { + ir: func_ir, + dependencies: func_calls, + recursive, + args, + }, + ); + } + } + } + } + a => { + let scope = a.scope(); + + for func in to_be_defined_map.clone().iter() { + if get_common_ancestor(&scope, func.1) == scope.to_vec() { + if let Some(index_scope) = func_index_map.get(func.0) { + if get_common_ancestor(index_scope, func.1) == scope.to_vec() { + func_index_map.insert(func.0.clone(), scope.clone()); + to_be_defined_map.shift_remove(func.0); + } else { + to_be_defined_map.insert( + func.0.clone(), + get_common_ancestor(index_scope, func.1), + ); + } + } else { + func_index_map.insert(func.0.clone(), scope.clone()); + to_be_defined_map.shift_remove(func.0); + } + } + } + } + } + } + + //Still to be defined + for func in to_be_defined_map.clone().iter() { + let index_scope = func_index_map.get(func.0).unwrap(); + func_index_map.insert(func.0.clone(), get_common_ancestor(func.1, index_scope)); + } + } + fn uplc_code_gen(&mut self, ir_stack: &mut Vec) -> Term { let mut arg_stack: Vec> = vec![]; @@ -1451,9 +1932,12 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } Air::Var { - name, constructor, .. + name, + constructor, + variant_name, + .. } => { - match constructor.variant { + match &constructor.variant { ValueConstructorVariant::LocalVariable { .. } => { arg_stack.push(Term::Var(Name { text: name, @@ -1468,11 +1952,12 @@ impl<'a> CodeGenerator<'a> { module, .. } => { - let name = if func_name == name { - format!("{module}_{func_name}") + let name = if *func_name == name { + format!("{module}_{func_name}{variant_name}") } else { - name + format!("{func_name}{variant_name}") }; + arg_stack.push(Term::Var(Name { text: name, unique: 0.into(), @@ -1481,6 +1966,7 @@ impl<'a> CodeGenerator<'a> { ValueConstructorVariant::Record { name: constr_name, field_map, + arity, .. } => { let data_type_key = match &*constructor.tipo { @@ -1499,12 +1985,13 @@ impl<'a> CodeGenerator<'a> { Type::Tuple { .. } => todo!(), }; - if data_type_key.defined_type == "Bool" { + if constructor.tipo.is_bool() { arg_stack .push(Term::Constant(UplcConstant::Bool(constr_name == "True"))); } else { let data_type = self.data_types.get(&data_type_key).unwrap(); - let (constr_index, _constr) = data_type + + let (constr_index, _) = data_type .constructors .iter() .enumerate() @@ -1517,9 +2004,8 @@ impl<'a> CodeGenerator<'a> { let tipo = constructor.tipo; let args_type = match tipo.as_ref() { - Type::Fn { args, .. } => args, - - _ => todo!(), + Type::Fn { args, .. } | Type::App { args, .. } => args, + _ => unreachable!(), }; if let Some(field_map) = field_map.clone() { @@ -1530,6 +2016,7 @@ impl<'a> CodeGenerator<'a> { .zip(args_type) .rev() { + // TODO revisit fields = Term::Apply { function: Term::Apply { function: Term::Builtin(DefaultFunction::MkCons) @@ -1548,6 +2035,26 @@ impl<'a> CodeGenerator<'a> { argument: fields.into(), }; } + } else { + for (index, arg) in args_type.iter().enumerate().take(*arity) { + fields = Term::Apply { + function: Term::Apply { + function: Term::Builtin(DefaultFunction::MkCons) + .force_wrap() + .into(), + argument: convert_type_to_data( + Term::Var(Name { + text: format!("__arg_{}", index), + unique: 0.into(), + }), + arg, + ) + .into(), + } + .into(), + argument: fields.into(), + }; + } } let mut term = Term::Apply { @@ -1577,6 +2084,16 @@ impl<'a> CodeGenerator<'a> { body: term.into(), }; } + } else { + for (index, _) in args_type.iter().enumerate().take(*arity) { + term = Term::Lambda { + parameter_name: Name { + text: format!("__arg_{}", index), + unique: 0.into(), + }, + body: term.into(), + }; + } } arg_stack.push(term); @@ -1603,7 +2120,7 @@ impl<'a> CodeGenerator<'a> { } } - let list_type = tipo.get_inner_type()[0].clone(); + let list_type = tipo.get_inner_types()[0].clone(); if constants.len() == args.len() && !tail { let list = if tipo.is_map() { @@ -1664,10 +2181,9 @@ impl<'a> CodeGenerator<'a> { }; term = Term::Apply { function: Term::Apply { - function: Term::Force( - Term::Builtin(DefaultFunction::MkCons).force_wrap().into(), - ) - .into(), + function: Term::Builtin(DefaultFunction::MkCons) + .force_wrap() + .into(), argument: list_item.into(), } .into(), @@ -1715,7 +2231,7 @@ impl<'a> CodeGenerator<'a> { }) .into(), }, - &tipo.get_inner_type()[0], + &tipo.get_inner_types()[0], ) }; @@ -1821,7 +2337,7 @@ impl<'a> CodeGenerator<'a> { }) .into(), }, - &tipo.get_inner_type()[0], + &tipo.get_inner_types()[0], ) }; term = Term::Apply { @@ -1839,11 +2355,27 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } + + Air::Fn { params, .. } => { + let mut term = arg_stack.pop().unwrap(); + + for param in params.iter().rev() { + term = Term::Lambda { + parameter_name: Name { + text: param.clone(), + unique: 0.into(), + }, + body: term.into(), + }; + } + + arg_stack.push(term); + } Air::Call { count, .. } => { - if count >= 2 { + if count >= 1 { let mut term = arg_stack.pop().unwrap(); - for _ in 0..count - 1 { + for _ in 0..count { let arg = arg_stack.pop().unwrap(); term = Term::Apply { @@ -1856,13 +2388,50 @@ impl<'a> CodeGenerator<'a> { todo!() } } - Air::Builtin { func, .. } => { - let mut term = Term::Builtin(func); - for _ in 0..func.force_count() { - term = Term::Force(term.into()); + Air::Builtin { func, tipo, .. } => match func { + DefaultFunction::FstPair | DefaultFunction::SndPair | DefaultFunction::HeadList => { + let id = self.id_gen.next(); + let mut term: Term = func.into(); + for _ in 0..func.force_count() { + term = term.force_wrap(); + } + + term = Term::Apply { + function: term.into(), + argument: Term::Var(Name { + text: format!("__arg_{}", id), + unique: 0.into(), + }) + .into(), + }; + + let inner_type = if matches!(func, DefaultFunction::SndPair) { + tipo.get_inner_types()[0].get_inner_types()[1].clone() + } else { + tipo.get_inner_types()[0].get_inner_types()[0].clone() + }; + + term = convert_data_to_type(term, &inner_type); + term = Term::Lambda { + parameter_name: Name { + text: format!("__arg_{}", id), + unique: 0.into(), + }, + body: term.into(), + }; + + arg_stack.push(term); } - arg_stack.push(term); - } + DefaultFunction::MkCons => todo!(), + DefaultFunction::MkPairData => todo!(), + _ => { + let mut term = Term::Builtin(func); + for _ in 0..func.force_count() { + term = term.force_wrap(); + } + arg_stack.push(term); + } + }, Air::BinOp { name, tipo, .. } => { let left = arg_stack.pop().unwrap(); let right = arg_stack.pop().unwrap(); @@ -1974,7 +2543,7 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); return; } else if tipo.is_tuple() - && matches!(tipo.get_uplc_type(), UplcType::Pair(_, _)) + && matches!(tipo.clone().get_uplc_type(), UplcType::Pair(_, _)) { let term = Term::Apply { function: Term::Apply { @@ -2028,9 +2597,7 @@ impl<'a> CodeGenerator<'a> { }; arg_stack.push(term); return; - } else if tipo.is_list() - || matches!(tipo.get_uplc_type(), UplcType::List(_)) - { + } else if tipo.is_list() { let term = Term::Apply { function: Term::Apply { function: default_builtin.into(), @@ -2041,13 +2608,10 @@ impl<'a> CodeGenerator<'a> { .into(), } .into(), + argument: Term::Apply { - function: default_builtin.into(), - argument: Term::Apply { - function: DefaultFunction::ListData.into(), - argument: right.into(), - } - .into(), + function: DefaultFunction::ListData.into(), + argument: right.into(), } .into(), }; @@ -2146,7 +2710,7 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); return; } else if tipo.is_tuple() - && matches!(tipo.get_uplc_type(), UplcType::Pair(_, _)) + && matches!(tipo.clone().get_uplc_type(), UplcType::Pair(_, _)) { // let term = Term::Apply { // function: Term::Apply { @@ -2360,12 +2924,13 @@ impl<'a> CodeGenerator<'a> { params, recursive, module_name, + variant_name, .. } => { let func_name = if module_name.is_empty() { - func_name + format!("{func_name}{variant_name}") } else { - format!("{module_name}_{func_name}") + format!("{module_name}_{func_name}{variant_name}") }; let mut func_body = arg_stack.pop().unwrap(); @@ -3022,7 +3587,7 @@ impl<'a> CodeGenerator<'a> { } } - let tuple_sub_types = tipo.get_inner_type(); + let tuple_sub_types = tipo.get_inner_types(); if constants.len() == args.len() { let data_constants = convert_constants_to_data(constants); @@ -3168,7 +3733,7 @@ impl<'a> CodeGenerator<'a> { }) .into(), }, - &tipo.get_inner_type()[0], + &tipo.get_inner_types()[0], ) }; @@ -3223,986 +3788,4 @@ impl<'a> CodeGenerator<'a> { } } } - - pub(crate) fn define_ir(&mut self, ir_stack: &mut Vec) { - let mut func_components = IndexMap::new(); - let mut func_index_map = IndexMap::new(); - - let recursion_func_map = IndexMap::new(); - - self.define_recurse_ir( - ir_stack, - &mut func_components, - &mut func_index_map, - recursion_func_map, - ); - - let mut final_func_dep_ir = IndexMap::new(); - - for func in func_index_map.clone() { - if self.defined_functions.contains_key(&func.0) { - continue; - } - - let mut funt_comp = func_components.get(&func.0).unwrap().clone(); - let func_scope = func_index_map.get(&func.0).unwrap(); - - let mut dep_ir = vec![]; - - while let Some(dependency) = funt_comp.dependencies.pop() { - if self.defined_functions.contains_key(&dependency) { - continue; - } - - let depend_comp = func_components.get(&dependency).unwrap(); - - let dep_scope = func_index_map.get(&dependency).unwrap(); - - if get_common_ancestor(dep_scope, func_scope) == func_scope.clone() { - funt_comp - .dependencies - .extend(depend_comp.dependencies.clone()); - - let mut temp_ir = vec![Air::DefineFunc { - scope: func_scope.clone(), - func_name: dependency.function_name.clone(), - module_name: dependency.module_name.clone(), - params: depend_comp.args.clone(), - recursive: depend_comp.recursive, - }]; - - temp_ir.extend(depend_comp.ir.clone()); - - temp_ir.append(&mut dep_ir); - - dep_ir = temp_ir; - self.defined_functions.insert(dependency, ()); - } - } - - final_func_dep_ir.insert(func.0, dep_ir); - } - - for (index, ir) in ir_stack.clone().into_iter().enumerate().rev() { - match ir { - Air::Var { constructor, .. } => { - if let ValueConstructorVariant::ModuleFn { .. } = &constructor.variant {} - } - a => { - let temp_func_index_map = func_index_map.clone(); - let to_insert = temp_func_index_map - .iter() - .filter(|func| { - func.1.clone() == a.scope() - && !self.defined_functions.contains_key(func.0) - }) - .collect_vec(); - - for item in to_insert.into_iter() { - func_index_map.remove(item.0); - self.defined_functions.insert(item.0.clone(), ()); - - let mut full_func_ir = final_func_dep_ir.get(item.0).unwrap().clone(); - - let funt_comp = func_components.get(item.0).unwrap(); - - full_func_ir.push(Air::DefineFunc { - scope: item.1.clone(), - func_name: item.0.function_name.clone(), - module_name: item.0.module_name.clone(), - params: funt_comp.args.clone(), - recursive: funt_comp.recursive, - }); - - full_func_ir.extend(funt_comp.ir.clone()); - - for ir in full_func_ir.into_iter().rev() { - ir_stack.insert(index, ir); - } - } - } - } - } - } - - fn define_recurse_ir( - &mut self, - ir_stack: &[Air], - func_components: &mut IndexMap, - func_index_map: &mut IndexMap>, - recursion_func_map: IndexMap, - ) { - self.process_define_ir(ir_stack, func_components, func_index_map); - - let mut recursion_func_map = recursion_func_map; - - for func_index in func_index_map.clone().iter() { - let func = func_index.0; - - let function_components = func_components.get(func).unwrap(); - let function_ir = function_components.ir.clone(); - - for ir in function_ir.clone() { - if let Air::Var { - constructor: - ValueConstructor { - variant: - ValueConstructorVariant::ModuleFn { - name: func_name, - module, - .. - }, - .. - }, - .. - } = ir - { - if recursion_func_map.contains_key(&FunctionAccessKey { - module_name: module.clone(), - function_name: func_name.clone(), - }) { - return; - } else { - recursion_func_map.insert( - FunctionAccessKey { - module_name: module.clone(), - function_name: func_name.clone(), - }, - (), - ); - } - } - } - - let mut inner_func_components = IndexMap::new(); - - let mut inner_func_index_map = IndexMap::new(); - - self.define_recurse_ir( - &function_ir, - &mut inner_func_components, - &mut inner_func_index_map, - recursion_func_map.clone(), - ); - - //now unify - for item in inner_func_components { - if !func_components.contains_key(&item.0) { - func_components.insert(item.0, item.1); - } - } - - for item in inner_func_index_map { - if let Some(entry) = func_index_map.get_mut(&item.0) { - *entry = get_common_ancestor(entry, &item.1); - } else { - func_index_map.insert(item.0, item.1); - } - } - } - } - - fn process_define_ir( - &mut self, - ir_stack: &[Air], - func_components: &mut IndexMap, - func_index_map: &mut IndexMap>, - ) { - let mut to_be_defined_map: IndexMap> = IndexMap::new(); - for ir in ir_stack.iter().rev() { - match ir { - Air::Var { - scope, constructor, .. - } => { - if let ValueConstructorVariant::ModuleFn { - name, - module, - builtin, - .. - } = &constructor.variant - { - if builtin.is_none() { - let function_key = FunctionAccessKey { - module_name: module.clone(), - function_name: name.clone(), - }; - - if let Some(scope_prev) = to_be_defined_map.get(&function_key) { - let new_scope = get_common_ancestor(scope, scope_prev); - - to_be_defined_map.insert(function_key, new_scope); - } else if func_components.get(&function_key).is_some() { - to_be_defined_map.insert(function_key.clone(), scope.to_vec()); - } else { - let function = self.functions.get(&function_key).unwrap(); - - let mut func_ir = vec![]; - - self.build_ir(&function.body, &mut func_ir, scope.to_vec()); - - to_be_defined_map.insert(function_key.clone(), scope.to_vec()); - let mut func_calls = vec![]; - - for ir in func_ir.clone() { - if let Air::Var { - constructor: - ValueConstructor { - variant: - ValueConstructorVariant::ModuleFn { - name: func_name, - module, - .. - }, - .. - }, - .. - } = ir - { - func_calls.push(FunctionAccessKey { - module_name: module.clone(), - function_name: func_name.clone(), - }) - } - } - - let mut args = vec![]; - - for arg in function.arguments.iter() { - match &arg.arg_name { - ArgName::Named { name, .. } - | ArgName::NamedLabeled { name, .. } => { - args.push(name.clone()); - } - _ => {} - } - } - let recursive = - if let Ok(index) = func_calls.binary_search(&function_key) { - func_calls.remove(index); - true - } else { - false - }; - - func_components.insert( - function_key, - FuncComponents { - ir: func_ir, - dependencies: func_calls, - recursive, - args, - }, - ); - } - } - } - } - a => { - let scope = a.scope(); - - for func in to_be_defined_map.clone().iter() { - if get_common_ancestor(&scope, func.1) == scope.to_vec() { - if let Some(index_scope) = func_index_map.get(func.0) { - if get_common_ancestor(index_scope, func.1) == scope.to_vec() { - func_index_map.insert(func.0.clone(), scope.clone()); - to_be_defined_map.shift_remove(func.0); - } else { - to_be_defined_map.insert( - func.0.clone(), - get_common_ancestor(index_scope, func.1), - ); - } - } else { - func_index_map.insert(func.0.clone(), scope.clone()); - to_be_defined_map.shift_remove(func.0); - } - } - } - } - } - } - - //Still to be defined - for func in to_be_defined_map.clone().iter() { - let index_scope = func_index_map.get(func.0).unwrap(); - func_index_map.insert(func.0.clone(), get_common_ancestor(func.1, index_scope)); - } - } -} - -fn convert_constants_to_data(constants: Vec) -> Vec { - let mut new_constants = vec![]; - for constant in constants { - let constant = match constant { - UplcConstant::Integer(i) => { - UplcConstant::Data(PlutusData::BigInt(BigInt::Int((i).try_into().unwrap()))) - } - UplcConstant::ByteString(b) => { - UplcConstant::Data(PlutusData::BoundedBytes(b.try_into().unwrap())) - } - UplcConstant::String(s) => UplcConstant::Data(PlutusData::BoundedBytes( - s.as_bytes().to_vec().try_into().unwrap(), - )), - - UplcConstant::Bool(b) => UplcConstant::Data(PlutusData::Constr(Constr { - tag: u64::from(b), - any_constructor: None, - fields: vec![], - })), - UplcConstant::ProtoList(_, _) => todo!(), - UplcConstant::ProtoPair(_, _, _, _) => todo!(), - d @ UplcConstant::Data(_) => d, - _ => unreachable!(), - }; - new_constants.push(constant); - } - new_constants -} - -fn constants_ir(literal: &Constant, String>, ir_stack: &mut Vec, scope: Vec) { - match literal { - Constant::Int { value, .. } => { - ir_stack.push(Air::Int { - scope, - value: value.clone(), - }); - } - Constant::String { value, .. } => { - ir_stack.push(Air::String { - scope, - value: value.clone(), - }); - } - Constant::Tuple { .. } => { - todo!() - } - Constant::List { elements, tipo, .. } => { - ir_stack.push(Air::List { - scope: scope.clone(), - count: elements.len(), - tipo: tipo.clone(), - tail: false, - }); - - for element in elements { - constants_ir(element, ir_stack, scope.clone()); - } - } - Constant::Record { .. } => { - // ir_stack.push(Air::Record { scope, }); - todo!() - } - Constant::ByteArray { bytes, .. } => { - ir_stack.push(Air::ByteArray { - scope, - bytes: bytes.clone(), - }); - } - Constant::Var { .. } => todo!(), - }; -} - -fn check_when_pattern_needs( - pattern: &Pattern>, - needs_access_to_constr_var: &mut bool, - needs_clause_guard: &mut bool, -) { - match pattern { - Pattern::Var { .. } => { - *needs_access_to_constr_var = true; - } - Pattern::List { .. } - | Pattern::Constructor { .. } - | Pattern::Tuple { .. } - | Pattern::Int { .. } => { - *needs_access_to_constr_var = true; - *needs_clause_guard = true; - } - Pattern::Discard { .. } => {} - - _ => todo!("{pattern:#?}"), - } -} - -fn get_common_ancestor(scope: &[u64], scope_prev: &[u64]) -> Vec { - let longest_length = if scope.len() >= scope_prev.len() { - scope.len() - } else { - scope_prev.len() - }; - - if *scope == *scope_prev { - return scope.to_vec(); - } - - for index in 0..longest_length { - if scope.get(index).is_none() { - return scope.to_vec(); - } else if scope_prev.get(index).is_none() { - return scope_prev.to_vec(); - } else if scope[index] != scope_prev[index] { - return scope[0..index].to_vec(); - } - } - vec![] -} - -fn list_access_to_uplc( - names: &[String], - id_list: &[u64], - tail: bool, - current_index: usize, - term: Term, - tipo: &Type, -) -> Term { - let (first, names) = names.split_first().unwrap(); - - let head_list = if tipo.is_map() { - Term::Apply { - function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), - argument: Term::Var(Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }) - .into(), - } - } else { - convert_data_to_type( - Term::Apply { - function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), - argument: Term::Var(Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }) - .into(), - }, - &tipo.clone().get_inner_type()[0], - ) - }; - - if names.len() == 1 && tail { - Term::Lambda { - parameter_name: Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: first.clone(), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: names[0].clone(), - unique: 0.into(), - }, - body: term.into(), - } - .into(), - argument: Term::Apply { - function: Term::Force(Term::Builtin(DefaultFunction::TailList).into()) - .into(), - argument: Term::Var(Name { - text: format!( - "tail_index_{}_{}", - current_index, id_list[current_index] - ), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - } - .into(), - argument: head_list.into(), - } - .into(), - } - } else if names.is_empty() { - Term::Lambda { - parameter_name: Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: first.clone(), - unique: 0.into(), - }, - body: term.into(), - } - .into(), - argument: Term::Apply { - function: Term::Force(Term::Builtin(DefaultFunction::HeadList).into()).into(), - argument: Term::Var(Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - } - } else { - Term::Lambda { - parameter_name: Name { - text: format!("tail_index_{}_{}", current_index, id_list[current_index]), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: first.clone(), - unique: 0.into(), - }, - body: Term::Apply { - function: list_access_to_uplc( - names, - id_list, - tail, - current_index + 1, - term, - tipo, - ) - .into(), - argument: Term::Apply { - function: Term::Force(Term::Builtin(DefaultFunction::TailList).into()) - .into(), - argument: Term::Var(Name { - text: format!( - "tail_index_{}_{}", - current_index, id_list[current_index] - ), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - } - .into(), - argument: head_list.into(), - } - .into(), - } - } -} - -fn rearrange_clauses( - clauses: Vec, String>>, -) -> Vec, String>> { - let mut sorted_clauses = clauses; - - // if we have a list sort clauses so we can plug holes for cases not covered by clauses - // TODO: while having 10000000 element list is impossible to destructure in plutus budget, - // let's sort clauses by a safer manner - // TODO: how shall tails be weighted? Since any clause after will not run - sorted_clauses.sort_by(|clause1, clause2| { - let clause1_len = match &clause1.pattern[0] { - Pattern::List { elements, tail, .. } => elements.len() + usize::from(tail.is_some()), - _ => 10000000, - }; - let clause2_len = match &clause2.pattern[0] { - Pattern::List { elements, tail, .. } => elements.len() + usize::from(tail.is_some()), - _ => 10000001, - }; - - clause1_len.cmp(&clause2_len) - }); - - let mut elems_len = 0; - let mut final_clauses = sorted_clauses.clone(); - let mut holes_to_fill = vec![]; - let mut assign_plug_in_name = None; - let mut last_clause_index = 0; - let mut last_clause_set = false; - - // If we have a catch all, use that. Otherwise use todo which will result in error - // TODO: fill in todo label with description - let plug_in_then = match &sorted_clauses[sorted_clauses.len() - 1].pattern[0] { - Pattern::Var { name, .. } => { - assign_plug_in_name = Some(name); - sorted_clauses[sorted_clauses.len() - 1].clone().then - } - Pattern::Discard { .. } => sorted_clauses[sorted_clauses.len() - 1].clone().then, - _ => TypedExpr::Todo { - location: Span::empty(), - label: None, - tipo: sorted_clauses[sorted_clauses.len() - 1].then.tipo(), - }, - }; - - for (index, clause) in sorted_clauses.iter().enumerate() { - if let Pattern::List { elements, .. } = &clause.pattern[0] { - // found a hole and now we plug it - while elems_len < elements.len() { - let mut discard_elems = vec![]; - - for _ in 0..elems_len { - discard_elems.push(Pattern::Discard { - name: "_".to_string(), - location: Span::empty(), - }); - } - - // If we have a named catch all then in scope the name and create list of discards, otherwise list of discards - let clause_to_fill = if let Some(name) = assign_plug_in_name { - Clause { - location: Span::empty(), - pattern: vec![Pattern::Assign { - name: name.clone(), - location: Span::empty(), - pattern: Pattern::List { - location: Span::empty(), - elements: discard_elems, - tail: None, - } - .into(), - }], - alternative_patterns: vec![], - guard: None, - then: plug_in_then.clone(), - } - } else { - Clause { - location: Span::empty(), - pattern: vec![Pattern::List { - location: Span::empty(), - elements: discard_elems, - tail: None, - }], - alternative_patterns: vec![], - guard: None, - then: plug_in_then.clone(), - } - }; - - holes_to_fill.push((index, clause_to_fill)); - elems_len += 1; - } - } - - // if we have a pattern with no clause guards and a tail then no lists will get past here to other clauses - if let Pattern::List { - elements, - tail: Some(tail), - .. - } = &clause.pattern[0] - { - let mut elements = elements.clone(); - elements.push(*tail.clone()); - if elements - .iter() - .all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. })) - && !last_clause_set - { - last_clause_index = index; - last_clause_set = true; - } - } - - // If the last condition doesn't have a catch all or tail then add a catch all with a todo - if index == sorted_clauses.len() - 1 { - if let Pattern::List { - elements, - tail: Some(tail), - .. - } = &clause.pattern[0] - { - let mut elements = elements.clone(); - elements.push(*tail.clone()); - if !elements - .iter() - .all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. })) - { - final_clauses.push(Clause { - location: Span::empty(), - pattern: vec![Pattern::Discard { - name: "_".to_string(), - location: Span::empty(), - }], - alternative_patterns: vec![], - guard: None, - then: plug_in_then.clone(), - }); - } - } - } - - elems_len += 1; - } - - // Encountered a tail so stop there with that as last clause - final_clauses = final_clauses[0..(last_clause_index + 1)].to_vec(); - - // insert hole fillers into clauses - for (index, clause) in holes_to_fill.into_iter().rev() { - final_clauses.insert(index, clause); - } - - final_clauses -} - -fn convert_type_to_data(term: Term, field_type: &Arc) -> Term { - if field_type.is_bytearray() { - Term::Apply { - function: DefaultFunction::BData.into(), - argument: term.into(), - } - } else if field_type.is_int() { - Term::Apply { - function: DefaultFunction::IData.into(), - argument: term.into(), - } - } else if field_type.is_map() { - Term::Apply { - function: DefaultFunction::MapData.into(), - argument: term.into(), - } - } else if field_type.is_list() { - Term::Apply { - function: DefaultFunction::ListData.into(), - argument: term.into(), - } - } else if field_type.is_string() { - Term::Apply { - function: DefaultFunction::BData.into(), - argument: Term::Apply { - function: DefaultFunction::EncodeUtf8.into(), - argument: term.into(), - } - .into(), - } - } else if field_type.is_tuple() { - match field_type.get_uplc_type() { - UplcType::List(_) => Term::Apply { - function: DefaultFunction::ListData.into(), - argument: term.into(), - }, - UplcType::Pair(_, _) => Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: "__pair".to_string(), - unique: 0.into(), - }, - body: Term::Apply { - function: DefaultFunction::ListData.into(), - argument: Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::MkCons) - .force_wrap() - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::FstPair) - .force_wrap() - .force_wrap() - .into(), - argument: Term::Var(Name { - text: "__pair".to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - - argument: Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::MkCons) - .force_wrap() - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::SndPair) - .force_wrap() - .force_wrap() - .into(), - argument: Term::Var(Name { - text: "__pair".to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - argument: Term::Constant(UplcConstant::ProtoList( - UplcType::Data, - vec![], - )) - .into(), - } - .into(), - } - .into(), - } - .into(), - } - .into(), - argument: term.into(), - }, - _ => unreachable!(), - } - } else if field_type.is_bool() { - Term::Apply { - function: Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::IfThenElse) - .force_wrap() - .into(), - argument: term.into(), - } - .into(), - argument: Term::Constant(UplcConstant::Data(PlutusData::Constr(Constr { - tag: convert_constr_to_tag(1), - any_constructor: None, - fields: vec![], - }))) - .into(), - } - .into(), - argument: Term::Constant(UplcConstant::Data(PlutusData::Constr(Constr { - tag: convert_constr_to_tag(0), - any_constructor: None, - fields: vec![], - }))) - .into(), - } - } else { - term - } -} - -fn convert_data_to_type(term: Term, field_type: &Arc) -> Term { - if field_type.is_int() { - Term::Apply { - function: DefaultFunction::UnIData.into(), - argument: term.into(), - } - } else if field_type.is_bytearray() { - Term::Apply { - function: DefaultFunction::UnBData.into(), - argument: term.into(), - } - } else if field_type.is_map() { - Term::Apply { - function: DefaultFunction::UnMapData.into(), - argument: term.into(), - } - } else if field_type.is_list() { - Term::Apply { - function: DefaultFunction::UnListData.into(), - argument: term.into(), - } - } else if field_type.is_string() { - Term::Apply { - function: DefaultFunction::DecodeUtf8.into(), - argument: Term::Apply { - function: DefaultFunction::UnBData.into(), - argument: term.into(), - } - .into(), - } - } else if field_type.is_tuple() { - match field_type.get_uplc_type() { - UplcType::List(_) => Term::Apply { - function: DefaultFunction::UnListData.into(), - argument: term.into(), - }, - UplcType::Pair(_, _) => Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: "__list_data".to_string(), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: "__tail".to_string(), - unique: 0.into(), - }, - body: Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::MkPairData).into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::HeadList) - .force_wrap() - .into(), - argument: Term::Var(Name { - text: "__list_data".to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::HeadList) - .force_wrap() - .into(), - argument: Term::Var(Name { - text: "__tail".to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - } - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::TailList).force_wrap().into(), - argument: Term::Var(Name { - text: "__list_data".to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - } - .into(), - } - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::UnListData) - .force_wrap() - .into(), - argument: term.into(), - } - .into(), - }, - _ => unreachable!(), - } - } else if field_type.is_bool() { - Term::Apply { - function: Term::Apply { - function: Term::Builtin(DefaultFunction::EqualsInteger).into(), - argument: Term::Constant(UplcConstant::Integer(1)).into(), - } - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::FstPair) - .force_wrap() - .force_wrap() - .into(), - argument: Term::Apply { - function: Term::Builtin(DefaultFunction::UnConstrData).into(), - argument: term.into(), - } - .into(), - } - .into(), - } - } else { - term - } } diff --git a/crates/project/src/error.rs b/crates/project/src/error.rs index a1bc12f9..49924876 100644 --- a/crates/project/src/error.rs +++ b/crates/project/src/error.rs @@ -1,13 +1,18 @@ +use crate::{pretty, script::EvalHint}; +use aiken_lang::{ + ast::{BinOp, Span}, + parser::error::ParseError, + tipo, +}; +use miette::{ + Diagnostic, EyreContext, LabeledSpan, MietteHandlerOpts, NamedSource, RgbColors, SourceCode, +}; use std::{ fmt::{Debug, Display}, io, path::{Path, PathBuf}, }; - -use aiken_lang::{ast::Span, parser::error::ParseError, tipo}; -use miette::{ - Diagnostic, EyreContext, LabeledSpan, MietteHandlerOpts, NamedSource, RgbColors, SourceCode, -}; +use uplc::machine::cost_model::ExBudget; #[allow(dead_code)] #[derive(thiserror::Error)] @@ -28,7 +33,7 @@ pub enum Error { #[error(transparent)] StandardIo(#[from] io::Error), - #[error("Syclical module imports")] + #[error("Cyclical module imports")] ImportCycle { modules: Vec }, /// Useful for returning many [`Error::Parse`] at once @@ -73,6 +78,15 @@ pub enum Error { src: String, named: NamedSource, }, + + #[error("{name} failed{}", if *verbose { format!("\n{src}") } else { String::new() } )] + TestFailure { + name: String, + path: PathBuf, + verbose: bool, + src: String, + evaluation_hint: Option, + }, } impl Error { @@ -148,6 +162,7 @@ impl Error { Error::Type { path, .. } => Some(path.to_path_buf()), Error::ValidatorMustReturnBool { path, .. } => Some(path.to_path_buf()), Error::WrongValidatorArity { path, .. } => Some(path.to_path_buf()), + Error::TestFailure { path, .. } => Some(path.to_path_buf()), } } @@ -163,6 +178,7 @@ impl Error { Error::Type { src, .. } => Some(src.to_string()), Error::ValidatorMustReturnBool { src, .. } => Some(src.to_string()), Error::WrongValidatorArity { src, .. } => Some(src.to_string()), + Error::TestFailure { .. } => None, } } } @@ -203,6 +219,7 @@ impl Diagnostic for Error { Error::Format { .. } => None, Error::ValidatorMustReturnBool { .. } => Some(Box::new("aiken::scripts")), Error::WrongValidatorArity { .. } => Some(Box::new("aiken::validators")), + Error::TestFailure { path, .. } => Some(Box::new(path.to_str().unwrap_or(""))), } } @@ -225,6 +242,34 @@ impl Diagnostic for Error { Error::Format { .. } => None, Error::ValidatorMustReturnBool { .. } => Some(Box::new("Try annotating the validator's return type with Bool")), Error::WrongValidatorArity { .. } => Some(Box::new("Validators require a minimum number of arguments please add the missing arguments.\nIf you don't need one of the required arguments use an underscore `_datum`.")), + Error::TestFailure { evaluation_hint, .. } =>{ + match evaluation_hint { + None => None, + Some(hint) => { + let budget = ExBudget { mem: i64::MAX, cpu: i64::MAX, }; + let left = pretty::boxed("left", match hint.left.eval(budget) { + (Ok(term), _, _) => format!("{term}"), + (Err(err), _, _) => format!("{err}"), + }); + let right = pretty::boxed("right", match hint.right.eval(budget) { + (Ok(term), _, _) => format!("{term}"), + (Err(err), _, _) => format!("{err}"), + }); + let msg = match hint.bin_op { + BinOp::And => Some(format!("{left}\n\nand\n\n{right}\n\nshould both be true.")), + BinOp::Or => Some(format!("{left}\n\nor\n\n{right}\n\nshould be true.")), + BinOp::Eq => Some(format!("{left}\n\nshould be equal to\n\n{right}")), + BinOp::NotEq => Some(format!("{left}\n\nshould not be equal to\n\n{right}")), + BinOp::LtInt => Some(format!("{left}\n\nshould be lower than\n\n{right}")), + BinOp::LtEqInt => Some(format!("{left}\n\nshould be lower than or equal to\n\n{right}")), + BinOp::GtEqInt => Some(format!("{left}\n\nshould be greater than\n\n{right}")), + BinOp::GtInt => Some(format!("{left}\n\nshould be greater than or equal to\n\n{right}")), + _ => None + }?; + Some(Box::new(msg)) + } + } + }, } } @@ -244,6 +289,7 @@ impl Diagnostic for Error { Error::WrongValidatorArity { location, .. } => Some(Box::new( vec![LabeledSpan::new_with_span(None, *location)].into_iter(), )), + Error::TestFailure { .. } => None, } } @@ -259,6 +305,7 @@ impl Diagnostic for Error { Error::Format { .. } => None, Error::ValidatorMustReturnBool { named, .. } => Some(named), Error::WrongValidatorArity { named, .. } => Some(named), + Error::TestFailure { .. } => None, } } } diff --git a/crates/project/src/lib.rs b/crates/project/src/lib.rs index 08db8a97..f9e2397f 100644 --- a/crates/project/src/lib.rs +++ b/crates/project/src/lib.rs @@ -1,22 +1,21 @@ -use std::{ - collections::HashMap, - fs, - path::{Path, PathBuf}, -}; - pub mod config; pub mod error; pub mod format; pub mod module; pub mod options; +pub mod pretty; pub mod script; pub mod telemetry; use aiken_lang::{ - ast::{Definition, Function, ModuleKind, TypedFunction}, - builtins, + ast::{ + Annotation, DataType, Definition, Function, ModuleKind, RecordConstructor, + RecordConstructorArg, Span, TypedDataType, TypedDefinition, TypedFunction, + }, + builder::{DataTypeKey, FunctionAccessKey}, + builtins::{self, generic_var}, tipo::TypeInfo, - uplc::{CodeGenerator, DataTypeKey, FunctionAccessKey}, + uplc::CodeGenerator, IdGenerator, }; use miette::NamedSource; @@ -26,11 +25,16 @@ use pallas::{ ledger::{addresses::Address, primitives::babbage}, }; use pallas_traverse::ComputeHash; -use script::Script; +use script::{EvalHint, EvalInfo, Script}; use serde_json::json; -use telemetry::{EventListener, TestInfo}; +use std::{ + collections::HashMap, + fs, + path::{Path, PathBuf}, +}; +use telemetry::EventListener; use uplc::{ - ast::{DeBruijn, Program}, + ast::{Constant, DeBruijn, Program, Term}, machine::cost_model::ExBudget, }; @@ -101,12 +105,20 @@ where self.compile(options) } - pub fn check(&mut self, skip_tests: bool, match_tests: Option) -> Result<(), Error> { + pub fn check( + &mut self, + skip_tests: bool, + match_tests: Option, + verbose: bool, + ) -> Result<(), Error> { let options = Options { code_gen_mode: if skip_tests { CodeGenMode::NoOp } else { - CodeGenMode::Test(match_tests) + CodeGenMode::Test { + match_tests, + verbose, + } }, }; @@ -140,19 +152,47 @@ where self.event_listener.handle_event(Event::GeneratingUPLC { output_path: self.output_path(), }); - let programs = self.code_gen(validators, &checked_modules)?; - self.write_build_outputs(programs, uplc_dump)?; + Ok(()) } - CodeGenMode::Test(match_tests) => { - let tests = self.test_gen(&checked_modules)?; - self.run_tests(tests, match_tests); - } - CodeGenMode::NoOp => (), - } + CodeGenMode::Test { + match_tests, + verbose, + } => { + let tests = self + .collect_scripts(&checked_modules, |def| matches!(def, Definition::Test(..)))?; + if !tests.is_empty() { + self.event_listener.handle_event(Event::RunningTests); + } + let results = self.eval_scripts(tests, match_tests); + let errors: Vec = results + .iter() + .filter_map(|e| { + if e.success { + None + } else { + Some(Error::TestFailure { + name: e.script.name.clone(), + path: e.script.input_path.clone(), + evaluation_hint: e.script.evaluation_hint.clone(), + src: e.script.program.to_pretty(), + verbose, + }) + } + }) + .collect(); - Ok(()) + self.event_listener + .handle_event(Event::FinishedTests { tests: results }); + if !errors.is_empty() { + Err(Error::List(errors)) + } else { + Ok(()) + } + } + CodeGenMode::NoOp => Ok(()), + } } fn read_source_files(&mut self) -> Result<(), Error> { @@ -290,7 +330,7 @@ where fn validate_validators( &self, checked_modules: &mut CheckedModules, - ) -> Result, Error> { + ) -> Result, Error> { let mut errors = Vec::new(); let mut validators = Vec::new(); let mut indices_to_remove = Vec::new(); @@ -344,7 +384,11 @@ where }) } - validators.push((module.name.clone(), func_def.clone())); + validators.push(( + module.input_path.clone(), + module.name.clone(), + func_def.clone(), + )); indices_to_remove.push(index); } } @@ -364,7 +408,7 @@ where fn code_gen( &mut self, - validators: Vec<(String, TypedFunction)>, + validators: Vec<(PathBuf, String, TypedFunction)>, checked_modules: &CheckedModules, ) -> Result, Error> { let mut programs = Vec::new(); @@ -374,6 +418,16 @@ where let mut imports = HashMap::new(); let mut constants = HashMap::new(); + let option_data_type = make_option(); + + data_types.insert( + DataTypeKey { + module_name: "".to_string(), + defined_type: "Option".to_string(), + }, + &option_data_type, + ); + for module in checked_modules.values() { for def in module.ast.definitions() { match def { @@ -382,6 +436,7 @@ where FunctionAccessKey { module_name: module.name.clone(), function_name: func.name.clone(), + variant_name: String::new(), }, func, ); @@ -409,7 +464,7 @@ where } } - for (module_name, func_def) in validators { + for (input_path, module_name, func_def) in validators { let Function { arguments, name, @@ -426,9 +481,15 @@ where &self.module_types, ); - let program = generator.generate(body, arguments); + let program = generator.generate(body, arguments, true); - let script = Script::new(module_name, name, program.try_into().unwrap()); + let script = Script::new( + input_path, + module_name, + name, + program.try_into().unwrap(), + None, + ); programs.push(script); } @@ -437,7 +498,11 @@ where } // TODO: revisit ownership and lifetimes of data in this function - fn test_gen(&mut self, checked_modules: &CheckedModules) -> Result, Error> { + fn collect_scripts( + &mut self, + checked_modules: &CheckedModules, + should_collect: fn(&TypedDefinition) -> bool, + ) -> Result, Error> { let mut programs = Vec::new(); let mut functions = HashMap::new(); let mut type_aliases = HashMap::new(); @@ -445,8 +510,18 @@ where let mut imports = HashMap::new(); let mut constants = HashMap::new(); + let option_data_type = make_option(); + + data_types.insert( + DataTypeKey { + module_name: "".to_string(), + defined_type: "Option".to_string(), + }, + &option_data_type, + ); + // let mut indices_to_remove = Vec::new(); - let mut tests = Vec::new(); + let mut scripts = Vec::new(); for module in checked_modules.values() { for (_index, def) in module.ast.definitions().enumerate() { @@ -456,12 +531,18 @@ where FunctionAccessKey { module_name: module.name.clone(), function_name: func.name.clone(), + variant_name: String::new(), }, func, ); + if should_collect(def) { + scripts.push((module.input_path.clone(), module.name.clone(), func)); + } } Definition::Test(func) => { - tests.push((module.name.clone(), func)); + if should_collect(def) { + scripts.push((module.input_path.clone(), module.name.clone(), func)); + } // indices_to_remove.push(index); } Definition::TypeAlias(ta) => { @@ -490,7 +571,7 @@ where // } } - for (module_name, func_def) in tests { + for (input_path, module_name, func_def) in scripts { let Function { arguments, name, @@ -507,9 +588,34 @@ where &self.module_types, ); - let program = generator.generate(body.clone(), arguments.clone()); + let evaluation_hint = if let Some((bin_op, left_src, right_src)) = func_def.test_hint() + { + let left = CodeGenerator::new(&functions, &data_types, &self.module_types) + .generate(*left_src, vec![], false) + .try_into() + .unwrap(); + let right = CodeGenerator::new(&functions, &data_types, &self.module_types) + .generate(*right_src, vec![], false) + .try_into() + .unwrap(); + Some(EvalHint { + bin_op, + left, + right, + }) + } else { + None + }; - let script = Script::new(module_name, name.to_string(), program.try_into().unwrap()); + let program = generator.generate(body.clone(), arguments.clone(), false); + + let script = Script::new( + input_path, + module_name, + name.to_string(), + program.try_into().unwrap(), + evaluation_hint, + ); programs.push(script); } @@ -517,7 +623,7 @@ where Ok(programs) } - fn run_tests(&self, tests: Vec