diff --git a/oqpy/base.py b/oqpy/base.py index 1c76d9b..9bfe79a 100644 --- a/oqpy/base.py +++ b/oqpy/base.py @@ -210,18 +210,21 @@ def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.Cl (ast.FloatType, ast.IntType): left_type, (ast.FloatType, ast.UintType): left_type, (ast.FloatType, ast.DurationType): right_type, + (ast.FloatType, ast.StretchType): right_type, (ast.FloatType, ast.AngleType): right_type, (ast.FloatType, ast.ComplexType): right_type, (ast.IntType, ast.FloatType): right_type, (ast.IntType, ast.IntType): left_type, (ast.IntType, ast.UintType): left_type, (ast.IntType, ast.DurationType): right_type, + (ast.IntType, ast.StretchType): right_type, (ast.IntType, ast.AngleType): right_type, (ast.IntType, ast.ComplexType): right_type, (ast.UintType, ast.FloatType): right_type, (ast.UintType, ast.IntType): right_type, (ast.UintType, ast.UintType): left_type, (ast.UintType, ast.DurationType): right_type, + (ast.UintType, ast.StretchType): right_type, (ast.UintType, ast.AngleType): right_type, (ast.UintType, ast.ComplexType): right_type, (ast.DurationType, ast.FloatType): left_type, @@ -230,18 +233,28 @@ def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.Cl (ast.DurationType, ast.DurationType): TypeError( "Cannot multiply two durations. You may need to re-group computations to eliminate this." ), + (ast.DurationType, ast.StretchType): TypeError("Cannot multiply duration and stretch"), (ast.DurationType, ast.AngleType): TypeError("Cannot multiply duration and angle"), (ast.DurationType, ast.ComplexType): TypeError("Cannot multiply duration and complex"), + (ast.StretchType, ast.FloatType): left_type, + (ast.StretchType, ast.IntType): left_type, + (ast.StretchType, ast.UintType): left_type, + (ast.StretchType, ast.DurationType): TypeError("Cannot multiply stretch and duration"), + (ast.StretchType, ast.StretchType): TypeError("Cannot multiply two stretches"), + (ast.StretchType, ast.AngleType): TypeError("Cannot multiply stretch and angle"), + (ast.StretchType, ast.ComplexType): TypeError("Cannot multiply stretch and complex"), (ast.AngleType, ast.FloatType): left_type, (ast.AngleType, ast.IntType): left_type, (ast.AngleType, ast.UintType): left_type, (ast.AngleType, ast.DurationType): TypeError("Cannot multiply angle and duration"), + (ast.AngleType, ast.StretchType): TypeError("Cannot multiply angle and stretch"), (ast.AngleType, ast.AngleType): TypeError("Cannot multiply two angles"), (ast.AngleType, ast.ComplexType): TypeError("Cannot multiply angle and complex"), (ast.ComplexType, ast.FloatType): left_type, (ast.ComplexType, ast.IntType): left_type, (ast.ComplexType, ast.UintType): left_type, (ast.ComplexType, ast.DurationType): TypeError("Cannot multiply complex and duration"), + (ast.ComplexType, ast.StretchType): TypeError("Cannot multiply complex and stretch"), (ast.ComplexType, ast.AngleType): TypeError("Cannot multiply complex and angle"), (ast.ComplexType, ast.ComplexType): left_type, } @@ -266,36 +279,49 @@ def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.C (ast.FloatType, ast.IntType): left_type, (ast.FloatType, ast.UintType): left_type, (ast.FloatType, ast.DurationType): TypeError("Cannot divide float by duration"), + (ast.FloatType, ast.StretchType): TypeError("Cannot divide float by stretch"), (ast.FloatType, ast.AngleType): TypeError("Cannot divide float by angle"), (ast.FloatType, ast.ComplexType): right_type, (ast.IntType, ast.FloatType): right_type, (ast.IntType, ast.IntType): float_type, (ast.IntType, ast.UintType): float_type, (ast.IntType, ast.DurationType): TypeError("Cannot divide int by duration"), + (ast.IntType, ast.StretchType): TypeError("Cannot divide int by stretch"), (ast.IntType, ast.AngleType): TypeError("Cannot divide int by angle"), (ast.IntType, ast.ComplexType): right_type, (ast.UintType, ast.FloatType): right_type, (ast.UintType, ast.IntType): float_type, (ast.UintType, ast.UintType): float_type, (ast.UintType, ast.DurationType): TypeError("Cannot divide uint by duration"), + (ast.UintType, ast.StretchType): TypeError("Cannot divide uint by stretch"), (ast.UintType, ast.AngleType): TypeError("Cannot divide uint by angle"), (ast.UintType, ast.ComplexType): right_type, (ast.DurationType, ast.FloatType): left_type, (ast.DurationType, ast.IntType): left_type, (ast.DurationType, ast.UintType): left_type, (ast.DurationType, ast.DurationType): ast.FloatType(), + (ast.DurationType, ast.StretchType): ast.FloatType(), (ast.DurationType, ast.AngleType): TypeError("Cannot divide duration by angle"), (ast.DurationType, ast.ComplexType): TypeError("Cannot divide duration by complex"), + (ast.StretchType, ast.FloatType): left_type, + (ast.StretchType, ast.IntType): left_type, + (ast.StretchType, ast.UintType): left_type, + (ast.StretchType, ast.DurationType): ast.FloatType(), + (ast.StretchType, ast.StretchType): ast.FloatType(), + (ast.StretchType, ast.AngleType): TypeError("Cannot divide stretch by angle"), + (ast.StretchType, ast.ComplexType): TypeError("Cannot divide stretch by complex"), (ast.AngleType, ast.FloatType): left_type, (ast.AngleType, ast.IntType): left_type, (ast.AngleType, ast.UintType): left_type, (ast.AngleType, ast.DurationType): TypeError("Cannot divide by duration"), + (ast.AngleType, ast.StretchType): TypeError("Cannot divide by stretch"), (ast.AngleType, ast.AngleType): float_type, (ast.AngleType, ast.ComplexType): TypeError("Cannot divide by angle by complex"), (ast.ComplexType, ast.FloatType): left_type, (ast.ComplexType, ast.IntType): left_type, (ast.ComplexType, ast.UintType): left_type, (ast.ComplexType, ast.DurationType): TypeError("Cannot divide by duration"), + (ast.ComplexType, ast.StretchType): TypeError("Cannot divide by stretch"), (ast.ComplexType, ast.AngleType): TypeError("Cannot divide by angle"), (ast.ComplexType, ast.ComplexType): left_type, } @@ -416,7 +442,7 @@ def __init__( self.type = ast_type # Adding floats to durations is not allowed. So we promote types as necessary. - if isinstance(self.type, ast.DurationType) and self.op in [ + if isinstance(self.type, (ast.DurationType, ast.StretchType)) and self.op in [ ast.BinaryOperator["+"], ast.BinaryOperator["-"], ]: diff --git a/oqpy/classical_types.py b/oqpy/classical_types.py index f161b15..e22ac7b 100644 --- a/oqpy/classical_types.py +++ b/oqpy/classical_types.py @@ -236,6 +236,11 @@ def make_declaration_statement(self, program: Program) -> ast.Statement: stmt.annotations = make_annotations(self.annotations) return stmt + def __str__(self) -> str: + return ( + f"<{self.__class__.__name__} name={self.name} init_expression={self.init_expression}>" + ) + class BoolVar(_ClassicalVar): """An (unsized) oqpy variable with bool type.""" @@ -373,6 +378,17 @@ class StretchVar(_ClassicalVar): type_cls = ast.StretchType + def __init__( + self, + init_expression: AstConvertible | Literal["input", "output"] | None = None, + name: str | None = None, + *args: Any, + **type_kwargs: Any, + ) -> None: + if init_expression is not None and not isinstance(init_expression, str): + init_expression = convert_float_to_duration(init_expression) + super().__init__(init_expression, name, *args, **type_kwargs) + AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar] diff --git a/oqpy/program.py b/oqpy/program.py index 54c6536..ddcc419 100644 --- a/oqpy/program.py +++ b/oqpy/program.py @@ -639,7 +639,7 @@ def include(self, path: str) -> Program: def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) -> None: """Helper function for variable assignment operations.""" - if isinstance(var, classical_types.DurationVar): + if isinstance(var, (classical_types.DurationVar, classical_types.StretchVar)): value = convert_float_to_duration(value) var_ast = to_ast(self, var) if isinstance(var_ast, ast.IndexExpression): diff --git a/oqpy/subroutines.py b/oqpy/subroutines.py index 375cbe5..9b01fce 100644 --- a/oqpy/subroutines.py +++ b/oqpy/subroutines.py @@ -249,14 +249,14 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ if k_idx < len(call_args): raise TypeError(f"{name}() got multiple values for argument '{k}'.") - if type(arg_types[k_idx]) == ast.DurationType: + if type(arg_types[k_idx]) in (ast.DurationType, ast.StretchType): new_args[k_idx] = convert_float_to_duration(call_kwargs[k]) else: new_args[k_idx] = call_kwargs[k] # Casting floats into durations for the non-keyword arguments for i, a in enumerate(call_args): - if type(arg_types[i]) == ast.DurationType: + if type(arg_types[i]) in (ast.DurationType, ast.StretchType): new_args[i] = convert_float_to_duration(a) return OQFunctionCall( name, {k: v for k, v in zip(arg_names, new_args)}, return_type, extern_decl=extern_decl diff --git a/oqpy/timing.py b/oqpy/timing.py index ac80d65..99edea0 100644 --- a/oqpy/timing.py +++ b/oqpy/timing.py @@ -81,7 +81,7 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool = if isinstance(time, OQPyExpression): if isinstance(time.type, (ast.UintType, ast.IntType, ast.FloatType)): time = time * OQDurationLiteral(1) - elif not isinstance(time.type, ast.DurationType): + elif not isinstance(time.type, (ast.DurationType, ast.StretchType)): raise TypeError(f"Cannot convert expression with type {time.type} to duration") if hasattr(time, "to_ast"): return time # type: ignore[return-value] @@ -91,7 +91,9 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool = def convert_duration_to_float(value: AstConvertible) -> AstConvertible: - if isinstance(value, OQPyExpression) and isinstance(value.type, ast.DurationType): + if isinstance(value, OQPyExpression) and isinstance( + value.type, (ast.DurationType, ast.StretchType) + ): value = value / OQDurationLiteral(1) return value diff --git a/tests/test_directives.py b/tests/test_directives.py index 0ef5420..6b5d0f5 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -562,6 +562,67 @@ def test_binary_expressions(): _check_respects_type_hints(prog) +def test_stretch_as_duration(): + prog = Program() + s = StretchVar(1e-6, "s") + d = DurationVar(1e-7, "d") + + myext = declare_extern( + "myext", + [("dd", oqpy.duration), ("ss", oqpy.stretch)], + ) + + prog.do_expression(myext(s, 2e-8)) + prog.delay(s) + with Box(prog, s + d): + prog.delay(s / 2) + prog.set(s, 3e-6) + prog.set(s, s + d) + prog.set(d, s) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + extern myext(duration, stretch); + stretch s = 1.0us; + duration d = 100.0ns; + myext(s, 20.0ns); + delay[s]; + box[s + d] { + delay[s / 2]; + } + s = 3.0us; + s = s + d; + d = s; + """ + ).strip() + + assert prog.to_qasm() == expected + _check_respects_type_hints(prog) + + +def test_stretch_arithmetic_types(): + s = StretchVar(name="s") + d = DurationVar(1e-7, name="d") + + assert isinstance((s + s).type, ast.StretchType) + assert isinstance((s + d).type, ast.StretchType) + assert isinstance((d + s).type, ast.DurationType) + assert isinstance((2 * s).type, ast.StretchType) + assert isinstance((s * 2).type, ast.StretchType) + assert isinstance((s / 2).type, ast.StretchType) + assert isinstance((s / s).type, ast.FloatType) + assert isinstance((s / d).type, ast.FloatType) + assert isinstance((d / s).type, ast.FloatType) + + with pytest.raises(TypeError): + s * d + with pytest.raises(TypeError): + 2 / s + with pytest.raises(TypeError): + s * 1j + + @pytest.mark.xfail def test_add_incomptible_type(): # This test should fail since we add float to a duration and then don't type cast things