nested lists now work

This commit is contained in:
Kasey White 2022-12-24 04:32:14 -05:00 committed by Lucas
parent 6126ee4cb4
commit b7d506a8db
4 changed files with 281 additions and 105 deletions

View File

@ -134,6 +134,7 @@ pub enum Air {
tail_name: String,
next_tail_name: Option<String>,
complex_clause: bool,
inverse: bool,
},
TupleClause {
@ -150,7 +151,6 @@ pub enum Air {
scope: Vec<u64>,
subject_name: String,
tipo: Arc<Type>,
invert: bool,
},
Discard {

View File

@ -443,7 +443,7 @@ pub fn rearrange_clauses(
let mut final_clauses = sorted_clauses.clone();
let mut holes_to_fill = vec![];
let mut assign_plug_in_name = None;
let mut last_clause_index = sorted_clauses.len() - 1;
let mut last_clause_index = 0;
let mut last_clause_set = false;
// If we have a catch all, use that. Otherwise use todo which will result in error
@ -454,10 +454,10 @@ pub fn rearrange_clauses(
sorted_clauses[sorted_clauses.len() - 1].clone().then
}
Pattern::Discard { .. } => sorted_clauses[sorted_clauses.len() - 1].clone().then,
_ => TypedExpr::Todo {
_ => TypedExpr::ErrorTerm {
location: Span::empty(),
label: None,
tipo: sorted_clauses[sorted_clauses.len() - 1].then.tipo(),
label: Some("Clause not filled".to_string()),
},
};
@ -512,49 +512,47 @@ pub fn rearrange_clauses(
}
// if we have a pattern with no clause guards and a tail then no lists will get past here to other clauses
if let Pattern::List {
elements,
tail: Some(tail),
..
} = &clause.pattern[0]
{
let mut elements = elements.clone();
elements.push(*tail.clone());
if elements
.iter()
.all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. }))
&& !last_clause_set
{
match &clause.pattern[0] {
Pattern::Var { .. } => {
last_clause_index = index;
last_clause_set = true;
}
Pattern::Discard { .. } => {
last_clause_index = index;
last_clause_set = true;
}
Pattern::List {
elements,
tail: Some(tail),
..
} => {
let mut elements = elements.clone();
elements.push(*tail.clone());
if elements
.iter()
.all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. }))
&& !last_clause_set
{
last_clause_index = index + 1;
last_clause_set = true;
}
}
_ => {}
}
// If the last condition doesn't have a catch all or tail then add a catch all with a todo
if index == sorted_clauses.len() - 1 {
if let Pattern::List {
elements,
tail: Some(tail),
..
} = &clause.pattern[0]
{
let mut elements = elements.clone();
elements.push(*tail.clone());
if !elements
.iter()
.all(|element| matches!(element, Pattern::Var { .. } | Pattern::Discard { .. }))
{
final_clauses.push(Clause {
if let Pattern::List { tail: None, .. } = &clause.pattern[0] {
final_clauses.push(Clause {
location: Span::empty(),
pattern: vec![Pattern::Discard {
name: "_".to_string(),
location: Span::empty(),
pattern: vec![Pattern::Discard {
name: "_".to_string(),
location: Span::empty(),
}],
alternative_patterns: vec![],
guard: None,
then: plug_in_then.clone(),
});
}
}],
alternative_patterns: vec![],
guard: None,
then: plug_in_then.clone(),
});
}
}
@ -562,7 +560,9 @@ pub fn rearrange_clauses(
}
// Encountered a tail so stop there with that as last clause
final_clauses = final_clauses[0..(last_clause_index + 1)].to_vec();
if last_clause_set {
final_clauses = final_clauses[0..last_clause_index].to_vec();
}
// insert hole fillers into clauses
for (index, clause) in holes_to_fill.into_iter().rev() {
@ -1201,6 +1201,7 @@ pub fn monomorphize(
tail_name,
complex_clause,
next_tail_name,
inverse,
} => {
if tipo.is_generic() {
let mut tipo = tipo.clone();
@ -1212,6 +1213,7 @@ pub fn monomorphize(
tail_name,
complex_clause,
next_tail_name,
inverse,
};
needs_variant = false;
}
@ -1220,7 +1222,6 @@ pub fn monomorphize(
tipo,
scope,
subject_name,
invert,
} => {
if tipo.is_generic() {
let mut tipo = tipo.clone();
@ -1230,7 +1231,6 @@ pub fn monomorphize(
scope,
subject_name,
tipo,
invert,
};
needs_variant = false;
}

View File

@ -5,7 +5,7 @@ use itertools::Itertools;
use uplc::{
ast::{
builder::{
self, constr_index_exposer, delayed_choose_list, delayed_if_else, if_else,
self, choose_list, constr_index_exposer, delayed_choose_list, delayed_if_else, if_else,
CONSTR_FIELDS_EXPOSER, CONSTR_GET_FIELD,
},
Constant as UplcConstant, Name, NamedDeBruijn, Program, Term, Type as UplcType,
@ -621,9 +621,17 @@ impl<'a> CodeGenerator<'a> {
scope,
tipo: subject_type.clone(),
tail_name: subject_name,
complex_clause: *clause_properties.is_complex_clause(),
next_tail_name: next_tail,
complex_clause: *clause_properties.is_complex_clause(),
inverse: false,
});
match clause_properties {
ClauseProperties::ListClause { current_index, .. } => {
*current_index += 1;
}
_ => unreachable!(),
}
}
ClauseProperties::TupleClause {
original_subject_name,
@ -763,13 +771,11 @@ impl<'a> CodeGenerator<'a> {
self.when_recursive_ir(
pattern,
pattern_vec,
&mut vec![],
values,
clause_properties.clone(),
tipo,
scope,
);
pattern_vec.append(values);
}
Pattern::Constructor {
arguments,
@ -890,8 +896,6 @@ impl<'a> CodeGenerator<'a> {
pattern_vec.append(values);
}
Pattern::List { elements, tail, .. } => {
// let mut elements_vec = vec![];
let mut names = vec![];
let mut nested_pattern = vec![];
let items_type = &tipo.get_inner_types()[0];
@ -957,8 +961,8 @@ impl<'a> CodeGenerator<'a> {
});
}
pattern_vec.append(values);
pattern_vec.append(&mut nested_pattern);
pattern_vec.append(values);
}
Pattern::Constructor {
is_record,
@ -1011,7 +1015,11 @@ impl<'a> CodeGenerator<'a> {
.iter()
.filter_map(|item| {
let label = item.label.clone().unwrap_or_default();
let field_index = field_map.fields.get(&label).unwrap_or(&0);
let field_index = field_map
.fields
.get(&label)
.map(|(index, _)| index)
.unwrap_or(&0);
let var_name = self.nested_pattern_ir_and_label(
&item.value,
&mut nested_pattern,
@ -1095,30 +1103,173 @@ impl<'a> CodeGenerator<'a> {
Pattern::Var { name, .. } => Some(name.clone()),
Pattern::Discard { .. } => None,
a @ Pattern::List { elements, tail, .. } => {
let item_name = format!("list_item_id_{}", self.id_gen.next());
let item_name = format!("__list_item_id_{}", self.id_gen.next());
let new_tail_name = "__list_tail".to_string();
if elements.is_empty() {
pattern_vec.push(Air::ClauseGuard {
pattern_vec.push(Air::ListClause {
scope: scope.clone(),
subject_name: item_name.clone(),
tipo: pattern_type.clone(),
invert: false,
tail_name: item_name.clone(),
next_tail_name: None,
complex_clause: false,
inverse: true,
});
pattern_vec.push(Air::Discard {
scope: scope.clone(),
});
pattern_vec.push(Air::Var {
scope,
constructor: ValueConstructor::public(
pattern_type.clone(),
ValueConstructorVariant::LocalVariable {
location: Span::empty(),
},
),
name: "__other_clauses_delayed".to_string(),
variant_name: String::new(),
});
} else {
for (index, element) in elements.iter().enumerate() {
if index == 0 {
pattern_vec.push(Air::ClauseGuard {
scope: scope.clone(),
subject_name: item_name.clone(),
tipo: pattern_type.clone(),
invert: true,
});
pattern_vec.push(Air::ListAccessor {
for (index, _) in elements.iter().enumerate() {
let prev_tail_name = if index == 0 {
item_name.clone()
} else {
format!("{}_{}", new_tail_name, index - 1)
};
let mut clause_properties = ClauseProperties::ListClause {
clause_var_name: item_name.clone(),
needs_constr_var: false,
is_complex_clause: false,
original_subject_name: item_name.clone(),
current_index: index,
};
let tail_name = format!("{}_{}", new_tail_name, index);
if elements.len() - 1 == index {
if tail.is_some() {
let tail_name = match *tail.clone().unwrap() {
Pattern::Var { name, .. } => name,
Pattern::Discard { .. } => "_".to_string(),
_ => unreachable!(),
};
pattern_vec.push(Air::ListClause {
scope: scope.clone(),
tipo: pattern_type.clone(),
tail_name: prev_tail_name,
next_tail_name: Some(tail_name),
complex_clause: false,
inverse: false,
});
pattern_vec.push(Air::Discard {
scope: scope.clone(),
});
pattern_vec.push(Air::Var {
scope: scope.clone(),
constructor: ValueConstructor::public(
pattern_type.clone(),
ValueConstructorVariant::LocalVariable {
location: Span::empty(),
},
),
name: "__other_clauses_delayed".to_string(),
variant_name: "".to_string(),
});
self.when_ir(
a,
pattern_vec,
&mut vec![],
pattern_type,
&mut clause_properties,
scope.clone(),
);
} else {
pattern_vec.push(Air::ListClause {
scope: scope.clone(),
tipo: pattern_type.clone(),
tail_name: prev_tail_name,
next_tail_name: Some(tail_name.clone()),
complex_clause: false,
inverse: false,
});
pattern_vec.push(Air::Discard {
scope: scope.clone(),
});
pattern_vec.push(Air::Var {
scope: scope.clone(),
constructor: ValueConstructor::public(
pattern_type.clone(),
ValueConstructorVariant::LocalVariable {
location: Span::empty(),
},
),
name: "__other_clauses_delayed".to_string(),
variant_name: String::new(),
});
pattern_vec.push(Air::ListClause {
scope: scope.clone(),
tipo: pattern_type.clone(),
tail_name: tail_name.clone(),
next_tail_name: None,
complex_clause: false,
inverse: true,
});
pattern_vec.push(Air::Discard {
scope: scope.clone(),
});
pattern_vec.push(Air::Var {
scope: scope.clone(),
constructor: ValueConstructor::public(
pattern_type.clone(),
ValueConstructorVariant::LocalVariable {
location: Span::empty(),
},
),
name: "__other_clauses_delayed".to_string(),
variant_name: String::new(),
});
self.when_ir(
a,
pattern_vec,
&mut vec![],
pattern_type,
&mut clause_properties,
scope.clone(),
);
}
} else {
let tail_name = match *tail.clone().unwrap() {
Pattern::Var { name, .. } => name,
Pattern::Discard { .. } => "_".to_string(),
_ => unreachable!(),
};
pattern_vec.push(Air::ListClause {
scope: scope.clone(),
tipo: pattern_type.clone(),
names: vec![format!("todo")],
tail: false,
tail_name: prev_tail_name,
next_tail_name: Some(tail_name),
complex_clause: false,
inverse: false,
});
pattern_vec.push(Air::Discard {
scope: scope.clone(),
});
pattern_vec.push(Air::Var {
scope: scope.clone(),
constructor: ValueConstructor::public(
@ -1127,30 +1278,22 @@ impl<'a> CodeGenerator<'a> {
location: Span::empty(),
},
),
name: item_name.clone(),
variant_name: String::new(),
name: "__other_clauses_delayed".to_string(),
variant_name: "".to_string(),
});
}
self.when_ir(
a,
pattern_vec,
&mut vec![],
pattern_type,
&mut clause_properties,
scope.clone(),
);
};
}
}
let mut clause_properties = ClauseProperties::ListClause {
clause_var_name: item_name.clone(),
needs_constr_var: false,
is_complex_clause: false,
original_subject_name: item_name.clone(),
current_index: 0,
};
self.when_ir(
a,
pattern_vec,
&mut vec![],
pattern_type,
&mut clause_properties,
scope,
);
// self.when_recursive_ir(a);
Some(item_name)
}
@ -1183,7 +1326,6 @@ impl<'a> CodeGenerator<'a> {
scope: scope.clone(),
tipo: tipo.clone(),
subject_name: constr_var_name.clone(),
invert: false,
});
}
@ -1200,7 +1342,7 @@ impl<'a> CodeGenerator<'a> {
&mut vec![],
tipo,
&mut clause_properties,
scope.clone(),
scope,
);
Some(constr_var_name)
@ -3196,16 +3338,26 @@ impl<'a> CodeGenerator<'a> {
Air::ListClause {
tail_name,
next_tail_name,
inverse,
complex_clause,
..
} => {
// discard to pop off
let _ = arg_stack.pop().unwrap();
// the body to be run if the clause matches
let body = arg_stack.pop().unwrap();
// the next branch in the when expression
let mut term = arg_stack.pop().unwrap();
let (body, mut term) = if inverse {
let term = arg_stack.pop().unwrap();
let body = arg_stack.pop().unwrap();
(body, term)
} else {
let body = arg_stack.pop().unwrap();
let term = arg_stack.pop().unwrap();
(body, term)
};
let arg = if let Some(next_tail_name) = next_tail_name {
Term::Apply {
@ -3231,22 +3383,46 @@ impl<'a> CodeGenerator<'a> {
term
};
term = delayed_choose_list(
Term::Var(Name {
text: tail_name,
unique: 0.into(),
}),
body,
arg,
);
if complex_clause {
term = choose_list(
Term::Var(Name {
text: tail_name,
unique: 0.into(),
}),
Term::Delay(body.into()),
Term::Var(Name {
text: "__other_clauses_delayed".to_string(),
unique: 0.into(),
}),
)
.force_wrap();
term = Term::Apply {
function: Term::Lambda {
parameter_name: Name {
text: "__other_clauses_delayed".into(),
unique: 0.into(),
},
body: term.into(),
}
.into(),
argument: Term::Delay(arg.into()).into(),
};
} else {
term = delayed_choose_list(
Term::Var(Name {
text: tail_name,
unique: 0.into(),
}),
body,
arg,
);
}
arg_stack.push(term);
}
Air::ClauseGuard {
subject_name,
tipo,
invert,
..
subject_name, tipo, ..
} => {
let condition = arg_stack.pop().unwrap();

View File

@ -314,7 +314,7 @@ pub fn delayed_choose_list(
Term::Apply {
function: Term::Apply {
function: Term::Apply {
function: Term::Builtin(DefaultFunction::IfThenElse)
function: Term::Builtin(DefaultFunction::ChooseList)
.force_wrap()
.force_wrap()
.into(),
@ -337,7 +337,7 @@ pub fn choose_list(
Term::Apply {
function: Term::Apply {
function: Term::Apply {
function: Term::Builtin(DefaultFunction::IfThenElse)
function: Term::Builtin(DefaultFunction::ChooseList)
.force_wrap()
.force_wrap()
.into(),