1572 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Rust
		
	
	
	
			
		
		
	
	
			1572 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Rust
		
	
	
	
| use core::fmt;
 | |
| use pretty::RcDoc;
 | |
| use std::{cmp::Ordering, fmt::Display, rc::Rc};
 | |
| 
 | |
| use indexmap::IndexMap;
 | |
| use itertools::{Itertools, Position};
 | |
| 
 | |
| use crate::{
 | |
|     ast::{DataTypeKey, Pattern, TypedClause, TypedDataType, TypedPattern},
 | |
|     expr::{lookup_data_type_by_tipo, Type, TypeVar, TypedExpr},
 | |
| };
 | |
| 
 | |
| use super::{interner::AirInterner, tree::AirTree};
 | |
| 
 | |
| const PAIR_NEW_COLUMNS: usize = 2;
 | |
| 
 | |
| const MIN_NEW_COLUMNS: usize = 1;
 | |
| 
 | |
| #[derive(Clone, Default, Copy)]
 | |
| struct Occurrence {
 | |
|     passed_wild_card: bool,
 | |
|     amount: usize,
 | |
| }
 | |
| 
 | |
| #[derive(Clone, Debug)]
 | |
| pub enum Path {
 | |
|     Pair(usize),
 | |
|     Tuple(usize),
 | |
|     Constr(Rc<Type>, usize),
 | |
|     OpaqueConstr(Rc<Type>),
 | |
|     List(usize),
 | |
|     ListTail(usize),
 | |
| }
 | |
| 
 | |
| impl ToString for Path {
 | |
|     fn to_string(&self) -> String {
 | |
|         match self {
 | |
|             Path::Pair(i) => {
 | |
|                 format!("pair_{}", i)
 | |
|             }
 | |
|             Path::Tuple(i) => {
 | |
|                 format!("tuple_{}", i)
 | |
|             }
 | |
|             Path::Constr(_, i) => {
 | |
|                 format!("constr_{}", i)
 | |
|             }
 | |
|             Path::OpaqueConstr(_) => "opaqueconstr".to_string(),
 | |
|             Path::List(i) => {
 | |
|                 format!("list_{}", i)
 | |
|             }
 | |
|             Path::ListTail(i) => {
 | |
|                 format!("listtail_{}", i)
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl PartialEq for Path {
 | |
|     fn eq(&self, other: &Self) -> bool {
 | |
|         match (self, other) {
 | |
|             (Path::Pair(a), Path::Pair(b))
 | |
|             | (Path::Tuple(a), Path::Tuple(b))
 | |
|             | (Path::Constr(_, a), Path::Constr(_, b))
 | |
|             | (Path::List(a), Path::List(b))
 | |
|             | (Path::ListTail(a), Path::ListTail(b)) => a == b,
 | |
|             (Path::OpaqueConstr(_), Path::OpaqueConstr(_)) => true,
 | |
|             _ => false,
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Eq for Path {}
 | |
| 
 | |
| #[derive(Clone, Debug)]
 | |
| pub struct Assigned {
 | |
|     pub path: Vec<Path>,
 | |
|     pub assigned: String,
 | |
| }
 | |
| 
 | |
| #[derive(Clone, Debug)]
 | |
| struct RowItem<'a> {
 | |
|     path: Vec<Path>,
 | |
|     pattern: &'a TypedPattern,
 | |
| }
 | |
| 
 | |
| #[derive(Clone, Debug)]
 | |
| struct Row<'a> {
 | |
|     assigns: Vec<Assigned>,
 | |
|     columns: Vec<RowItem<'a>>,
 | |
|     then: String,
 | |
| }
 | |
| 
 | |
| #[derive(Clone, Debug)]
 | |
| struct PatternMatrix<'a> {
 | |
|     rows: Vec<Row<'a>>,
 | |
| }
 | |
| 
 | |
| #[derive(Debug, Clone, Eq, PartialEq)]
 | |
| pub enum CaseTest {
 | |
|     Constr(usize),
 | |
|     Int(String),
 | |
|     Bytes(Vec<u8>),
 | |
|     List(usize),
 | |
|     ListWithTail(usize),
 | |
|     Wild,
 | |
| }
 | |
| 
 | |
| impl CaseTest {
 | |
|     pub fn get_air_pattern(&self, current_type: Rc<Type>) -> AirTree {
 | |
|         match self {
 | |
|             CaseTest::Constr(i) => {
 | |
|                 if current_type.is_bool() {
 | |
|                     AirTree::bool(1 == *i)
 | |
|                 } else {
 | |
|                     AirTree::int(i)
 | |
|                 }
 | |
|             }
 | |
|             CaseTest::Int(i) => AirTree::int(i),
 | |
|             CaseTest::Bytes(vec) => AirTree::byte_array(vec.clone()),
 | |
|             CaseTest::List(_) => unreachable!(),
 | |
|             CaseTest::ListWithTail(_) => unreachable!(),
 | |
|             CaseTest::Wild => unreachable!(),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Display for CaseTest {
 | |
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 | |
|         match self {
 | |
|             CaseTest::Constr(i) => write!(f, "Constr({})", i),
 | |
|             CaseTest::Int(i) => write!(f, "Int({})", i),
 | |
|             CaseTest::Bytes(vec) => write!(f, "Bytes({:?})", vec),
 | |
|             CaseTest::List(i) => write!(f, "List({})", i),
 | |
|             CaseTest::ListWithTail(i) => write!(f, "ListWithTail({})", i),
 | |
|             CaseTest::Wild => write!(f, "Wild"),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug, Clone)]
 | |
| pub enum DecisionTree<'a> {
 | |
|     Switch {
 | |
|         path: Vec<Path>,
 | |
|         cases: Vec<(CaseTest, DecisionTree<'a>)>,
 | |
|         default: Option<Box<DecisionTree<'a>>>,
 | |
|     },
 | |
|     ListSwitch {
 | |
|         path: Vec<Path>,
 | |
|         cases: Vec<(CaseTest, DecisionTree<'a>)>,
 | |
|         tail_cases: Vec<(CaseTest, DecisionTree<'a>)>,
 | |
|         default: Option<Box<DecisionTree<'a>>>,
 | |
|     },
 | |
|     HoistedLeaf(String, Vec<Assigned>),
 | |
|     HoistThen {
 | |
|         name: String,
 | |
|         assigns: Vec<Assigned>,
 | |
|         pattern: Box<DecisionTree<'a>>,
 | |
|         then: &'a TypedExpr,
 | |
|     },
 | |
| }
 | |
| 
 | |
| #[derive(Eq, Hash, PartialEq, Clone, Debug)]
 | |
| pub enum ScopePath {
 | |
|     Case(usize),
 | |
|     Fallback,
 | |
| }
 | |
| 
 | |
| impl PartialOrd for ScopePath {
 | |
|     fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 | |
|         match (self, other) {
 | |
|             (ScopePath::Case(a), ScopePath::Case(b)) => Some(b.cmp(a)),
 | |
|             (ScopePath::Case(_), ScopePath::Fallback) => Some(Ordering::Greater),
 | |
|             (ScopePath::Fallback, ScopePath::Case(_)) => Some(Ordering::Less),
 | |
|             (ScopePath::Fallback, ScopePath::Fallback) => Some(Ordering::Equal),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Ord for ScopePath {
 | |
|     fn cmp(&self, other: &Self) -> Ordering {
 | |
|         self.partial_cmp(other).unwrap()
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Eq, Hash, PartialEq, Clone, Debug, Default, PartialOrd, Ord)]
 | |
| pub struct Scope {
 | |
|     scope: Vec<ScopePath>,
 | |
| }
 | |
| 
 | |
| impl Scope {
 | |
|     pub fn new() -> Self {
 | |
|         Self { scope: vec![] }
 | |
|     }
 | |
| 
 | |
|     pub fn push(&mut self, path: ScopePath) {
 | |
|         self.scope.push(path);
 | |
|     }
 | |
| 
 | |
|     pub fn pop(&mut self) {
 | |
|         self.scope.pop();
 | |
|     }
 | |
| 
 | |
|     pub fn common_ancestor(&mut self, other: &Scope) {
 | |
|         let scope = std::mem::replace(&mut self.scope, vec![]);
 | |
| 
 | |
|         self.scope = scope
 | |
|             .into_iter()
 | |
|             .zip(other.scope.iter())
 | |
|             .map_while(|(a, b)| if a == *b { Some(a) } else { None })
 | |
|             .collect_vec()
 | |
|     }
 | |
| 
 | |
|     pub fn len(&self) -> usize {
 | |
|         self.scope.len()
 | |
|     }
 | |
| 
 | |
|     pub fn is_empty(&self) -> bool {
 | |
|         self.scope.is_empty()
 | |
|     }
 | |
| }
 | |
| 
 | |
| enum Marker<'a, 'b> {
 | |
|     Pop,
 | |
|     Push(ScopePath, &'b DecisionTree<'a>),
 | |
|     PopPush(ScopePath, &'b DecisionTree<'a>),
 | |
| }
 | |
| 
 | |
| impl<'a> DecisionTree<'a> {
 | |
|     pub fn to_pretty(&self) -> String {
 | |
|         let mut w = Vec::new();
 | |
| 
 | |
|         self.to_doc().render(80, &mut w).unwrap();
 | |
| 
 | |
|         String::from_utf8(w)
 | |
|             .unwrap()
 | |
|             .lines()
 | |
|             // This is a hack to deal with blank newlines
 | |
|             // that end up with a bunch of useless whitespace
 | |
|             // because of the nesting
 | |
|             .map(|l| {
 | |
|                 if l.chars().all(|c| c.is_whitespace()) {
 | |
|                     "".to_string()
 | |
|                 } else {
 | |
|                     l.to_string()
 | |
|                 }
 | |
|             })
 | |
|             .collect::<Vec<_>>()
 | |
|             .join("\n")
 | |
|     }
 | |
| 
 | |
|     fn to_doc(&self) -> RcDoc<()> {
 | |
|         match self {
 | |
|             DecisionTree::Switch {
 | |
|                 path,
 | |
|                 cases,
 | |
|                 default,
 | |
|                 ..
 | |
|             } => RcDoc::text("Switch(")
 | |
|                 .append(
 | |
|                     path.iter()
 | |
|                         .fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| {
 | |
|                             acc.append(
 | |
|                                 RcDoc::line()
 | |
|                                     .append(RcDoc::text(format!("{}", p.to_string())).nest(4)),
 | |
|                             )
 | |
|                         })
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(RcDoc::text(")"))
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     cases
 | |
|                         .iter()
 | |
|                         .fold(
 | |
|                             RcDoc::line().append(RcDoc::text("cases(")),
 | |
|                             |acc, (con, tree)| {
 | |
|                                 acc.append(RcDoc::line())
 | |
|                                     .append(RcDoc::text(format!("({}): ", con)))
 | |
|                                     .append(RcDoc::line())
 | |
|                                     .append(tree.to_doc().nest(4))
 | |
|                             },
 | |
|                         )
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(RcDoc::text(")"))
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     RcDoc::line()
 | |
|                         .append(RcDoc::text("default : "))
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(
 | |
|                             default
 | |
|                                 .as_ref()
 | |
|                                 .map(|i| i.to_doc())
 | |
|                                 .unwrap_or(RcDoc::text("None")),
 | |
|                         )
 | |
|                         .append(RcDoc::line())
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(RcDoc::text(")")),
 | |
|             DecisionTree::ListSwitch {
 | |
|                 path,
 | |
|                 cases,
 | |
|                 tail_cases,
 | |
|                 default,
 | |
|                 ..
 | |
|             } => RcDoc::text("ListSwitch(")
 | |
|                 .append(
 | |
|                     path.iter()
 | |
|                         .fold(RcDoc::line().append(RcDoc::text("path(")), |acc, p| {
 | |
|                             acc.append(
 | |
|                                 RcDoc::line()
 | |
|                                     .append(RcDoc::text(format!("{}", p.to_string())).nest(4)),
 | |
|                             )
 | |
|                         })
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(RcDoc::text(")"))
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     cases
 | |
|                         .iter()
 | |
|                         .fold(
 | |
|                             RcDoc::line().append(RcDoc::text("cases(")),
 | |
|                             |acc, (con, tree)| {
 | |
|                                 acc.append(RcDoc::line())
 | |
|                                     .append(RcDoc::text(format!("({}): ", con)))
 | |
|                                     .append(RcDoc::line())
 | |
|                                     .append(tree.to_doc().nest(4))
 | |
|                             },
 | |
|                         )
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(RcDoc::text(")"))
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     tail_cases
 | |
|                         .iter()
 | |
|                         .fold(
 | |
|                             RcDoc::line().append(RcDoc::text("tail_cases(")),
 | |
|                             |acc, (con, tree)| {
 | |
|                                 acc.append(RcDoc::line())
 | |
|                                     .append(RcDoc::text(format!("({}): ", con)))
 | |
|                                     .append(RcDoc::line())
 | |
|                                     .append(tree.to_doc().nest(4))
 | |
|                             },
 | |
|                         )
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(RcDoc::text(")"))
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     RcDoc::line()
 | |
|                         .append(RcDoc::text("default : "))
 | |
|                         .append(RcDoc::line())
 | |
|                         .append(
 | |
|                             default
 | |
|                                 .as_ref()
 | |
|                                 .map(|i| i.to_doc())
 | |
|                                 .unwrap_or(RcDoc::text("None")),
 | |
|                         )
 | |
|                         .append(RcDoc::line())
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(RcDoc::text(")")),
 | |
|             DecisionTree::HoistedLeaf(name, _) => RcDoc::text(format!("Leaf({})", name)),
 | |
|             DecisionTree::HoistThen { name, pattern, .. } => RcDoc::text("HoistThen(")
 | |
|                 .append(
 | |
|                     RcDoc::line()
 | |
|                         .append(RcDoc::text(format!("name : {}", name)))
 | |
|                         .append(RcDoc::line())
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(
 | |
|                     RcDoc::line()
 | |
|                         .append(pattern.to_doc())
 | |
|                         .append(RcDoc::line())
 | |
|                         .nest(4),
 | |
|                 )
 | |
|                 .append(RcDoc::text(")")),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     /// For fun I decided to do this without recursion
 | |
|     /// It doesn't look to bad lol
 | |
|     fn get_hoist_paths<'b>(&self, names: Vec<&'b String>) -> IndexMap<&'b String, Scope> {
 | |
|         let mut prev = vec![];
 | |
| 
 | |
|         let mut current_path = Scope::new();
 | |
| 
 | |
|         let mut tree = self;
 | |
| 
 | |
|         let mut scope_map: IndexMap<&String, Scope> =
 | |
|             names.into_iter().map(|item| (item, Scope::new())).collect();
 | |
| 
 | |
|         loop {
 | |
|             match tree {
 | |
|                 DecisionTree::Switch { cases, default, .. } => {
 | |
|                     prev.push(Marker::Pop);
 | |
| 
 | |
|                     if let Some(def) = default {
 | |
|                         prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref()));
 | |
|                     }
 | |
| 
 | |
|                     cases
 | |
|                         .iter()
 | |
|                         .enumerate()
 | |
|                         .rev()
 | |
|                         .for_each(|(index, (_, detree))| {
 | |
|                             if index == 0 {
 | |
|                                 prev.push(Marker::Push(ScopePath::Case(index), detree));
 | |
|                             } else {
 | |
|                                 prev.push(Marker::PopPush(ScopePath::Case(index), detree));
 | |
|                             }
 | |
|                         });
 | |
|                 }
 | |
| 
 | |
|                 DecisionTree::ListSwitch {
 | |
|                     cases,
 | |
|                     tail_cases,
 | |
|                     default,
 | |
|                     ..
 | |
|                 } => {
 | |
|                     prev.push(Marker::Pop);
 | |
| 
 | |
|                     if let Some(def) = default {
 | |
|                         prev.push(Marker::PopPush(ScopePath::Fallback, def.as_ref()));
 | |
|                     }
 | |
| 
 | |
|                     tail_cases
 | |
|                         .iter()
 | |
|                         .enumerate()
 | |
|                         .rev()
 | |
|                         .for_each(|(index, (_, detree))| {
 | |
|                             prev.push(Marker::PopPush(
 | |
|                                 ScopePath::Case(index + cases.len()),
 | |
|                                 detree,
 | |
|                             ));
 | |
|                         });
 | |
| 
 | |
|                     cases
 | |
|                         .iter()
 | |
|                         .enumerate()
 | |
|                         .rev()
 | |
|                         .for_each(|(index, (_, detree))| {
 | |
|                             if index == 0 {
 | |
|                                 prev.push(Marker::Push(ScopePath::Case(index), detree));
 | |
|                             } else {
 | |
|                                 prev.push(Marker::PopPush(ScopePath::Case(index), detree));
 | |
|                             }
 | |
|                         });
 | |
|                 }
 | |
|                 DecisionTree::HoistedLeaf(leaf_name, _) => {
 | |
|                     let scope_for_name = scope_map
 | |
|                         .get_mut(leaf_name)
 | |
|                         .expect("Impossible, Leaf is based off of given names");
 | |
| 
 | |
|                     if scope_for_name.is_empty() {
 | |
|                         *scope_for_name = current_path.clone();
 | |
|                     } else {
 | |
|                         scope_for_name.common_ancestor(¤t_path);
 | |
|                     }
 | |
|                 }
 | |
|                 DecisionTree::HoistThen { .. } => unreachable!(),
 | |
|             }
 | |
| 
 | |
|             if let Some(action) = prev.pop() {
 | |
|                 match action {
 | |
|                     Marker::Pop => {
 | |
|                         current_path.pop();
 | |
|                     }
 | |
|                     Marker::Push(p, dec_tree) => {
 | |
|                         current_path.push(p);
 | |
| 
 | |
|                         tree = dec_tree;
 | |
|                     }
 | |
|                     Marker::PopPush(p, dec_tree) => {
 | |
|                         current_path.pop();
 | |
| 
 | |
|                         current_path.push(p);
 | |
| 
 | |
|                         tree = dec_tree;
 | |
|                     }
 | |
|                 }
 | |
|             } else {
 | |
|                 break;
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         scope_map
 | |
|     }
 | |
| 
 | |
|     fn hoist_by_path(
 | |
|         &mut self,
 | |
|         current_path: &mut Scope,
 | |
|         name_paths: &mut Vec<(String, Scope)>,
 | |
|         hoistables: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
 | |
|     ) {
 | |
|         match self {
 | |
|             DecisionTree::Switch { cases, default, .. } => {
 | |
|                 cases.iter_mut().enumerate().for_each(|(index, (_, tree))| {
 | |
|                     current_path.push(ScopePath::Case(index));
 | |
|                     tree.hoist_by_path(current_path, name_paths, hoistables);
 | |
|                     current_path.pop();
 | |
|                 });
 | |
| 
 | |
|                 current_path.push(ScopePath::Fallback);
 | |
|                 if let Some(def) = default {
 | |
|                     def.hoist_by_path(current_path, name_paths, hoistables);
 | |
|                 }
 | |
|                 current_path.pop();
 | |
|             }
 | |
|             DecisionTree::ListSwitch {
 | |
|                 cases,
 | |
|                 tail_cases,
 | |
|                 default,
 | |
|                 ..
 | |
|             } => {
 | |
|                 cases.iter_mut().enumerate().for_each(|(index, (_, tree))| {
 | |
|                     current_path.push(ScopePath::Case(index));
 | |
|                     tree.hoist_by_path(current_path, name_paths, hoistables);
 | |
|                     current_path.pop();
 | |
|                 });
 | |
| 
 | |
|                 tail_cases
 | |
|                     .iter_mut()
 | |
|                     .enumerate()
 | |
|                     .for_each(|(index, (_, tree))| {
 | |
|                         current_path.push(ScopePath::Case(index + cases.len()));
 | |
|                         tree.hoist_by_path(current_path, name_paths, hoistables);
 | |
|                         current_path.pop();
 | |
|                     });
 | |
| 
 | |
|                 current_path.push(ScopePath::Fallback);
 | |
|                 if let Some(def) = default {
 | |
|                     def.hoist_by_path(current_path, name_paths, hoistables);
 | |
|                 }
 | |
|                 current_path.pop();
 | |
|             }
 | |
|             DecisionTree::HoistedLeaf(_, _) => (),
 | |
|             DecisionTree::HoistThen { .. } => unreachable!(),
 | |
|         }
 | |
| 
 | |
|         loop {
 | |
|             if let Some(name_path) = name_paths.pop() {
 | |
|                 if name_path.1 == *current_path {
 | |
|                     let (assigns, then) = hoistables.remove(&name_path.0).unwrap();
 | |
|                     let pattern =
 | |
|                         std::mem::replace(self, DecisionTree::HoistedLeaf("".to_string(), vec![]));
 | |
| 
 | |
|                     *self = DecisionTree::HoistThen {
 | |
|                         name: name_path.0,
 | |
|                         assigns,
 | |
|                         pattern: pattern.into(),
 | |
|                         then,
 | |
|                     };
 | |
|                 } else {
 | |
|                     name_paths.push(name_path);
 | |
|                     break;
 | |
|                 }
 | |
|             } else {
 | |
|                 break;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl<'a> Display for DecisionTree<'a> {
 | |
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 | |
|         write!(f, "{}", self.to_pretty())
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub struct TreeGen<'a, 'b> {
 | |
|     interner: &'b mut AirInterner,
 | |
|     data_types: &'b IndexMap<&'a DataTypeKey, &'a TypedDataType>,
 | |
|     wild_card_pattern: RowItem<'a>,
 | |
| }
 | |
| 
 | |
| impl<'a, 'b> TreeGen<'a, 'b> {
 | |
|     pub fn new(
 | |
|         interner: &'b mut AirInterner,
 | |
|         data_types: &'b IndexMap<&'a DataTypeKey, &'a TypedDataType>,
 | |
|         wild_card_pattern: &'a TypedPattern,
 | |
|     ) -> Self {
 | |
|         TreeGen {
 | |
|             interner,
 | |
|             data_types,
 | |
|             wild_card_pattern: RowItem {
 | |
|                 path: vec![],
 | |
|                 pattern: wild_card_pattern,
 | |
|             },
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     pub fn build_tree(
 | |
|         mut self,
 | |
|         subject_name: &String,
 | |
|         subject_tipo: &Rc<Type>,
 | |
|         clauses: &'a [TypedClause],
 | |
|     ) -> DecisionTree<'a> {
 | |
|         let mut hoistables = IndexMap::new();
 | |
| 
 | |
|         let rows = clauses
 | |
|             .iter()
 | |
|             .enumerate()
 | |
|             .map(|(index, clause)| {
 | |
|                 let (assign, row_items) =
 | |
|                     self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);
 | |
| 
 | |
|                 self.interner.intern(format!("__clause_then_{}", index));
 | |
|                 let clause_then_name = self
 | |
|                     .interner
 | |
|                     .lookup_interned(&format!("__clause_then_{}", index));
 | |
| 
 | |
|                 hoistables.insert(clause_then_name.clone(), (vec![], &clause.then));
 | |
| 
 | |
|                 let row = Row {
 | |
|                     assigns: assign.into_iter().collect_vec(),
 | |
|                     columns: row_items,
 | |
|                     then: clause_then_name,
 | |
|                 };
 | |
| 
 | |
|                 self.interner.pop_text(format!("__clause_then_{}", index));
 | |
| 
 | |
|                 row
 | |
|             })
 | |
|             .collect_vec();
 | |
| 
 | |
|         let mut tree = self.do_build_tree(
 | |
|             subject_name,
 | |
|             subject_tipo,
 | |
|             PatternMatrix { rows },
 | |
|             &mut hoistables,
 | |
|         );
 | |
| 
 | |
|         let scope_map = tree.get_hoist_paths(hoistables.keys().collect_vec());
 | |
| 
 | |
|         let mut name_paths = scope_map
 | |
|             .into_iter()
 | |
|             .sorted_by(|a, b| a.1.cmp(&b.1))
 | |
|             .map(|(name, path)| (name.clone(), path))
 | |
|             .collect_vec();
 | |
| 
 | |
|         tree.hoist_by_path(&mut Scope::new(), &mut name_paths, &mut hoistables);
 | |
| 
 | |
|         // Do hoisting of thens here
 | |
|         tree
 | |
|     }
 | |
| 
 | |
|     fn do_build_tree(
 | |
|         &mut self,
 | |
|         subject_name: &String,
 | |
|         subject_tipo: &Rc<Type>,
 | |
|         matrix: PatternMatrix<'a>,
 | |
|         then_map: &mut IndexMap<String, (Vec<Assigned>, &'a TypedExpr)>,
 | |
|     ) -> DecisionTree<'a> {
 | |
|         let column_length = matrix.rows[0].columns.len();
 | |
|         // First step make sure all rows have same number of columns
 | |
|         // or something went wrong
 | |
|         assert!(matrix
 | |
|             .rows
 | |
|             .iter()
 | |
|             .all(|row| { row.columns.len() == column_length }));
 | |
| 
 | |
|         let occurrence_col = highest_occurrence(&matrix, column_length);
 | |
|         // Find which column has the most important pattern
 | |
| 
 | |
|         let Some(occurrence_col) = occurrence_col else {
 | |
|             // No more patterns to match on so we grab the first default row and return that
 | |
|             let mut fallback = matrix.rows;
 | |
| 
 | |
|             let row = fallback.swap_remove(0);
 | |
| 
 | |
|             let Some((assigns, _)) = then_map.get_mut(&row.then) else {
 | |
|                 unreachable!()
 | |
|             };
 | |
| 
 | |
|             if assigns.is_empty() {
 | |
|                 *assigns = row.assigns.clone();
 | |
|             }
 | |
| 
 | |
|             return DecisionTree::HoistedLeaf(row.then, row.assigns);
 | |
|         };
 | |
| 
 | |
|         let mut longest_elems_no_tail = None;
 | |
|         let mut longest_elems_with_tail = None;
 | |
|         let mut has_list_pattern = false;
 | |
| 
 | |
|         // List patterns are special so we need more information on length
 | |
|         matrix.rows.iter().for_each(|item| {
 | |
|             let col = &item.columns[occurrence_col];
 | |
| 
 | |
|             match col.pattern {
 | |
|                 Pattern::List { elements, tail, .. } => {
 | |
|                     has_list_pattern = true;
 | |
|                     if tail.is_none() {
 | |
|                         match longest_elems_no_tail {
 | |
|                             Some(elems_count) => {
 | |
|                                 if elems_count < elements.len() {
 | |
|                                     longest_elems_no_tail = Some(elements.len());
 | |
|                                 }
 | |
|                             }
 | |
|                             None => {
 | |
|                                 longest_elems_no_tail = Some(elements.len());
 | |
|                             }
 | |
|                         }
 | |
|                     } else {
 | |
|                         match longest_elems_with_tail {
 | |
|                             Some(elems_count) => {
 | |
|                                 if elems_count < elements.len() {
 | |
|                                     longest_elems_with_tail = Some(elements.len());
 | |
|                                 }
 | |
|                             }
 | |
|                             None => {
 | |
|                                 longest_elems_with_tail = Some(elements.len());
 | |
|                             }
 | |
|                         }
 | |
|                     }
 | |
|                 }
 | |
|                 _ => (),
 | |
|             }
 | |
|         });
 | |
| 
 | |
|         let path = matrix
 | |
|             .rows
 | |
|             .get(0)
 | |
|             .unwrap()
 | |
|             .columns
 | |
|             .get(occurrence_col)
 | |
|             .map(|col| col.path.clone())
 | |
|             .unwrap_or(vec![]);
 | |
| 
 | |
|         let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
 | |
| 
 | |
|         let (default_matrix, specialized_matrices) = matrix.rows.into_iter().fold(
 | |
|             (vec![], vec![]),
 | |
|             |(mut default_matrix, mut case_matrices): (Vec<Row>, Vec<(CaseTest, Vec<Row>)>),
 | |
|              mut row| {
 | |
|                 // For example in the case of matching on []
 | |
|                 if row.columns.is_empty() {
 | |
|                     default_matrix.push(row);
 | |
|                     return (default_matrix, case_matrices);
 | |
|                 }
 | |
| 
 | |
|                 let col = row.columns.remove(occurrence_col);
 | |
| 
 | |
|                 let (case, remaining_patts) = match col.pattern {
 | |
|                     Pattern::Var { .. } | Pattern::Discard { .. } => (CaseTest::Wild, vec![]),
 | |
|                     Pattern::Int { value, .. } => (CaseTest::Int(value.clone()), vec![]),
 | |
|                     Pattern::ByteArray { value, .. } => (CaseTest::Bytes(value.clone()), vec![]),
 | |
|                     Pattern::List { elements, tail, .. } => (
 | |
|                         if tail.is_none() {
 | |
|                             CaseTest::List(elements.len())
 | |
|                         } else {
 | |
|                             CaseTest::ListWithTail(elements.len())
 | |
|                         },
 | |
|                         elements
 | |
|                             .iter()
 | |
|                             .chain(tail.as_ref().map(|tail| tail.as_ref()))
 | |
|                             .enumerate()
 | |
|                             .with_position()
 | |
|                             .map(|elem| match elem {
 | |
|                                 Position::First((index, element))
 | |
|                                 | Position::Middle((index, element))
 | |
|                                 | Position::Only((index, element)) => {
 | |
|                                     let mut item_path = col.path.clone();
 | |
| 
 | |
|                                     item_path.push(Path::List(index));
 | |
| 
 | |
|                                     self.map_pattern_to_row(element, subject_tipo, item_path)
 | |
|                                 }
 | |
| 
 | |
|                                 Position::Last((index, element)) => {
 | |
|                                     if tail.is_none() {
 | |
|                                         let mut item_path = col.path.clone();
 | |
| 
 | |
|                                         item_path.push(Path::List(index));
 | |
| 
 | |
|                                         self.map_pattern_to_row(element, subject_tipo, item_path)
 | |
|                                     } else {
 | |
|                                         let mut item_path = col.path.clone();
 | |
| 
 | |
|                                         item_path.push(Path::ListTail(index));
 | |
| 
 | |
|                                         self.map_pattern_to_row(element, subject_tipo, item_path)
 | |
|                                     }
 | |
|                                 }
 | |
|                             })
 | |
|                             .collect_vec(),
 | |
|                     ),
 | |
| 
 | |
|                     Pattern::Constructor {
 | |
|                         name,
 | |
|                         arguments,
 | |
|                         tipo,
 | |
|                         ..
 | |
|                     } => {
 | |
|                         let data_type =
 | |
|                             lookup_data_type_by_tipo(&self.data_types, &specialized_tipo).unwrap();
 | |
| 
 | |
|                         let (constr_index, _) = data_type
 | |
|                             .constructors
 | |
|                             .iter()
 | |
|                             .enumerate()
 | |
|                             .find(|(_, dt)| &dt.name == name)
 | |
|                             .unwrap();
 | |
| 
 | |
|                         (
 | |
|                             CaseTest::Constr(constr_index),
 | |
|                             arguments
 | |
|                                 .iter()
 | |
|                                 .enumerate()
 | |
|                                 .map(|(index, arg)| {
 | |
|                                     let mut item_path = col.path.clone();
 | |
| 
 | |
|                                     item_path.push(Path::Constr(tipo.clone(), index));
 | |
| 
 | |
|                                     self.map_pattern_to_row(&arg.value, subject_tipo, item_path)
 | |
|                                 })
 | |
|                                 .collect_vec(),
 | |
|                         )
 | |
|                     }
 | |
|                     Pattern::Tuple { .. } | Pattern::Pair { .. } | Pattern::Assign { .. } => {
 | |
|                         unreachable!("{:#?}", col.pattern)
 | |
|                     }
 | |
|                 };
 | |
| 
 | |
|                 // Assert path is the same for each specialized row
 | |
|                 assert!(path == col.path || matches!(case, CaseTest::Wild));
 | |
| 
 | |
|                 // expand assigns by newly added ones
 | |
|                 row.assigns
 | |
|                     .extend(remaining_patts.iter().flat_map(|x| x.0.clone()));
 | |
| 
 | |
|                 // Add inner patterns to existing row
 | |
|                 let mut new_cols = remaining_patts.into_iter().flat_map(|x| x.1).collect_vec();
 | |
| 
 | |
|                 if matches!(case, CaseTest::ListWithTail(_)) {
 | |
|                     new_cols.pop();
 | |
|                 }
 | |
| 
 | |
|                 let added_columns = new_cols.len();
 | |
| 
 | |
|                 // Pop off tail so that it aligns more easily with other list patterns
 | |
| 
 | |
|                 new_cols.extend(row.columns);
 | |
| 
 | |
|                 row.columns = new_cols;
 | |
| 
 | |
|                 if let CaseTest::Wild = case {
 | |
|                     let current_wild_cols = row.columns.len();
 | |
|                     default_matrix.push(row.clone());
 | |
| 
 | |
|                     case_matrices.iter_mut().for_each(|(_, matrix)| {
 | |
|                         let mut row = row.clone();
 | |
|                         let total_cols = matrix[0].columns.len();
 | |
| 
 | |
|                         if total_cols != 0 {
 | |
|                             let added_columns = total_cols - current_wild_cols;
 | |
| 
 | |
|                             for _ in 0..added_columns {
 | |
|                                 row.columns.insert(0, self.wild_card_pattern.clone());
 | |
|                             }
 | |
| 
 | |
|                             matrix.push(row);
 | |
|                         }
 | |
|                     });
 | |
|                 } else if let CaseTest::ListWithTail(case_length) = case {
 | |
|                     // For lists with tail it's a special case where we also add it to existing patterns
 | |
|                     // all the way to the longest element. The reason being that each list size greater
 | |
|                     // than the list with tail could also match with could also match depending on the inner pattern.
 | |
|                     // See tests below for an example
 | |
|                     if let Some(longest_elems_no_tail) = longest_elems_no_tail {
 | |
|                         for elem_count in case_length..=longest_elems_no_tail {
 | |
|                             let case = CaseTest::List(elem_count);
 | |
| 
 | |
|                             let mut row = row.clone();
 | |
| 
 | |
|                             for _ in 0..(elem_count - case_length) {
 | |
|                                 row.columns
 | |
|                                     .insert(case_length, self.wild_card_pattern.clone());
 | |
|                             }
 | |
| 
 | |
|                             self.insert_case(
 | |
|                                 &mut case_matrices,
 | |
|                                 case,
 | |
|                                 &default_matrix,
 | |
|                                 row,
 | |
|                                 added_columns,
 | |
|                             );
 | |
|                         }
 | |
|                     }
 | |
| 
 | |
|                     let Some(longest_elems_with_tail) = longest_elems_with_tail else {
 | |
|                         unreachable!()
 | |
|                     };
 | |
| 
 | |
|                     for elem_count in case_length..=longest_elems_with_tail {
 | |
|                         let case = CaseTest::ListWithTail(elem_count);
 | |
| 
 | |
|                         let mut row = row.clone();
 | |
| 
 | |
|                         for _ in 0..(elem_count - case_length) {
 | |
|                             row.columns
 | |
|                                 .insert(case_length, self.wild_card_pattern.clone());
 | |
|                         }
 | |
| 
 | |
|                         self.insert_case(
 | |
|                             &mut case_matrices,
 | |
|                             case,
 | |
|                             &default_matrix,
 | |
|                             row,
 | |
|                             added_columns,
 | |
|                         );
 | |
|                     }
 | |
|                 } else {
 | |
|                     self.insert_case(
 | |
|                         &mut case_matrices,
 | |
|                         case,
 | |
|                         &default_matrix,
 | |
|                         row,
 | |
|                         added_columns,
 | |
|                     );
 | |
|                 }
 | |
| 
 | |
|                 (default_matrix, case_matrices)
 | |
|             },
 | |
|         );
 | |
| 
 | |
|         let default_matrix = PatternMatrix {
 | |
|             rows: default_matrix,
 | |
|         };
 | |
| 
 | |
|         if has_list_pattern {
 | |
|             // Since the list_tail case might cover the rest of the possible matches extensively
 | |
|             // then fallback is optional here
 | |
|             let fallback_option = if default_matrix.rows.is_empty() {
 | |
|                 None
 | |
|             } else {
 | |
|                 Some(
 | |
|                     self.do_build_tree(
 | |
|                         subject_name,
 | |
|                         subject_tipo,
 | |
|                         // Since everything after this point had a wild card on or above
 | |
|                         // the row for the selected column in front. Then we ignore the
 | |
|                         // cases and continue to check other columns.
 | |
|                         default_matrix,
 | |
|                         then_map,
 | |
|                     )
 | |
|                     .into(),
 | |
|                 )
 | |
|             };
 | |
| 
 | |
|             let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices
 | |
|                 .into_iter()
 | |
|                 .partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_)));
 | |
| 
 | |
|             DecisionTree::ListSwitch {
 | |
|                 path,
 | |
|                 cases: cases
 | |
|                     .into_iter()
 | |
|                     .map(|x| {
 | |
|                         (
 | |
|                             x.0,
 | |
|                             self.do_build_tree(
 | |
|                                 subject_name,
 | |
|                                 subject_tipo,
 | |
|                                 PatternMatrix { rows: x.1 },
 | |
|                                 then_map,
 | |
|                             ),
 | |
|                         )
 | |
|                     })
 | |
|                     .collect_vec(),
 | |
|                 tail_cases: tail_cases
 | |
|                     .into_iter()
 | |
|                     .map(|x| {
 | |
|                         (
 | |
|                             x.0,
 | |
|                             self.do_build_tree(
 | |
|                                 subject_name,
 | |
|                                 subject_tipo,
 | |
|                                 PatternMatrix { rows: x.1 },
 | |
|                                 then_map,
 | |
|                             ),
 | |
|                         )
 | |
|                     })
 | |
|                     .collect_vec(),
 | |
|                 default: fallback_option,
 | |
|             }
 | |
|         } else {
 | |
|             let fallback_option = if default_matrix.rows.is_empty() {
 | |
|                 None
 | |
|             } else {
 | |
|                 Some(
 | |
|                     self.do_build_tree(
 | |
|                         subject_name,
 | |
|                         subject_tipo,
 | |
|                         // Since everything after this point had a wild card on or above
 | |
|                         // the row for the selected column in front. Then we ignore the
 | |
|                         // cases and continue to check other columns.
 | |
|                         default_matrix,
 | |
|                         then_map,
 | |
|                     )
 | |
|                     .into(),
 | |
|                 )
 | |
|             };
 | |
| 
 | |
|             DecisionTree::Switch {
 | |
|                 path,
 | |
|                 cases: specialized_matrices
 | |
|                     .into_iter()
 | |
|                     .map(|x| {
 | |
|                         (
 | |
|                             x.0,
 | |
|                             self.do_build_tree(
 | |
|                                 subject_name,
 | |
|                                 subject_tipo,
 | |
|                                 PatternMatrix { rows: x.1 },
 | |
|                                 then_map,
 | |
|                             ),
 | |
|                         )
 | |
|                     })
 | |
|                     .collect_vec(),
 | |
|                 default: fallback_option.into(),
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn map_pattern_to_row(
 | |
|         &self,
 | |
|         pattern: &'a TypedPattern,
 | |
|         subject_tipo: &Rc<Type>,
 | |
|         path: Vec<Path>,
 | |
|     ) -> (Vec<Assigned>, Vec<RowItem<'a>>) {
 | |
|         let current_tipo = get_tipo_by_path(subject_tipo.clone(), &path);
 | |
| 
 | |
|         let new_columns_added = if current_tipo.is_pair() {
 | |
|             PAIR_NEW_COLUMNS
 | |
|         } else if current_tipo.is_tuple() {
 | |
|             let Type::Tuple { elems, .. } = current_tipo.as_ref() else {
 | |
|                 unreachable!("{:#?}", current_tipo)
 | |
|             };
 | |
|             elems.len()
 | |
|         } else if let Some(data) = lookup_data_type_by_tipo(self.data_types, ¤t_tipo) {
 | |
|             if data.constructors.len() == 1 {
 | |
|                 data.constructors[0].arguments.len()
 | |
|             } else if data.is_never() {
 | |
|                 0
 | |
|             } else {
 | |
|                 MIN_NEW_COLUMNS
 | |
|             }
 | |
|         } else {
 | |
|             MIN_NEW_COLUMNS
 | |
|         };
 | |
| 
 | |
|         match pattern {
 | |
|             Pattern::Var { name, .. } => (
 | |
|                 vec![Assigned {
 | |
|                     path: path.clone(),
 | |
|                     assigned: name.clone(),
 | |
|                 }],
 | |
|                 vec![RowItem {
 | |
|                     pattern,
 | |
|                     path: path.clone(),
 | |
|                 }]
 | |
|                 .into_iter()
 | |
|                 .cycle()
 | |
|                 .take(new_columns_added)
 | |
|                 .collect_vec(),
 | |
|             ),
 | |
| 
 | |
|             Pattern::Assign { name, pattern, .. } => (
 | |
|                 vec![Assigned {
 | |
|                     path: path.clone(),
 | |
|                     assigned: name.clone(),
 | |
|                 }],
 | |
|                 self.map_pattern_to_row(pattern, subject_tipo, path).1,
 | |
|             ),
 | |
|             Pattern::Int { .. }
 | |
|             | Pattern::ByteArray { .. }
 | |
|             | Pattern::Discard { .. }
 | |
|             | Pattern::List { .. } => (
 | |
|                 vec![],
 | |
|                 vec![RowItem {
 | |
|                     pattern,
 | |
|                     path: path.clone(),
 | |
|                 }]
 | |
|                 .into_iter()
 | |
|                 .cycle()
 | |
|                 .take(new_columns_added)
 | |
|                 .collect_vec(),
 | |
|             ),
 | |
| 
 | |
|             Pattern::Constructor {
 | |
|                 arguments, tipo, ..
 | |
|             } => {
 | |
|                 let data_type = lookup_data_type_by_tipo(self.data_types, ¤t_tipo).unwrap();
 | |
| 
 | |
|                 let is_transparent =
 | |
|                     data_type.opaque && data_type.constructors[0].arguments.len() == 1;
 | |
| 
 | |
|                 if data_type.constructors.len() == 1 || data_type.is_never() {
 | |
|                     arguments
 | |
|                         .iter()
 | |
|                         .enumerate()
 | |
|                         .fold((vec![], vec![]), |mut acc, (index, arg)| {
 | |
|                             let arg_value = &arg.value;
 | |
| 
 | |
|                             let mut item_path = path.clone();
 | |
| 
 | |
|                             if is_transparent {
 | |
|                                 item_path.push(Path::OpaqueConstr(tipo.clone()));
 | |
|                             } else {
 | |
|                                 item_path.push(Path::Constr(tipo.clone(), index));
 | |
|                             }
 | |
| 
 | |
|                             let (assigns, patts) =
 | |
|                                 self.map_pattern_to_row(arg_value, subject_tipo, item_path);
 | |
| 
 | |
|                             acc.0.extend(assigns);
 | |
|                             acc.1.extend(patts);
 | |
| 
 | |
|                             acc
 | |
|                         })
 | |
|                 } else {
 | |
|                     (
 | |
|                         vec![],
 | |
|                         vec![RowItem {
 | |
|                             pattern,
 | |
|                             path: path.clone(),
 | |
|                         }]
 | |
|                         .into_iter()
 | |
|                         .cycle()
 | |
|                         .take(new_columns_added)
 | |
|                         .collect_vec(),
 | |
|                     )
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             Pattern::Pair { fst, snd, .. } => {
 | |
|                 let mut fst_path = path.clone();
 | |
|                 fst_path.push(Path::Pair(0));
 | |
|                 let mut snd_path = path;
 | |
|                 snd_path.push(Path::Pair(1));
 | |
| 
 | |
|                 let (mut assigns, mut patts) = self.map_pattern_to_row(fst, subject_tipo, fst_path);
 | |
| 
 | |
|                 let (assign_snd, patt_snd) = self.map_pattern_to_row(snd, subject_tipo, snd_path);
 | |
| 
 | |
|                 assigns.extend(assign_snd);
 | |
| 
 | |
|                 patts.extend(patt_snd);
 | |
| 
 | |
|                 (assigns, patts)
 | |
|             }
 | |
|             Pattern::Tuple { elems, .. } => {
 | |
|                 elems
 | |
|                     .iter()
 | |
|                     .enumerate()
 | |
|                     .fold((vec![], vec![]), |mut acc, (index, item)| {
 | |
|                         let mut item_path = path.clone();
 | |
| 
 | |
|                         item_path.push(Path::Tuple(index));
 | |
| 
 | |
|                         let (assigns, patts) =
 | |
|                             self.map_pattern_to_row(item, subject_tipo, item_path);
 | |
| 
 | |
|                         acc.0.extend(assigns);
 | |
|                         acc.1.extend(patts);
 | |
| 
 | |
|                         acc
 | |
|                     })
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn insert_case(
 | |
|         &self,
 | |
|         case_matrices: &mut Vec<(CaseTest, Vec<Row<'a>>)>,
 | |
|         case: CaseTest,
 | |
|         default_matrix: &Vec<Row<'a>>,
 | |
|         new_row: Row<'a>,
 | |
|         added_columns: usize,
 | |
|     ) {
 | |
|         if let Some(entry) = case_matrices.iter_mut().find(|item| item.0 == case) {
 | |
|             entry.1.push(new_row);
 | |
|         } else {
 | |
|             let mut rows = default_matrix.clone();
 | |
| 
 | |
|             for _ in 0..added_columns {
 | |
|                 for row in &mut rows {
 | |
|                     row.columns.insert(0, self.wild_card_pattern.clone());
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             rows.push(new_row);
 | |
|             case_matrices.push((case, rows));
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub fn get_tipo_by_path(mut subject_tipo: Rc<Type>, mut path: &[Path]) -> Rc<Type> {
 | |
|     while let Some((p, rest)) = path.split_first() {
 | |
|         subject_tipo = match p {
 | |
|             Path::Pair(index) | Path::Tuple(index) => {
 | |
|                 subject_tipo.get_inner_types().swap_remove(*index)
 | |
|             }
 | |
|             Path::List(_) => subject_tipo.get_inner_types().swap_remove(0),
 | |
|             Path::ListTail(_) => subject_tipo,
 | |
|             Path::Constr(tipo, index) => tipo.arg_types().unwrap().swap_remove(*index),
 | |
|             Path::OpaqueConstr(tipo) => {
 | |
|                 let x = tipo.arg_types().unwrap().swap_remove(0);
 | |
|                 x
 | |
|             }
 | |
|         };
 | |
| 
 | |
|         path = rest
 | |
|     }
 | |
|     match subject_tipo.as_ref() {
 | |
|         Type::Var { tipo, .. } => match &*tipo.borrow() {
 | |
|             TypeVar::Unbound { .. } | TypeVar::Generic { .. } => subject_tipo.clone(),
 | |
|             TypeVar::Link { tipo } => get_tipo_by_path(tipo.clone(), &[]),
 | |
|         },
 | |
|         _ => subject_tipo,
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn match_wild_card(pattern: &TypedPattern) -> bool {
 | |
|     match pattern {
 | |
|         Pattern::Var { .. } | Pattern::Discard { .. } => true,
 | |
|         Pattern::Assign { pattern, .. } => match_wild_card(pattern),
 | |
|         _ => false,
 | |
|     }
 | |
| }
 | |
| 
 | |
| // A function to get which column has the most pattern matches before a wild card
 | |
| // Returns none if all columns in the first row are wild cards
 | |
| fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> Option<usize> {
 | |
|     let occurrences = [Occurrence::default()].repeat(column_length);
 | |
| 
 | |
|     let occurrences =
 | |
|         matrix
 | |
|             .rows
 | |
|             .iter()
 | |
|             .fold(occurrences, |mut occurrences: Vec<Occurrence>, row| {
 | |
|                 row.columns
 | |
|                     .iter()
 | |
|                     .enumerate()
 | |
|                     .for_each(|(column_index, row_item)| {
 | |
|                         let Some(occurrence_col) = occurrences.get_mut(column_index) else {
 | |
|                             unreachable!()
 | |
|                         };
 | |
|                         if !match_wild_card(row_item.pattern) && !occurrence_col.passed_wild_card {
 | |
|                             occurrence_col.amount += 1;
 | |
|                         } else {
 | |
|                             occurrence_col.passed_wild_card = true;
 | |
|                         }
 | |
|                     });
 | |
| 
 | |
|                 occurrences
 | |
|             });
 | |
| 
 | |
|     // index and count
 | |
|     let mut highest_occurrence = (0, 0);
 | |
| 
 | |
|     occurrences.iter().enumerate().for_each(|(index, occ)| {
 | |
|         if occ.amount > highest_occurrence.1 {
 | |
|             highest_occurrence.0 = index;
 | |
|             highest_occurrence.1 = occ.amount;
 | |
|         }
 | |
|     });
 | |
| 
 | |
|     if highest_occurrence.1 == 0 {
 | |
|         None
 | |
|     } else {
 | |
|         Some(highest_occurrence.0)
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[cfg(test)]
 | |
| mod tester {
 | |
|     use std::collections::HashMap;
 | |
| 
 | |
|     use indexmap::IndexMap;
 | |
| 
 | |
|     use crate::{
 | |
|         ast::{
 | |
|             Definition, ModuleKind, Span, TraceLevel, Tracing, TypedModule, TypedPattern,
 | |
|             UntypedModule,
 | |
|         },
 | |
|         builtins,
 | |
|         expr::TypedExpr,
 | |
|         gen_uplc::{decision_tree::TreeGen, interner::AirInterner},
 | |
|         parser,
 | |
|         tipo::error::{Error, Warning},
 | |
|         utils, IdGenerator,
 | |
|     };
 | |
| 
 | |
|     fn parse(source_code: &str) -> UntypedModule {
 | |
|         let kind = ModuleKind::Lib;
 | |
|         let (ast, _) = parser::module(source_code, kind).expect("Failed to parse module");
 | |
|         ast
 | |
|     }
 | |
| 
 | |
|     fn check_module(
 | |
|         ast: UntypedModule,
 | |
|         extra: Vec<(String, UntypedModule)>,
 | |
|         kind: ModuleKind,
 | |
|         tracing: Tracing,
 | |
|     ) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
 | |
|         let id_gen = IdGenerator::new();
 | |
| 
 | |
|         let mut warnings = vec![];
 | |
| 
 | |
|         let mut module_types = HashMap::new();
 | |
|         module_types.insert("aiken".to_string(), builtins::prelude(&id_gen));
 | |
|         module_types.insert("aiken/builtin".to_string(), builtins::plutus(&id_gen));
 | |
| 
 | |
|         for (package, module) in extra {
 | |
|             let mut warnings = vec![];
 | |
|             let typed_module = module
 | |
|                 .infer(
 | |
|                     &id_gen,
 | |
|                     kind,
 | |
|                     &package,
 | |
|                     &module_types,
 | |
|                     Tracing::All(TraceLevel::Verbose),
 | |
|                     &mut warnings,
 | |
|                     None,
 | |
|                 )
 | |
|                 .expect("extra dependency did not compile");
 | |
|             module_types.insert(package.clone(), typed_module.type_info.clone());
 | |
|         }
 | |
| 
 | |
|         let result = ast.infer(
 | |
|             &id_gen,
 | |
|             kind,
 | |
|             "test/project",
 | |
|             &module_types,
 | |
|             tracing,
 | |
|             &mut warnings,
 | |
|             None,
 | |
|         );
 | |
| 
 | |
|         result
 | |
|             .map(|o| (warnings.clone(), o))
 | |
|             .map_err(|e| (warnings, e))
 | |
|     }
 | |
| 
 | |
|     fn check(ast: UntypedModule) -> Result<(Vec<Warning>, TypedModule), (Vec<Warning>, Error)> {
 | |
|         check_module(ast, Vec::new(), ModuleKind::Lib, Tracing::verbose())
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn thing() {
 | |
|         let source_code = r#"
 | |
|             test thing(){
 | |
|                 when [1, 2, 3] is {
 | |
|                   [] -> False
 | |
|                   [4] -> fail
 | |
|                   [a, 2, b] -> True
 | |
|                   _ -> False
 | |
|                 }
 | |
|             }
 | |
|         "#;
 | |
| 
 | |
|         let (_, ast) = check(parse(source_code)).unwrap();
 | |
| 
 | |
|         let Definition::Test(function) = &ast.definitions[0] else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let TypedExpr::When {
 | |
|             clauses, subject, ..
 | |
|         } = &function.body
 | |
|         else {
 | |
|             panic!()
 | |
|         };
 | |
|         let mut air_interner = AirInterner::new();
 | |
| 
 | |
|         let data_types = IndexMap::new();
 | |
| 
 | |
|         let pattern = TypedPattern::Discard {
 | |
|             name: "_".to_string(),
 | |
|             location: Span::empty(),
 | |
|         };
 | |
| 
 | |
|         let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
 | |
| 
 | |
|         let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
 | |
| 
 | |
|         println!("{}", tree);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn thing2() {
 | |
|         let source_code = r#"
 | |
|             test thing(){
 | |
|                 when (1,2,#"",[]) is {
 | |
|                   (a,b,#"", []) -> True
 | |
|                   (1,b,#"", [1]) -> False
 | |
|                   (3,b,#"aa", _) -> 2 == 2
 | |
|                   _ -> 1 == 1
 | |
|                 }
 | |
|             }
 | |
|         "#;
 | |
| 
 | |
|         let (_, ast) = check(parse(source_code)).unwrap();
 | |
| 
 | |
|         let Definition::Test(function) = &ast.definitions[0] else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let TypedExpr::When {
 | |
|             clauses, subject, ..
 | |
|         } = &function.body
 | |
|         else {
 | |
|             panic!()
 | |
|         };
 | |
|         let mut air_interner = AirInterner::new();
 | |
| 
 | |
|         let data_types = IndexMap::new();
 | |
| 
 | |
|         let pattern = TypedPattern::Discard {
 | |
|             name: "_".to_string(),
 | |
|             location: Span::empty(),
 | |
|         };
 | |
| 
 | |
|         let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
 | |
| 
 | |
|         let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
 | |
| 
 | |
|         println!("{}", tree);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn thing3() {
 | |
|         let source_code = r#"
 | |
|             test thing(){
 | |
|                 when (1,2,#"",[]) is {
 | |
|                   (2,b,#"", []) -> 4 == 4
 | |
|                   (a,b,#"", [2, ..y]) -> True
 | |
|                   (1,b,#"", [a]) -> False
 | |
|                   (3,b,#"aa", [x, y, ..z]) -> 2 == 2
 | |
|                   _ -> 1 == 1
 | |
|                 }
 | |
|             }
 | |
|         "#;
 | |
| 
 | |
|         let (_, ast) = check(parse(source_code)).unwrap();
 | |
| 
 | |
|         let Definition::Test(function) = &ast.definitions[0] else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let TypedExpr::When {
 | |
|             clauses, subject, ..
 | |
|         } = &function.body
 | |
|         else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let mut air_interner = AirInterner::new();
 | |
| 
 | |
|         let data_types = IndexMap::new();
 | |
| 
 | |
|         let pattern = TypedPattern::Discard {
 | |
|             name: "_".to_string(),
 | |
|             location: Span::empty(),
 | |
|         };
 | |
| 
 | |
|         let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
 | |
| 
 | |
|         let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
 | |
| 
 | |
|         println!("{}", tree);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn thing4() {
 | |
|         let source_code = r#"
 | |
|             test thing(){
 | |
|                 when (1,2,#"",[]) is {
 | |
|                   (2,b,#"", []) -> 4 == 4
 | |
|                   (a,b,#"", [2, ..y]) -> True
 | |
|                   (1,b,#"", [a]) -> False
 | |
|                   (3,b,#"aa", [x, y, ..z]) -> 2 == 2
 | |
|                   (3,b, c, [x, 3 as q]) -> fail
 | |
|                   _ -> 1 == 1
 | |
|                 }
 | |
|             }
 | |
|         "#;
 | |
| 
 | |
|         let (_, ast) = check(parse(source_code)).unwrap();
 | |
| 
 | |
|         let Definition::Test(function) = &ast.definitions[0] else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let TypedExpr::When {
 | |
|             clauses, subject, ..
 | |
|         } = &function.body
 | |
|         else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let mut air_interner = AirInterner::new();
 | |
| 
 | |
|         let data_types = IndexMap::new();
 | |
| 
 | |
|         let pattern = TypedPattern::Discard {
 | |
|             name: "_".to_string(),
 | |
|             location: Span::empty(),
 | |
|         };
 | |
| 
 | |
|         let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
 | |
| 
 | |
|         let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
 | |
| 
 | |
|         println!("{}", tree);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn thing5() {
 | |
|         let source_code = r#"
 | |
|             test thing(){
 | |
|                 when (1,[],#"",None) is {
 | |
|                   (2,b,#"", Some(Seeded { choices: #"", .. })) -> 4 == 4
 | |
|                   (a,b,#"", None) -> True
 | |
|                   (1,b,#"", Some(Seeded{ choices: #"", ..})) -> False
 | |
|                   (3,b,#"aa", Some(Replayed(..))) -> 2 == 2
 | |
|                   (3,b, c, y) -> fail
 | |
|                   _ -> 1 == 1
 | |
|                 }
 | |
|             }
 | |
|         "#;
 | |
| 
 | |
|         let (_, ast) = check(parse(source_code)).unwrap();
 | |
| 
 | |
|         let Definition::Test(function) = &ast.definitions[0] else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let TypedExpr::When {
 | |
|             clauses, subject, ..
 | |
|         } = &function.body
 | |
|         else {
 | |
|             panic!()
 | |
|         };
 | |
| 
 | |
|         let mut air_interner = AirInterner::new();
 | |
| 
 | |
|         let id_gen = IdGenerator::new();
 | |
| 
 | |
|         let data_types = builtins::prelude_data_types(&id_gen);
 | |
| 
 | |
|         let pattern = TypedPattern::Discard {
 | |
|             name: "_".to_string(),
 | |
|             location: Span::empty(),
 | |
|         };
 | |
| 
 | |
|         let data_types = utils::indexmap::as_ref_values(&data_types);
 | |
| 
 | |
|         let tree_gen = TreeGen::new(&mut air_interner, &data_types, &pattern);
 | |
| 
 | |
|         let tree = tree_gen.build_tree(&"subject".to_string(), &subject.tipo(), clauses);
 | |
| 
 | |
|         println!("{}", tree);
 | |
|     }
 | |
| }
 |