diff --git a/crates/lang/src/uplc.rs b/crates/lang/src/uplc.rs index 59e8f16f..6f9bcfa4 100644 --- a/crates/lang/src/uplc.rs +++ b/crates/lang/src/uplc.rs @@ -82,24 +82,59 @@ impl Default for ScopeLevels { } } +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct ConstrFieldKey { + pub local_var: String, + pub field_name: String, +} + +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct DataTypeKey { + pub module_name: String, + pub defined_type: String, +} + +pub type ConstrUsageKey = String; + +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct FunctionAccessKey { + pub module_name: String, + pub function_name: String, +} + +#[derive(Clone)] +pub struct ConstrConversionInfo { + local_var: String, + field: Option, + scope: ScopeLevels, + index: Option, + returning_type: String, +} + +#[derive(Clone)] +pub struct ScopedExpr { + scope: ScopeLevels, + expr: TypedExpr, +} + pub struct CodeGenerator<'a> { uplc_function_holder: Vec<(String, Term)>, - uplc_function_holder_lookup: IndexMap<(String, String), ScopeLevels>, - uplc_data_holder_lookup: IndexMap<(String, String, String), (ScopeLevels, TypedExpr)>, - uplc_data_constr_lookup: IndexMap<(String, String), ScopeLevels>, - uplc_data_usage_holder_lookup: IndexMap<(String, String), ScopeLevels>, - functions: &'a HashMap<(String, String), &'a Function, TypedExpr>>, + uplc_function_holder_lookup: IndexMap, + uplc_data_holder_lookup: IndexMap, + uplc_data_constr_lookup: IndexMap, + uplc_data_usage_holder_lookup: IndexMap, + functions: &'a HashMap, TypedExpr>>, // type_aliases: &'a HashMap<(String, String), &'a TypeAlias>>, - data_types: &'a HashMap<(String, String), &'a DataType>>, + data_types: &'a HashMap>>, // imports: &'a HashMap<(String, String), &'a Use>, // constants: &'a HashMap<(String, String), &'a ModuleConstant, String>>, } impl<'a> CodeGenerator<'a> { pub fn new( - functions: &'a HashMap<(String, String), &'a Function, TypedExpr>>, + functions: &'a HashMap, TypedExpr>>, // type_aliases: &'a HashMap<(String, String), &'a TypeAlias>>, - data_types: &'a HashMap<(String, String), &'a DataType>>, + data_types: &'a HashMap>>, // imports: &'a HashMap<(String, String), &'a Use>, // constants: &'a HashMap<(String, String), &'a ModuleConstant, String>>, ) -> Self { @@ -232,26 +267,45 @@ impl<'a> CodeGenerator<'a> { ValueConstructorVariant::ModuleFn { name, module, .. } => { if self .uplc_function_holder_lookup - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .is_none() { let func_def = self .functions - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .unwrap(); self.recurse_scope_level(&func_def.body, scope_level.clone()); - self.uplc_function_holder_lookup - .insert((module, name), scope_level); + self.uplc_function_holder_lookup.insert( + FunctionAccessKey { + module_name: module, + function_name: name, + }, + scope_level, + ); } else if scope_level.is_less_than( self.uplc_function_holder_lookup - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .unwrap(), false, ) { - self.uplc_function_holder_lookup - .insert((module, name), scope_level); + self.uplc_function_holder_lookup.insert( + FunctionAccessKey { + module_name: module, + function_name: name, + }, + scope_level, + ); } } ValueConstructorVariant::Record { .. } => { @@ -321,11 +375,10 @@ impl<'a> CodeGenerator<'a> { self.recurse_scope_level(&branch.body, scope_level.scope_increment_sequence(1)); } } - a @ TypedExpr::RecordAccess { label, record, .. } => { + expr @ TypedExpr::RecordAccess { label, record, .. } => { self.recurse_scope_level(record, scope_level.clone()); let mut is_var = false; - let mut current_var_name = "".to_string(); - let mut module = "".to_string(); + let mut current_var_name = String::new(); let mut current_record = *record.clone(); let mut current_scope = scope_level; while !is_var { @@ -336,19 +389,13 @@ impl<'a> CodeGenerator<'a> { constructor.clone().variant.clone(), (*constructor.tipo).clone(), ) { - ( - ValueConstructorVariant::LocalVariable { .. }, - Type::App { - module: app_module, .. - }, - ) => { + (ValueConstructorVariant::LocalVariable { .. }, Type::App { .. }) => { current_var_name = if current_var_name.is_empty() { name } else { format!("{name}_field_{current_var_name}") }; is_var = true; - module = app_module.to_string(); } _ => todo!(), }, @@ -365,35 +412,46 @@ impl<'a> CodeGenerator<'a> { } } - if let Some(val) = self.uplc_data_holder_lookup.get(&( - module.to_string(), - current_var_name.clone(), - label.clone(), - )) { - if current_scope.is_less_than(&val.0, false) { + if let Some(val) = self.uplc_data_holder_lookup.get(&ConstrFieldKey { + local_var: current_var_name.clone(), + field_name: label.clone(), + }) { + if current_scope.is_less_than(&val.scope, false) { self.uplc_data_holder_lookup.insert( - (module.to_string(), current_var_name.clone(), label.clone()), - (current_scope.clone(), a.clone()), + ConstrFieldKey { + local_var: current_var_name.clone(), + field_name: label.clone(), + }, + ScopedExpr { + scope: current_scope.clone(), + expr: expr.clone(), + }, ); } } else { self.uplc_data_holder_lookup.insert( - (module.to_string(), current_var_name.clone(), label.clone()), - (current_scope.clone(), a.clone()), + ConstrFieldKey { + local_var: current_var_name.clone(), + field_name: label.clone(), + }, + ScopedExpr { + scope: current_scope.clone(), + expr: expr.clone(), + }, ); } if let Some(val) = self .uplc_data_usage_holder_lookup - .get(&(module.to_string(), current_var_name.clone())) + .get(¤t_var_name.clone()) { if current_scope.is_less_than(val, false) { self.uplc_data_usage_holder_lookup - .insert((module, current_var_name), current_scope); + .insert(current_var_name, current_scope); } } else { self.uplc_data_usage_holder_lookup - .insert((module, current_var_name), current_scope); + .insert(current_var_name, current_scope); } } TypedExpr::ModuleSelect { constructor, .. } => match constructor { @@ -401,12 +459,18 @@ impl<'a> CodeGenerator<'a> { ModuleValueConstructor::Fn { module, name, .. } => { if self .uplc_function_holder_lookup - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .is_none() { let func_def = self .functions - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .unwrap(); self.recurse_scope_level( @@ -415,21 +479,35 @@ impl<'a> CodeGenerator<'a> { .scope_increment_sequence(func_def.arguments.len() as i32 + 1), ); - self.uplc_function_holder_lookup - .insert((module.to_string(), name.to_string()), scope_level); + self.uplc_function_holder_lookup.insert( + FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }, + scope_level, + ); } else if scope_level.is_less_than( self.uplc_function_holder_lookup - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .unwrap(), false, ) { let func_def = self .functions - .get(&(module.to_string(), name.to_string())) + .get(&FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }) .unwrap(); self.uplc_function_holder_lookup.insert( - (module.to_string(), name.to_string()), + FunctionAccessKey { + module_name: module.to_string(), + function_name: name.to_string(), + }, scope_level .scope_increment_sequence(func_def.arguments.len() as i32 + 1), ); @@ -471,17 +549,27 @@ impl<'a> CodeGenerator<'a> { match &**tipo { Type::App { module, name, .. } => { - if let Some(val) = self - .uplc_data_constr_lookup - .get(&(module.to_string(), name.clone())) - { + if let Some(val) = self.uplc_data_constr_lookup.get(&DataTypeKey { + module_name: module.to_string(), + defined_type: name.clone(), + }) { if scope_level.is_less_than(val, false) { - self.uplc_data_constr_lookup - .insert((module.to_string(), name.clone()), scope_level); + self.uplc_data_constr_lookup.insert( + DataTypeKey { + module_name: module.to_string(), + defined_type: name.clone(), + }, + scope_level, + ); } } else { - self.uplc_data_constr_lookup - .insert((module.to_string(), name.clone()), scope_level); + self.uplc_data_constr_lookup.insert( + DataTypeKey { + module_name: module.to_string(), + defined_type: name.clone(), + }, + scope_level, + ); } } Type::Fn { .. } => { @@ -507,7 +595,7 @@ impl<'a> CodeGenerator<'a> { rest => todo!("implement: {:#?}", rest), }; - let mut type_name = "".to_string(); + let mut type_name = String::new(); let mut is_app = false; let current_tipo = &*tipo; while !is_app { @@ -584,61 +672,70 @@ impl<'a> CodeGenerator<'a> { kind: AssignmentKind::Let, }; - if let Some(val) = self.uplc_data_holder_lookup.get(&( - module.to_string(), - var_name.clone(), - label.clone(), - )) { - if scope_level.is_less_than(&val.0, false) { + if let Some(val) = + self.uplc_data_holder_lookup.get(&ConstrFieldKey { + local_var: var_name.clone(), + field_name: label.clone(), + }) + { + if scope_level.is_less_than(&val.scope, false) { self.uplc_data_holder_lookup.insert( - ( - module.to_string(), - var_name.clone(), - label.clone(), - ), - ( - scope_level.scope_increment(1), - record_access.clone(), - ), + ConstrFieldKey { + local_var: var_name.clone(), + field_name: label.clone(), + }, + ScopedExpr { + scope: scope_level.scope_increment(1), + expr: record_access.clone(), + }, ); } } else { self.uplc_data_holder_lookup.insert( - (module.to_string(), var_name.clone(), label.clone()), - (scope_level.scope_increment(1), record_access.clone()), + ConstrFieldKey { + local_var: var_name.clone(), + field_name: label.clone(), + }, + ScopedExpr { + scope: scope_level.scope_increment(1), + expr: record_access.clone(), + }, ); } - if let Some(val) = self - .uplc_data_usage_holder_lookup - .get(&(module.to_string(), var_name.clone())) + if let Some(val) = + self.uplc_data_usage_holder_lookup.get(&var_name.clone()) { if scope_level.is_less_than(val, false) { - self.uplc_data_usage_holder_lookup.insert( - (module.to_string(), var_name.clone()), - scope_level.clone(), - ); + self.uplc_data_usage_holder_lookup + .insert(var_name.clone(), scope_level.clone()); } } else { - self.uplc_data_usage_holder_lookup.insert( - (module.to_string(), var_name.clone()), - scope_level.clone(), - ); + self.uplc_data_usage_holder_lookup + .insert(var_name.clone(), scope_level.clone()); } - if let Some(val) = self - .uplc_data_constr_lookup - .get(&(module.to_string(), type_name.clone())) + if let Some(val) = + self.uplc_data_constr_lookup.get(&DataTypeKey { + module_name: module.to_string(), + defined_type: type_name.clone(), + }) { if scope_level.is_less_than(val, false) { self.uplc_data_constr_lookup.insert( - (module.to_string(), type_name.clone()), + DataTypeKey { + module_name: module.to_string(), + defined_type: type_name.clone(), + }, scope_level.clone(), ); } } else { self.uplc_data_constr_lookup.insert( - (module.to_string(), type_name.clone()), + DataTypeKey { + module_name: module.to_string(), + defined_type: type_name.clone(), + }, scope_level.clone(), ); } @@ -671,7 +768,7 @@ impl<'a> CodeGenerator<'a> { .maybe_insert_def(term, scope_level.scope_increment_sequence(i as i32 + 1)); self.uplc_function_holder - .push(("".to_string(), term.clone())); + .push((String::new(), term.clone())); } self.uplc_function_holder.pop().unwrap().1 @@ -711,9 +808,10 @@ impl<'a> CodeGenerator<'a> { let mut term: Term = Term::Constant(Constant::ProtoList(uplc::ast::Type::Data, vec![])); - if let Some(data_type) = - self.data_types.get(&(module.to_string(), name.to_string())) - { + if let Some(data_type) = self.data_types.get(&DataTypeKey { + module_name: module.to_string(), + defined_type: name.to_string(), + }) { let constr = data_type .constructors .iter() @@ -1075,7 +1173,7 @@ impl<'a> CodeGenerator<'a> { let mut is_var = false; - let mut current_var_name = "".to_string(); + let mut current_var_name = String::new(); let mut current_subject = subject.clone(); @@ -1111,7 +1209,7 @@ impl<'a> CodeGenerator<'a> { let current_clauses = clauses.clone(); - let mut current_module = "".to_string(); + let mut current_module = String::new(); let mut total_constr_length = 0; let pattern = &clauses[0].pattern[0]; @@ -1120,12 +1218,16 @@ impl<'a> CodeGenerator<'a> { Pattern::Constructor { tipo, .. } => { let mut is_app = false; let mut tipo = &**tipo; - let mut key: (String, String) = ("".to_string(), "".to_string()); + let mut key = DataTypeKey { + module_name: String::new(), + defined_type: String::new(), + }; while !is_app { match tipo { Type::App { module, name, .. } => { is_app = true; - key = (module.clone(), name.clone()); + key.module_name = module.clone(); + key.defined_type = name.clone(); } Type::Fn { ret, .. } => { tipo = ret; @@ -1161,13 +1263,13 @@ impl<'a> CodeGenerator<'a> { let label = field.clone().label.unwrap_or(format!("{index}")); - if let Some((_, TypedExpr::Assignment { pattern, .. })) = - self.uplc_data_holder_lookup.get(&( - key.0.clone(), - current_var_name.to_string(), - label.clone(), - )) - { + if let Some(ScopedExpr { + expr: TypedExpr::Assignment { pattern, .. }, + .. + }) = self.uplc_data_holder_lookup.get(&ConstrFieldKey { + local_var: current_var_name.to_string(), + field_name: label.clone(), + }) { let var_name = match pattern { Pattern::Var { name, .. } => name, _ => todo!(), @@ -1384,7 +1486,7 @@ impl<'a> CodeGenerator<'a> { } TypedExpr::RecordAccess { label, record, .. } => { let mut is_var = false; - let mut current_var_name = "".to_string(); + let mut current_var_name = String::new(); let mut current_record = *record.clone(); while !is_var { match current_record.clone() { @@ -1442,15 +1544,15 @@ impl<'a> CodeGenerator<'a> { let mut term = current_term; // attempt to insert function definitions where needed - for func in self.uplc_function_holder_lookup.clone().keys() { + for func_key in self.uplc_function_holder_lookup.clone().keys() { if scope_level.is_less_than( - self.uplc_function_holder_lookup.clone().get(func).unwrap(), + self.uplc_function_holder_lookup + .clone() + .get(func_key) + .unwrap(), false, ) { - let func_def = self - .functions - .get(&(func.0.to_string(), func.1.to_string())) - .unwrap(); + let func_def = self.functions.get(func_key).unwrap(); let mut function_body = self.recurse_code_gen( &func_def.body, @@ -1470,7 +1572,7 @@ impl<'a> CodeGenerator<'a> { term = Term::Apply { function: Term::Lambda { parameter_name: Name { - text: format!("{}_{}", func.0, func.1), + text: format!("{}_{}", func_key.module_name, func_key.function_name), unique: 0.into(), }, body: term.into(), @@ -1478,7 +1580,7 @@ impl<'a> CodeGenerator<'a> { .into(), argument: function_body.into(), }; - self.uplc_function_holder_lookup.shift_remove(func); + self.uplc_function_holder_lookup.shift_remove(func_key); } } @@ -1577,7 +1679,7 @@ impl<'a> CodeGenerator<'a> { } .into(), }; - let module = &key.0; + let module = &key.module_name; term = Term::Apply { function: Term::Lambda { @@ -1594,39 +1696,37 @@ impl<'a> CodeGenerator<'a> { } } - // Pull out all uplc data holder and data usage, filter by Scope Level, Sort By Scope Depth, Then Apply - #[allow(clippy::type_complexity)] - let mut data_holder: Vec<((String, String, String), (ScopeLevels, i128, String))> = self + // Pull out all uplc data holder fields and data usage, filter by Scope Level, Sort By Scope Depth, Then Apply + let mut data_holder: Vec = self .uplc_data_usage_holder_lookup - .iter() - .filter(|record_scope| scope_level.is_less_than(record_scope.1, false)) - .map(|((module, name), scope)| { - ( - (module.to_string(), name.to_string(), "".to_string()), - (scope.clone(), -1, "".to_string()), - ) + .clone() + .into_iter() + .filter(|record_scope| scope_level.is_less_than(&record_scope.1, false)) + .map(|(var_name, scope)| ConstrConversionInfo { + local_var: var_name, + field: None, + scope, + index: None, + returning_type: String::new(), }) .collect(); data_holder.extend( self.uplc_data_holder_lookup - .iter() - .filter(|record_scope| scope_level.is_less_than(&record_scope.1 .0, false)) - .map(|((module, name, label), (scope, expr))| { - let index_type = match expr { - TypedExpr::RecordAccess { index, tipo, .. } => { - let tipo = &**tipo; - - let name = match tipo { - Type::App { name, .. } => name, - Type::Fn { .. } => todo!(), - Type::Var { .. } => todo!(), - }; - (index, name.clone()) - } - TypedExpr::Assignment { value, .. } => match &**value { + .clone() + .into_iter() + .filter(|record_scope| scope_level.is_less_than(&record_scope.1.scope, false)) + .map( + |( + ConstrFieldKey { + local_var, + field_name, + }, + ScopedExpr { scope, expr }, + )| { + let index_type = match expr { TypedExpr::RecordAccess { index, tipo, .. } => { - let tipo = &**tipo; + let tipo = &*tipo; let name = match tipo { Type::App { name, .. } => name, @@ -1635,62 +1735,56 @@ impl<'a> CodeGenerator<'a> { }; (index, name.clone()) } - _ => todo!(), - }, - _ => todo!(), - }; + TypedExpr::Assignment { value, .. } => match *value { + TypedExpr::RecordAccess { index, tipo, .. } => { + let tipo = &*tipo; - ( - (module.to_string(), name.to_string(), label.to_string()), - (scope.clone(), *index_type.0 as i128, index_type.1), - ) - }) - .collect::>(), + let name = match tipo { + Type::App { name, .. } => name, + Type::Fn { .. } => todo!(), + Type::Var { .. } => todo!(), + }; + (index, name.clone()) + } + _ => todo!(), + }, + _ => todo!(), + }; + + ConstrConversionInfo { + local_var, + field: Some(field_name), + scope, + index: Some(index_type.0), + returning_type: index_type.1, + } + }, + ) + .collect::>(), ); data_holder.sort_by(|item1, item2| { - if item1.1 .0.is_less_than(&item2.1 .0, true) { + if item1.scope.is_less_than(&item2.scope, true) { Ordering::Less - } else if item2.1 .0.is_less_than(&item1.1 .0, true) { + } else if item2.scope.is_less_than(&item1.scope, true) { Ordering::Greater - } else if item1.1 .1 < item2.1 .1 { + } else if item1.index < item2.index { Ordering::Less - } else if item2.1 .1 < item1.1 .1 { + } else if item2.index < item1.index { Ordering::Greater } else { Ordering::Equal } }); - for (key @ (module, name, label), (_, index, tipo)) in data_holder.iter().rev() { - if index < &0 { - term = Term::Apply { - function: Term::Lambda { - parameter_name: Name { - text: format!("{name}_fields"), - unique: 0.into(), - }, - body: term.into(), - } - .into(), - // TODO: Find proper scope for this function if at all. - argument: Term::Apply { - function: Term::Var(Name { - text: "constr_fields_exposer".to_string(), - unique: 0.into(), - }) - .into(), - argument: Term::Var(Name { - text: name.to_string(), - unique: 0.into(), - }) - .into(), - } - .into(), - }; - - self.uplc_data_usage_holder_lookup - .shift_remove(&(module.clone(), name.clone())); - } else { + for ConstrConversionInfo { + local_var, + field, + index, + returning_type, + .. + } in data_holder.into_iter().rev() + { + if let (Some(index), Some(field)) = (index, field) { let var_term = Term::Apply { function: Term::Apply { function: Term::Var(Name { @@ -1699,16 +1793,16 @@ impl<'a> CodeGenerator<'a> { }) .into(), argument: Term::Var(Name { - text: format!("{name}_fields"), + text: format!("{local_var}_fields"), unique: 0.into(), }) .into(), } .into(), - argument: Term::Constant(Constant::Integer(*index as i128)).into(), + argument: Term::Constant(Constant::Integer(index as i128)).into(), }; - let type_conversion = match tipo.as_str() { + let type_conversion = match returning_type.as_str() { "ByteArray" => Term::Apply { function: Term::Builtin(DefaultFunction::UnBData).into(), argument: var_term.into(), @@ -1723,7 +1817,7 @@ impl<'a> CodeGenerator<'a> { term = Term::Apply { function: Term::Lambda { parameter_name: Name { - text: format!("{name}_field_{label}"), + text: format!("{local_var}_field_{field}"), unique: 0.into(), }, body: term.into(), @@ -1731,7 +1825,37 @@ impl<'a> CodeGenerator<'a> { .into(), argument: type_conversion.into(), }; - self.uplc_data_holder_lookup.shift_remove(key); + self.uplc_data_holder_lookup.shift_remove(&ConstrFieldKey { + local_var, + field_name: field, + }); + } else { + term = Term::Apply { + function: Term::Lambda { + parameter_name: Name { + text: format!("{local_var}_fields"), + unique: 0.into(), + }, + body: term.into(), + } + .into(), + // TODO: Find proper scope for this function if at all. + argument: Term::Apply { + function: Term::Var(Name { + text: "constr_fields_exposer".to_string(), + unique: 0.into(), + }) + .into(), + argument: Term::Var(Name { + text: local_var.to_string(), + unique: 0.into(), + }) + .into(), + } + .into(), + }; + + self.uplc_data_usage_holder_lookup.shift_remove(&local_var); } } diff --git a/crates/project/src/lib.rs b/crates/project/src/lib.rs index 62180ae7..57a61f93 100644 --- a/crates/project/src/lib.rs +++ b/crates/project/src/lib.rs @@ -14,7 +14,7 @@ use aiken_lang::{ ast::{Definition, Function, ModuleKind, TypedFunction}, builtins, tipo::TypeInfo, - uplc::CodeGenerator, + uplc::{CodeGenerator, DataTypeKey, FunctionAccessKey}, IdGenerator, }; use pallas::{ @@ -312,13 +312,25 @@ impl Project { for def in module.ast.definitions() { match def { Definition::Fn(func) => { - functions.insert((module.name.clone(), func.name.clone()), func); + functions.insert( + FunctionAccessKey { + module_name: module.name.clone(), + function_name: func.name.clone(), + }, + func, + ); } Definition::TypeAlias(ta) => { type_aliases.insert((module.name.clone(), ta.alias.clone()), ta); } Definition::DataType(dt) => { - data_types.insert((module.name.clone(), dt.name.clone()), dt); + data_types.insert( + DataTypeKey { + module_name: module.name.clone(), + defined_type: dt.name.clone(), + }, + dt, + ); } Definition::Use(import) => { imports.insert((module.name.clone(), import.module.join("/")), import);