diff --git a/crates/lang/src/air.rs b/crates/lang/src/air.rs index 1740367e..d7863db3 100644 --- a/crates/lang/src/air.rs +++ b/crates/lang/src/air.rs @@ -31,14 +31,10 @@ pub enum Air { variant_name: String, }, - // Fn { - // scope: Vec, - // tipo: Arc, - // is_capture: bool, - // args: Vec>>, - // body: Box, - // return_annotation: Option, - // }, + Fn { + scope: Vec, + params: Vec, + }, List { scope: Vec, count: usize, @@ -239,6 +235,7 @@ impl Air { | Air::List { scope, .. } | Air::ListAccessor { scope, .. } | Air::ListExpose { scope, .. } + | Air::Fn { scope, .. } | Air::Call { scope, .. } | Air::Builtin { scope, .. } | Air::BinOp { scope, .. } diff --git a/crates/lang/src/tipo.rs b/crates/lang/src/tipo.rs index ec09c20a..4a16a4a0 100644 --- a/crates/lang/src/tipo.rs +++ b/crates/lang/src/tipo.rs @@ -434,6 +434,7 @@ impl TypeVar { pub fn get_generic(&self) -> Option { match self { TypeVar::Generic { id } => Some(*id), + TypeVar::Link { tipo } => tipo.get_generic(), _ => None, } } diff --git a/crates/lang/src/uplc.rs b/crates/lang/src/uplc.rs index 130208ce..c006d780 100644 --- a/crates/lang/src/uplc.rs +++ b/crates/lang/src/uplc.rs @@ -181,7 +181,27 @@ impl<'a> CodeGenerator<'a> { }); } } - TypedExpr::Fn { .. } => todo!(), + TypedExpr::Fn { args, body, .. } => { + let mut func_body = vec![]; + let mut func_scope = scope.clone(); + func_scope.push(self.id_gen.next()); + self.build_ir(body, &mut func_body, func_scope); + let mut arg_names = vec![]; + for arg in args { + let name = arg + .arg_name + .get_variable_name() + .unwrap_or_default() + .to_string(); + arg_names.push(name); + } + + ir_stack.push(Air::Fn { + scope, + params: arg_names, + }); + ir_stack.append(&mut func_body); + } TypedExpr::List { elements, tail, @@ -1850,6 +1870,22 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } + + Air::Fn { params, .. } => { + let mut term = arg_stack.pop().unwrap(); + + for param in params.iter().rev() { + term = Term::Lambda { + parameter_name: Name { + text: param.clone(), + unique: 0.into(), + }, + body: term.into(), + }; + } + + arg_stack.push(term); + } Air::Call { count, .. } => { if count >= 1 { let mut term = arg_stack.pop().unwrap(); @@ -3244,20 +3280,21 @@ impl<'a> CodeGenerator<'a> { recursion_func_map, ); + let mut insert_var_vec = vec![]; let mut final_func_dep_ir = IndexMap::new(); - for func in func_index_map.clone() { if self.defined_functions.contains_key(&func.0) { continue; } - let mut funt_comp = func_components.get(&func.0).unwrap().clone(); let func_scope = func_index_map.get(&func.0).unwrap(); let mut dep_ir = vec![]; - + // deal with function dependencies while let Some(dependency) = funt_comp.dependencies.pop() { - if self.defined_functions.contains_key(&dependency) { + if self.defined_functions.contains_key(&dependency) + || func_components.get(&dependency).is_none() + { continue; } @@ -3276,15 +3313,46 @@ impl<'a> CodeGenerator<'a> { module_name: dependency.module_name.clone(), params: depend_comp.args.clone(), recursive: depend_comp.recursive, - variant_name: func.0.variant_name.clone(), + variant_name: dependency.variant_name.clone(), }]; - temp_ir.extend(depend_comp.ir.clone()); + for (index, ir) in depend_comp.ir.iter().enumerate() { + match_ir_for_recursion( + ir.clone(), + &mut insert_var_vec, + &FunctionAccessKey { + function_name: dependency.function_name.clone(), + module_name: dependency.module_name.clone(), + variant_name: dependency.variant_name.clone(), + }, + index, + ); + } + + let mut recursion_ir = depend_comp.ir.clone(); + for (index, ir) in insert_var_vec.clone() { + recursion_ir.insert(index, ir); + + let current_call = recursion_ir[index - 1].clone(); + + match current_call { + Air::Call { scope, count } => { + recursion_ir[index - 1] = Air::Call { + scope, + count: count + 1, + } + } + _ => unreachable!(), + } + } + + temp_ir.append(&mut recursion_ir); temp_ir.append(&mut dep_ir); dep_ir = temp_ir; self.defined_functions.insert(dependency, ()); + insert_var_vec = vec![]; } } @@ -3292,95 +3360,65 @@ impl<'a> CodeGenerator<'a> { } for (index, ir) in ir_stack.clone().into_iter().enumerate().rev() { - match ir { - Air::Var { constructor, .. } => { - if let ValueConstructorVariant::ModuleFn { .. } = &constructor.variant {} - } - a => { - let temp_func_index_map = func_index_map.clone(); - let to_insert = temp_func_index_map - .iter() - .filter(|func| { - func.1.clone() == a.scope() - && !self.defined_functions.contains_key(func.0) - }) - .collect_vec(); + { + let temp_func_index_map = func_index_map.clone(); + let to_insert = temp_func_index_map + .iter() + .filter(|func| { + func.1.clone() == ir.scope() && !self.defined_functions.contains_key(func.0) + }) + .collect_vec(); - for (function_access_key, scopes) in to_insert.into_iter() { - func_index_map.remove(function_access_key); + for (function_access_key, scopes) in to_insert.into_iter() { + func_index_map.remove(function_access_key); - self.defined_functions - .insert(function_access_key.clone(), ()); + self.defined_functions + .insert(function_access_key.clone(), ()); - let mut full_func_ir = - final_func_dep_ir.get(function_access_key).unwrap().clone(); + let mut full_func_ir = + final_func_dep_ir.get(function_access_key).unwrap().clone(); - let mut func_comp = - func_components.get(function_access_key).unwrap().clone(); + let mut func_comp = func_components.get(function_access_key).unwrap().clone(); - full_func_ir.push(Air::DefineFunc { - scope: scopes.clone(), - func_name: function_access_key.function_name.clone(), - module_name: function_access_key.module_name.clone(), - params: func_comp.args.clone(), - recursive: func_comp.recursive, - variant_name: function_access_key.variant_name.clone(), - }); + full_func_ir.push(Air::DefineFunc { + scope: scopes.clone(), + func_name: function_access_key.function_name.clone(), + module_name: function_access_key.module_name.clone(), + params: func_comp.args.clone(), + recursive: func_comp.recursive, + variant_name: function_access_key.variant_name.clone(), + }); - let mut insert_var_vec = vec![]; - for (index, air) in func_comp.ir.clone().into_iter().enumerate().rev() { - if let Air::Var { - scope, - constructor, - variant_name, - .. - } = air - { - if let ValueConstructorVariant::ModuleFn { - name: func_name, - module, - .. - } = constructor.clone().variant - { - if func_name.clone() - == function_access_key.function_name.clone() - && module == function_access_key.module_name.clone() - { - insert_var_vec.push(( - index, - Air::Var { - scope: scope.clone(), - constructor: constructor.clone(), - name: func_name.clone(), - variant_name, - }, - )); - } + for (index, ir) in func_comp.ir.clone().iter().enumerate() { + match_ir_for_recursion( + ir.clone(), + &mut insert_var_vec, + function_access_key, + index, + ); + } + + for (index, ir) in insert_var_vec { + func_comp.ir.insert(index, ir); + + let current_call = func_comp.ir[index - 1].clone(); + + match current_call { + Air::Call { scope, count } => { + func_comp.ir[index - 1] = Air::Call { + scope, + count: count + 1, } } + _ => unreachable!("{current_call:#?}"), } + } + insert_var_vec = vec![]; - for (index, ir) in insert_var_vec { - func_comp.ir.insert(index, ir); + full_func_ir.extend(func_comp.ir.clone()); - let current_call = func_comp.ir[index - 1].clone(); - - match current_call { - Air::Call { scope, count } => { - func_comp.ir[index - 1] = Air::Call { - scope, - count: count + 1, - } - } - _ => unreachable!(), - } - } - - full_func_ir.extend(func_comp.ir.clone()); - - for ir in full_func_ir.into_iter().rev() { - ir_stack.insert(index, ir); - } + for ir in full_func_ir.into_iter().rev() { + ir_stack.insert(index, ir); } } } @@ -3492,7 +3530,6 @@ impl<'a> CodeGenerator<'a> { function_name: name.clone(), variant_name: String::new(), }; - if let Some(scope_prev) = to_be_defined_map.get(&function_key) { let new_scope = get_common_ancestor(scope, scope_prev); @@ -3508,19 +3545,25 @@ impl<'a> CodeGenerator<'a> { let (param_types, _) = constructor.tipo.function_types().unwrap(); - let mut generic_id_type_vec = vec![]; + let mut generics_type_map: HashMap> = HashMap::new(); for (index, arg) in function.arguments.iter().enumerate() { if arg.tipo.is_generic() { - generic_id_type_vec.append(&mut get_generics_and_type( + let mut map = generics_type_map.into_iter().collect_vec(); + map.append(&mut get_generics_and_type( &arg.tipo, ¶m_types[index], )); + + generics_type_map = map.into_iter().collect(); } } - let (variant_name, mut func_ir) = - self.monomorphize(func_ir, generic_id_type_vec); + let (variant_name, mut func_ir) = self.monomorphize( + func_ir, + generics_type_map, + &constructor.tipo, + ); function_key = FunctionAccessKey { module_name: module.clone(), @@ -3563,6 +3606,8 @@ impl<'a> CodeGenerator<'a> { function_name: func_name.clone(), variant_name: variant_name.clone(), }; + + let function = self.functions.get(¤t_func); if function_key.clone() == current_func_as_variant { func_ir[index] = Air::Var { scope, @@ -3582,6 +3627,26 @@ impl<'a> CodeGenerator<'a> { variant_name: variant_name.clone(), }; func_calls.push(current_func_as_variant); + } else if let (Some(function), Type::Fn { args, .. }) = + (function, &*tipo) + { + if function + .arguments + .iter() + .any(|arg| arg.tipo.is_generic()) + { + let mut new_name = String::new(); + for arg in args.iter() { + get_variant_name(&mut new_name, arg); + } + func_calls.push(FunctionAccessKey { + module_name: module, + function_name: func_name, + variant_name: new_name, + }); + } else { + func_calls.push(current_func); + } } else { func_calls.push(current_func); } @@ -3666,9 +3731,11 @@ impl<'a> CodeGenerator<'a> { fn monomorphize( &mut self, ir: Vec, - generic_types: Vec<(u64, Arc)>, + generic_types: HashMap>, + full_type: &Arc, ) -> (String, Vec) { let mut new_air = ir.clone(); + let mut new_name = String::new(); for (index, ir) in ir.into_iter().enumerate() { match ir { @@ -3676,20 +3743,33 @@ impl<'a> CodeGenerator<'a> { constructor, scope, name, - variant_name, + .. } => { if constructor.tipo.is_generic() { let mut tipo = constructor.tipo.clone(); + find_generics_to_replace(&mut tipo, &generic_types); + let mut variant = String::new(); + let mut constructor = constructor.clone(); constructor.tipo = tipo; + if let Type::Fn { args, .. } = &*constructor.tipo { + if matches!( + constructor.variant, + ValueConstructorVariant::ModuleFn { .. } + ) { + for arg in args { + get_variant_name(&mut variant, arg); + } + } + } new_air[index] = Air::Var { scope, constructor, name, - variant_name, + variant_name: variant, }; } } @@ -3887,24 +3967,59 @@ impl<'a> CodeGenerator<'a> { } } - let mut new_name = String::new(); - - for (_, t) in generic_types { - get_variant_name(&mut new_name, t); + if let Type::Fn { args, .. } = &**full_type { + for arg in args { + get_variant_name(&mut new_name, arg); + } } (new_name, new_air) } } -fn find_generics_to_replace(tipo: &mut Arc, generic_types: &[(u64, Arc)]) { - if let Some(id) = tipo.get_generic() { - if let Some((_, t)) = generic_types - .iter() - .find(|(generic_id, _)| id == *generic_id) +fn match_ir_for_recursion( + ir: Air, + insert_var_vec: &mut Vec<(usize, Air)>, + function_access_key: &FunctionAccessKey, + index: usize, +) { + if let Air::Var { + scope, + constructor, + variant_name, + .. + } = ir + { + if let ValueConstructorVariant::ModuleFn { + name: func_name, + module, + .. + } = constructor.clone().variant { - *tipo = t.clone(); + let var_func_access = FunctionAccessKey { + module_name: module, + function_name: func_name.clone(), + variant_name: variant_name.clone(), + }; + + if function_access_key.clone() == var_func_access { + insert_var_vec.push(( + index, + Air::Var { + scope, + constructor, + name: func_name, + variant_name, + }, + )); + } } + } +} + +fn find_generics_to_replace(tipo: &mut Arc, generic_types: &HashMap>) { + if let Some(id) = tipo.get_generic() { + *tipo = generic_types.get(&id).unwrap().clone(); } else if tipo.is_generic() { match &**tipo { Type::App { @@ -3992,7 +4107,7 @@ fn get_generics_and_type(tipo: &Type, param: &Type) -> Vec<(u64, Arc)> { generics_ids } -fn get_variant_name(new_name: &mut String, t: Arc) { +fn get_variant_name(new_name: &mut String, t: &Arc) { new_name.push_str(&format!( "_{}", if t.is_string() { @@ -4007,13 +4122,13 @@ fn get_variant_name(new_name: &mut String, t: Arc) { let fst_type = &pair_type.get_inner_types()[0]; let snd_type = &pair_type.get_inner_types()[1]; - get_variant_name(&mut full_type, fst_type.clone()); - get_variant_name(&mut full_type, snd_type.clone()); + get_variant_name(&mut full_type, fst_type); + get_variant_name(&mut full_type, snd_type); full_type } else if t.is_list() { let mut full_type = "list".to_string(); let list_type = &t.get_inner_types()[0]; - get_variant_name(&mut full_type, list_type.clone()); + get_variant_name(&mut full_type, list_type); full_type } else { "data".to_string()