diff --git a/crates/aiken-lang/src/test_framework.rs b/crates/aiken-lang/src/test_framework.rs index 85f9abae..806b8ec8 100644 --- a/crates/aiken-lang/src/test_framework.rs +++ b/crates/aiken-lang/src/test_framework.rs @@ -117,14 +117,19 @@ impl Test { }) } - pub fn from_function_definition( + pub fn from_test_definition( generator: &mut CodeGenerator<'_>, test: TypedTest, module_name: String, input_path: PathBuf, + is_benchmark: bool, ) -> Test { if test.arguments.is_empty() { - Self::unit_test(generator, test, module_name, input_path) + if is_benchmark { + unreachable!("benchmark must have at least one argument"); + } else { + Self::unit_test(generator, test, module_name, input_path) + } } else { let parameter = test.arguments.first().unwrap().to_owned(); @@ -143,25 +148,58 @@ impl Test { &module_name, ); - // NOTE: We need not to pass any parameter to the fuzzer here because the fuzzer + // NOTE: We need not to pass any parameter to the fuzzer/sampler here because the fuzzer // argument is a Data constructor which needs not any conversion. So we can just safely // apply onto it later. - let fuzzer = generator.clone().generate_raw(&via, &[], &module_name); + let generator_program = generator.clone().generate_raw(&via, &[], &module_name); - Self::property_test( - input_path, - module_name, - test.name, - test.on_test_failure, - program, - Fuzzer { - program: fuzzer, - stripped_type_info, - type_info, - }, - ) + if is_benchmark { + Test::Benchmark(Benchmark { + input_path, + module: module_name, + name: test.name, + program, + on_test_failure: test.on_test_failure, + sampler: Sampler { + program: generator_program, + type_info, + stripped_type_info, + }, + }) + } else { + Self::property_test( + input_path, + module_name, + test.name, + test.on_test_failure, + program, + Fuzzer { + program: generator_program, + stripped_type_info, + type_info, + }, + ) + } } } + + pub fn from_benchmark_definition( + generator: &mut CodeGenerator<'_>, + test: TypedTest, + module_name: String, + input_path: PathBuf, + ) -> Test { + Self::from_test_definition(generator, test, module_name, input_path, true) + } + + pub fn from_function_definition( + generator: &mut CodeGenerator<'_>, + test: TypedTest, + module_name: String, + input_path: PathBuf, + ) -> Test { + Self::from_test_definition(generator, test, module_name, input_path, false) + } } /// ----- UnitTest ----------------------------------------------------------------- diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index 4431ff7e..f5fb0a61 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -499,7 +499,7 @@ fn infer_definition( )?; // Ensure that the annotation, if any, matches the type inferred from the - // Fuzzer. + // Sampler. if let Some(provided_inner_type) = provided_inner_type { if !arg .arg @@ -518,7 +518,7 @@ fn infer_definition( } } - // Replace the pre-registered type for the test function, to allow inferring + // Replace the pre-registered type for the benchmark function, to allow inferring // the function body with the right type arguments. let scope = environment .scope @@ -567,7 +567,7 @@ fn infer_definition( }); } - Ok(Definition::Test(Function { + Ok(Definition::Benchmark(Function { doc: typed_f.doc, location: typed_f.location, name: typed_f.name, diff --git a/crates/aiken-project/src/lib.rs b/crates/aiken-project/src/lib.rs index 5f83cfa8..e9d41bb3 100644 --- a/crates/aiken-project/src/lib.rs +++ b/crates/aiken-project/src/lib.rs @@ -952,8 +952,9 @@ where Ok(()) } - fn collect_benchmarks( + fn collect_test_items( &mut self, + kind: &str, // "test" or "bench" verbose: bool, match_tests: Option>, exact_match: bool, @@ -991,7 +992,13 @@ where } for def in checked_module.ast.definitions() { - if let Definition::Benchmark(func) = def { + let func = match (kind, def) { + ("test", Definition::Test(func)) => Some(func), + ("bench", Definition::Benchmark(func)) => Some(func), + _ => None, + }; + + if let Some(func) = func { if let Some(match_tests) = &match_tests { let is_match = match_tests.iter().any(|(module, names)| { let matched_module = @@ -1035,19 +1042,27 @@ where for (input_path, module_name, test) in scripts.into_iter() { if verbose { - // TODO: We may want to handle the event listener differently for benchmarks self.event_listener.handle_event(Event::GeneratingUPLCFor { name: test.name.clone(), path: input_path.clone(), }) } - tests.push(Test::from_function_definition( - &mut generator, - test.to_owned(), - module_name, - input_path, - )); + tests.push(match kind { + "test" => Test::from_function_definition( + &mut generator, + test.to_owned(), + module_name, + input_path, + ), + "bench" => Test::from_benchmark_definition( + &mut generator, + test.to_owned(), + module_name, + input_path, + ), + _ => unreachable!("Invalid test kind"), + }); } Ok(tests) @@ -1060,97 +1075,17 @@ where exact_match: bool, tracing: Tracing, ) -> Result, Error> { - let mut scripts = Vec::new(); + self.collect_test_items("test", verbose, match_tests, exact_match, tracing) + } - let match_tests = match_tests.map(|mt| { - mt.into_iter() - .map(|match_test| { - let mut match_split_dot = match_test.split('.'); - - let match_module = if match_test.contains('.') || match_test.contains('/') { - match_split_dot.next().unwrap_or("") - } else { - "" - }; - - let match_names = match_split_dot.next().map(|names| { - let names = names.replace(&['{', '}'][..], ""); - - let names_split_comma = names.split(','); - - names_split_comma.map(str::to_string).collect() - }); - - (match_module.to_string(), match_names) - }) - .collect::>)>>() - }); - - for checked_module in self.checked_modules.values() { - if checked_module.package != self.config.name.to_string() { - continue; - } - - for def in checked_module.ast.definitions() { - if let Definition::Test(func) = def { - if let Some(match_tests) = &match_tests { - let is_match = match_tests.iter().any(|(module, names)| { - let matched_module = - module.is_empty() || checked_module.name.contains(module); - - let matched_name = match names { - None => true, - Some(names) => names.iter().any(|name| { - if exact_match { - name == &func.name - } else { - func.name.contains(name) - } - }), - }; - - matched_module && matched_name - }); - - if is_match { - scripts.push(( - checked_module.input_path.clone(), - checked_module.name.clone(), - func, - )) - } - } else { - scripts.push(( - checked_module.input_path.clone(), - checked_module.name.clone(), - func, - )) - } - } - } - } - - let mut generator = self.new_generator(tracing); - - let mut tests = Vec::new(); - - for (input_path, module_name, test) in scripts.into_iter() { - if verbose { - self.event_listener.handle_event(Event::GeneratingUPLCFor { - name: test.name.clone(), - path: input_path.clone(), - }) - } - - tests.push(Test::from_function_definition( - &mut generator, - test.to_owned(), - module_name, - input_path, - )); - } - - Ok(tests) + fn collect_benchmarks( + &mut self, + verbose: bool, + match_tests: Option>, + exact_match: bool, + tracing: Tracing, + ) -> Result, Error> { + self.collect_test_items("bench", verbose, match_tests, exact_match, tracing) } fn run_tests(