feat: implement bls primitives in code gen

This commit is contained in:
microproofs 2023-11-07 14:14:33 -05:00 committed by Lucas
parent d51374aac1
commit 8b89ba3b93
11 changed files with 70 additions and 16 deletions

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
builtins::{self, bool}, builtins::{self, bool, g1_element, g2_element},
expr::{TypedExpr, UntypedExpr}, expr::{TypedExpr, UntypedExpr},
parser::token::{Base, Token}, parser::token::{Base, Token},
tipo::{PatternConstructor, Type, TypeInfo}, tipo::{PatternConstructor, Type, TypeInfo},
@ -496,7 +496,7 @@ pub enum Constant {
CurvePoint { CurvePoint {
location: Span, location: Span,
point: Curve, point: Box<Curve>,
preferred_format: ByteArrayFormatPreference, preferred_format: ByteArrayFormatPreference,
}, },
} }
@ -507,7 +507,7 @@ impl Constant {
Constant::Int { .. } => builtins::int(), Constant::Int { .. } => builtins::int(),
Constant::String { .. } => builtins::string(), Constant::String { .. } => builtins::string(),
Constant::ByteArray { .. } => builtins::byte_array(), Constant::ByteArray { .. } => builtins::byte_array(),
Constant::CurvePoint { point, .. } => match point { Constant::CurvePoint { point, .. } => match point.as_ref() {
Curve::Bls12_381(Bls12_381Point::G1(_)) => builtins::g1_element(), Curve::Bls12_381(Bls12_381Point::G1(_)) => builtins::g1_element(),
Curve::Bls12_381(Bls12_381Point::G2(_)) => builtins::g2_element(), Curve::Bls12_381(Bls12_381Point::G2(_)) => builtins::g2_element(),
}, },
@ -1067,6 +1067,12 @@ impl Curve {
}, },
} }
} }
pub fn tipo(&self) -> Rc<Type> {
match self {
Curve::Bls12_381(point) => point.tipo(),
}
}
} }
#[derive(Debug, Clone, PartialEq, Eq, Copy)] #[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -1075,6 +1081,15 @@ pub enum Bls12_381Point {
G2(blst::blst_p2), G2(blst::blst_p2),
} }
impl Bls12_381Point {
pub fn tipo(&self) -> Rc<Type> {
match self {
Bls12_381Point::G1(_) => g1_element(),
Bls12_381Point::G2(_) => g2_element(),
}
}
}
impl Default for Bls12_381Point { impl Default for Bls12_381Point {
fn default() -> Self { fn default() -> Self {
Bls12_381Point::G1(Default::default()) Bls12_381Point::G1(Default::default())

View File

@ -37,7 +37,7 @@ pub enum TypedExpr {
CurvePoint { CurvePoint {
location: Span, location: Span,
tipo: Rc<Type>, tipo: Rc<Type>,
point: Curve, point: Box<Curve>,
}, },
Sequence { Sequence {
@ -485,7 +485,7 @@ pub enum UntypedExpr {
CurvePoint { CurvePoint {
location: Span, location: Span,
point: Curve, point: Box<Curve>,
preferred_format: ByteArrayFormatPreference, preferred_format: ByteArrayFormatPreference,
}, },

View File

@ -357,7 +357,11 @@ impl<'comments> Formatter<'comments> {
point, point,
preferred_format, preferred_format,
.. ..
} => self.bytearray(&point.compress(), Some(point.into()), preferred_format), } => self.bytearray(
&point.compress(),
Some(point.as_ref().into()),
preferred_format,
),
Constant::Int { value, base, .. } => self.int(value, base), Constant::Int { value, base, .. } => self.int(value, base),
Constant::String { value, .. } => self.string(value), Constant::String { value, .. } => self.string(value),
} }
@ -792,7 +796,11 @@ impl<'comments> Formatter<'comments> {
point, point,
preferred_format, preferred_format,
.. ..
} => self.bytearray(&point.compress(), Some(point.into()), preferred_format), } => self.bytearray(
&point.compress(),
Some(point.as_ref().into()),
preferred_format,
),
UntypedExpr::If { UntypedExpr::If {
branches, branches,

View File

@ -19,8 +19,8 @@ use uplc::{
use crate::{ use crate::{
ast::{ ast::{
AssignmentKind, BinOp, Pattern, Span, TypedArg, TypedClause, TypedDataType, TypedFunction, AssignmentKind, BinOp, Bls12_381Point, Curve, Pattern, Span, TypedArg, TypedClause,
TypedPattern, TypedValidator, UnOp, TypedDataType, TypedFunction, TypedPattern, TypedValidator, UnOp,
}, },
builtins::{bool, data, int, list, string, void}, builtins::{bool, data, int, list, string, void},
expr::TypedExpr, expr::TypedExpr,
@ -725,6 +725,7 @@ impl<'a> CodeGenerator<'a> {
} }
TypedExpr::UnOp { value, op, .. } => AirTree::unop(*op, self.build(value)), TypedExpr::UnOp { value, op, .. } => AirTree::unop(*op, self.build(value)),
TypedExpr::CurvePoint { point, .. } => AirTree::curve(*point.as_ref()),
} }
} }
@ -3550,6 +3551,10 @@ impl<'a> CodeGenerator<'a> {
Air::Bool { value } => { Air::Bool { value } => {
arg_stack.push(Term::bool(value)); arg_stack.push(Term::bool(value));
} }
Air::CurvePoint { point, .. } => match point {
Curve::Bls12_381(Bls12_381Point::G1(g1)) => arg_stack.push(Term::bls12_381_g1(g1)),
Curve::Bls12_381(Bls12_381Point::G2(g2)) => arg_stack.push(Term::bls12_381_g2(g2)),
},
Air::Var { Air::Var {
name, name,
constructor, constructor,

View File

@ -3,7 +3,7 @@ use std::rc::Rc;
use uplc::builtins::DefaultFunction; use uplc::builtins::DefaultFunction;
use crate::{ use crate::{
ast::{BinOp, UnOp}, ast::{BinOp, Curve, UnOp},
tipo::{Type, ValueConstructor}, tipo::{Type, ValueConstructor},
}; };
@ -19,6 +19,9 @@ pub enum Air {
ByteArray { ByteArray {
bytes: Vec<u8>, bytes: Vec<u8>,
}, },
CurvePoint {
point: Curve,
},
Bool { Bool {
value: bool, value: bool,
}, },

View File

@ -506,6 +506,7 @@ pub fn constants_ir(literal: &Constant) -> AirTree {
Constant::Int { value, .. } => AirTree::int(value), Constant::Int { value, .. } => AirTree::int(value),
Constant::String { value, .. } => AirTree::string(value), Constant::String { value, .. } => AirTree::string(value),
Constant::ByteArray { bytes, .. } => AirTree::byte_array(bytes.clone()), Constant::ByteArray { bytes, .. } => AirTree::byte_array(bytes.clone()),
Constant::CurvePoint { point, .. } => AirTree::curve(*point.as_ref()),
} }
} }
@ -1217,6 +1218,8 @@ pub fn convert_data_to_type(term: Term<Name>, field_type: &Rc<Type>) -> Term<Nam
Term::bls12_381_g1_uncompress().apply(Term::un_b_data().apply(term)) Term::bls12_381_g1_uncompress().apply(Term::un_b_data().apply(term))
} else if field_type.is_bls381_12_g2() { } else if field_type.is_bls381_12_g2() {
Term::bls12_381_g2_uncompress().apply(Term::un_b_data().apply(term)) Term::bls12_381_g2_uncompress().apply(Term::un_b_data().apply(term))
} else if field_type.is_ml_result() {
panic!("ML Result not supported")
} else { } else {
term term
} }
@ -1302,7 +1305,7 @@ pub fn convert_constants_to_data(constants: Vec<Rc<UplcConstant>>) -> Vec<UplcCo
UplcConstant::Bls12_381G2Element(b) => UplcConstant::Data(PlutusData::BoundedBytes( UplcConstant::Bls12_381G2Element(b) => UplcConstant::Data(PlutusData::BoundedBytes(
b.deref().clone().compress().into(), b.deref().clone().compress().into(),
)), )),
UplcConstant::Bls12_381MlResult(_) => unreachable!(), UplcConstant::Bls12_381MlResult(_) => unreachable!("Bls12_381MlResult not supported"),
}; };
new_constants.push(constant); new_constants.push(constant);
} }
@ -1365,6 +1368,8 @@ pub fn convert_type_to_data(term: Term<Name>, field_type: &Rc<Type>) -> Term<Nam
Term::bls12_381_g1_compress().apply(Term::b_data().apply(term)) Term::bls12_381_g1_compress().apply(Term::b_data().apply(term))
} else if field_type.is_bls381_12_g2() { } else if field_type.is_bls381_12_g2() {
Term::bls12_381_g2_compress().apply(Term::b_data().apply(term)) Term::bls12_381_g2_compress().apply(Term::b_data().apply(term))
} else if field_type.is_ml_result() {
panic!("ML Result not supported")
} else { } else {
term term
} }

View File

@ -4,7 +4,7 @@ use std::{borrow::BorrowMut, rc::Rc, slice::Iter};
use uplc::{builder::EXPECT_ON_LIST, builtins::DefaultFunction}; use uplc::{builder::EXPECT_ON_LIST, builtins::DefaultFunction};
use crate::{ use crate::{
ast::{BinOp, Span, UnOp}, ast::{BinOp, Curve, Span, UnOp},
builtins::{bool, byte_array, data, int, list, string, void}, builtins::{bool, byte_array, data, int, list, string, void},
tipo::{Type, ValueConstructor, ValueConstructorVariant}, tipo::{Type, ValueConstructor, ValueConstructorVariant},
}; };
@ -197,6 +197,9 @@ pub enum AirExpression {
ByteArray { ByteArray {
bytes: Vec<u8>, bytes: Vec<u8>,
}, },
CurvePoint {
point: Curve,
},
Bool { Bool {
value: bool, value: bool,
}, },
@ -341,6 +344,9 @@ impl AirTree {
pub fn byte_array(bytes: Vec<u8>) -> AirTree { pub fn byte_array(bytes: Vec<u8>) -> AirTree {
AirTree::Expression(AirExpression::ByteArray { bytes }) AirTree::Expression(AirExpression::ByteArray { bytes })
} }
pub fn curve(point: Curve) -> AirTree {
AirTree::Expression(AirExpression::CurvePoint { point })
}
pub fn bool(value: bool) -> AirTree { pub fn bool(value: bool) -> AirTree {
AirTree::Expression(AirExpression::Bool { value }) AirTree::Expression(AirExpression::Bool { value })
} }
@ -1058,6 +1064,9 @@ impl AirTree {
AirExpression::ByteArray { bytes } => air_vec.push(Air::ByteArray { AirExpression::ByteArray { bytes } => air_vec.push(Air::ByteArray {
bytes: bytes.clone(), bytes: bytes.clone(),
}), }),
AirExpression::CurvePoint { point } => {
air_vec.push(Air::CurvePoint { point: *point })
}
AirExpression::Bool { value } => air_vec.push(Air::Bool { value: *value }), AirExpression::Bool { value } => air_vec.push(Air::Bool { value: *value }),
AirExpression::List { tipo, tail, items } => { AirExpression::List { tipo, tail, items } => {
air_vec.push(Air::List { air_vec.push(Air::List {
@ -1286,6 +1295,7 @@ impl AirTree {
AirExpression::String { .. } => string(), AirExpression::String { .. } => string(),
AirExpression::ByteArray { .. } => byte_array(), AirExpression::ByteArray { .. } => byte_array(),
AirExpression::Bool { .. } => bool(), AirExpression::Bool { .. } => bool(),
AirExpression::CurvePoint { point } => point.tipo(),
AirExpression::List { tipo, .. } AirExpression::List { tipo, .. }
| AirExpression::Tuple { tipo, .. } | AirExpression::Tuple { tipo, .. }
| AirExpression::Call { tipo, .. } | AirExpression::Call { tipo, .. }

View File

@ -66,7 +66,7 @@ pub fn value() -> impl Parser<Token, ast::Constant, Error = ParseError> {
ast::Constant::CurvePoint { ast::Constant::CurvePoint {
location, location,
point: ast::Curve::Bls12_381(point), point: ast::Curve::Bls12_381(point).into(),
preferred_format, preferred_format,
} }
} }

View File

@ -27,7 +27,7 @@ pub fn parser() -> impl Parser<Token, UntypedExpr, Error = ParseError> {
UntypedExpr::CurvePoint { UntypedExpr::CurvePoint {
location, location,
point: ast::Curve::Bls12_381(point), point: ast::Curve::Bls12_381(point).into(),
preferred_format, preferred_format,
} }
} }

View File

@ -323,7 +323,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
UntypedExpr::CurvePoint { UntypedExpr::CurvePoint {
location, point, .. location, point, ..
} => self.infer_curve_point(point, location), } => self.infer_curve_point(*point, location),
UntypedExpr::RecordUpdate { UntypedExpr::RecordUpdate {
location, location,
@ -377,7 +377,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
Ok(TypedExpr::CurvePoint { Ok(TypedExpr::CurvePoint {
location, location,
point: curve, point: curve.into(),
tipo, tipo,
}) })
} }

View File

@ -38,6 +38,14 @@ impl<T> Term<T> {
Term::Constant(Constant::ByteString(b).into()) Term::Constant(Constant::ByteString(b).into())
} }
pub fn bls12_381_g1(b: blst::blst_p1) -> Self {
Term::Constant(Constant::Bls12_381G1Element(b.into()).into())
}
pub fn bls12_381_g2(b: blst::blst_p2) -> Self {
Term::Constant(Constant::Bls12_381G2Element(b.into()).into())
}
pub fn bool(b: bool) -> Self { pub fn bool(b: bool) -> Self {
Term::Constant(Constant::Bool(b).into()) Term::Constant(Constant::Bool(b).into())
} }