From 7c5b9aa35ebf14429e4e1c2c60fa6e2e5bd90bbc Mon Sep 17 00:00:00 2001 From: rvcas Date: Tue, 2 Apr 2024 19:22:19 -0400 Subject: [PATCH] feat(lsp): find_node for TypedArgVia --- crates/aiken-lang/src/ast.rs | 162 +++++++++++++++++++++++++-------- crates/aiken-lsp/src/server.rs | 10 ++ 2 files changed, 132 insertions(+), 40 deletions(-) diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index abade95d..ccd367f2 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -237,6 +237,34 @@ pub struct Function { pub can_error: bool, } +impl TypedFunction { + pub fn find_node(&self, byte_index: usize) -> Option> { + self.arguments + .iter() + .find_map(|arg| arg.find_node(byte_index)) + .or_else(|| self.body.find_node(byte_index)) + .or_else(|| { + self.return_annotation + .as_ref() + .and_then(|a| a.find_node(byte_index)) + }) + } +} + +impl TypedTest { + pub fn find_node(&self, byte_index: usize) -> Option> { + self.arguments + .iter() + .find_map(|arg| arg.find_node(byte_index)) + .or_else(|| self.body.find_node(byte_index)) + .or_else(|| { + self.return_annotation + .as_ref() + .and_then(|a| a.find_node(byte_index)) + }) + } +} + pub type TypedTypeAlias = TypeAlias>; pub type UntypedTypeAlias = TypeAlias<()>; @@ -496,6 +524,18 @@ pub struct Validator { } impl TypedValidator { + pub fn find_node(&self, byte_index: usize) -> Option> { + self.params + .iter() + .find_map(|arg| arg.find_node(byte_index)) + .or_else(|| self.fun.find_node(byte_index)) + .or_else(|| { + self.other_fun + .as_ref() + .and_then(|f| f.find_node(byte_index)) + }) + } + pub fn into_function_definition<'a, F>( &'a self, module_name: &str, @@ -592,43 +632,18 @@ impl TypedDefinition { pub fn find_node(&self, byte_index: usize) -> Option> { // Note that the fn span covers the function head, not // the entire statement. - match self { - Definition::Validator(Validator { - fun: Function { body, .. }, - other_fun: - Some(Function { - body: other_body, .. - }), - .. - }) => { - if let Some(located) = body.find_node(byte_index) { - return Some(located); - } + let located = match self { + Definition::Validator(validator) => validator.find_node(byte_index), + Definition::Fn(func) => func.find_node(byte_index), + Definition::Test(func) => func.find_node(byte_index), + _ => None, + }; - if let Some(located) = other_body.find_node(byte_index) { - return Some(located); - } - } - - Definition::Fn(Function { body, .. }) - | Definition::Test(Function { body, .. }) - | Definition::Validator(Validator { - fun: Function { body, .. }, - .. - }) => { - if let Some(located) = body.find_node(byte_index) { - return Some(located); - } - } - - _ => (), + if located.is_none() && self.location().contains(byte_index) { + return Some(Located::Definition(self)); } - if self.location().contains(byte_index) { - Some(Located::Definition(self)) - } else { - None - } + located } } @@ -637,20 +652,22 @@ pub enum Located<'a> { Expression(&'a TypedExpr), Pattern(&'a TypedPattern, Rc), Definition(&'a TypedDefinition), + Argument(&'a ArgName, Rc), + Annotation(&'a Annotation), } impl<'a> Located<'a> { pub fn definition_location(&self) -> Option> { match self { Self::Expression(expression) => expression.definition_location(), - // TODO: Revise definition location semantic for 'Pattern' - // e.g. for constructors, we might want to show the type definition - // for that constructor. - Self::Pattern(_, _) => None, Self::Definition(definition) => Some(DefinitionLocation { module: None, span: definition.location(), }), + // TODO: Revise definition location semantic for 'Pattern' + // e.g. for constructors, we might want to show the type definition + // for that constructor. + Self::Pattern(_, _) | Located::Argument(_, _) | Located::Annotation(_) => None, } } } @@ -811,6 +828,18 @@ impl Arg { } } +impl TypedArg { + pub fn find_node(&self, byte_index: usize) -> Option> { + if self.arg_name.location().contains(byte_index) { + Some(Located::Argument(&self.arg_name, self.tipo.clone())) + } else { + self.annotation + .as_ref() + .and_then(|annotation| annotation.find_node(byte_index)) + } + } +} + pub type TypedArgVia = ArgVia, TypedExpr>; pub type UntypedArgVia = ArgVia<(), UntypedExpr>; @@ -835,6 +864,23 @@ impl From> for Arg { } } +impl TypedArgVia { + pub fn find_node(&self, byte_index: usize) -> Option> { + if self.arg_name.location().contains(byte_index) { + Some(Located::Argument(&self.arg_name, self.tipo.clone())) + } else { + // `via` is done first here because when there is no manually written + // annotation, it seems one is injected leading to a `found` returning too early + // because the span of the filled in annotation matches the span of the via expr. + self.via.find_node(byte_index).or_else(|| { + self.annotation + .as_ref() + .and_then(|annotation| annotation.find_node(byte_index)) + }) + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum ArgName { Discarded { @@ -851,6 +897,15 @@ pub enum ArgName { } impl ArgName { + pub fn location(&self) -> Span { + match self { + ArgName::Discarded { location, .. } => *location, + ArgName::Named { location, .. } => *location, + } + } + + /// Returns the name of the variable if it is named, otherwise None. + /// Code gen uses the fact that this returns None to do certain things. pub fn get_variable_name(&self) -> Option<&str> { match self { ArgName::Discarded { .. } => None, @@ -858,10 +913,15 @@ impl ArgName { } } + pub fn get_name(&self) -> String { + match self { + ArgName::Discarded { name, .. } | ArgName::Named { name, .. } => name.clone(), + } + } + pub fn get_label(&self) -> String { match self { - ArgName::Discarded { label, .. } => label.to_string(), - ArgName::Named { label, .. } => label.to_string(), + ArgName::Discarded { label, .. } | ArgName::Named { label, .. } => label.to_string(), } } } @@ -1022,6 +1082,28 @@ impl Annotation { }, } } + + pub fn find_node(&self, byte_index: usize) -> Option> { + if !self.location().contains(byte_index) { + return None; + } + + let located = match self { + Annotation::Constructor { arguments, .. } => { + arguments.iter().find_map(|arg| arg.find_node(byte_index)) + } + Annotation::Fn { arguments, ret, .. } => arguments + .iter() + .find_map(|arg| arg.find_node(byte_index)) + .or_else(|| ret.find_node(byte_index)), + Annotation::Tuple { elems, .. } => { + elems.iter().find_map(|arg| arg.find_node(byte_index)) + } + Annotation::Var { .. } | Annotation::Hole { .. } => None, + }; + + located.or(Some(Located::Annotation(self))) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/crates/aiken-lsp/src/server.rs b/crates/aiken-lsp/src/server.rs index c79dba2b..db35e26a 100644 --- a/crates/aiken-lsp/src/server.rs +++ b/crates/aiken-lsp/src/server.rs @@ -399,6 +399,12 @@ impl Server { // TODO: autocompletion for expressions Some(Located::Expression(_expression)) => None, + + // TODO: autocompletion for arguments? + Some(Located::Argument(_arg_name, _tipo)) => None, + + // TODO: autocompletion for annotation? + Some(Located::Annotation(_annotation)) => None, } } @@ -510,7 +516,9 @@ impl Server { Some(expression.tipo()), ), Located::Pattern(pattern, tipo) => (pattern.location(), None, Some(tipo)), + Located::Argument(arg_name, tipo) => (arg_name.location(), None, Some(tipo)), Located::Definition(_) => return Ok(None), + Located::Annotation(_) => return Ok(None), }; let doc = definition_location @@ -525,6 +533,8 @@ impl Server { .and_then(|node| match node { Located::Expression(_) => None, Located::Pattern(_, _) => None, + Located::Argument(_, _) => None, + Located::Annotation(_) => None, Located::Definition(def) => def.doc(), }) .unwrap_or_default();