Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,21 @@ def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.Cl
(ast.FloatType, ast.IntType): left_type,
(ast.FloatType, ast.UintType): left_type,
(ast.FloatType, ast.DurationType): right_type,
(ast.FloatType, ast.StretchType): right_type,
(ast.FloatType, ast.AngleType): right_type,
(ast.FloatType, ast.ComplexType): right_type,
(ast.IntType, ast.FloatType): right_type,
(ast.IntType, ast.IntType): left_type,
(ast.IntType, ast.UintType): left_type,
(ast.IntType, ast.DurationType): right_type,
(ast.IntType, ast.StretchType): right_type,
(ast.IntType, ast.AngleType): right_type,
(ast.IntType, ast.ComplexType): right_type,
(ast.UintType, ast.FloatType): right_type,
(ast.UintType, ast.IntType): right_type,
(ast.UintType, ast.UintType): left_type,
(ast.UintType, ast.DurationType): right_type,
(ast.UintType, ast.StretchType): right_type,
(ast.UintType, ast.AngleType): right_type,
(ast.UintType, ast.ComplexType): right_type,
(ast.DurationType, ast.FloatType): left_type,
Expand All @@ -230,18 +233,28 @@ def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.Cl
(ast.DurationType, ast.DurationType): TypeError(
"Cannot multiply two durations. You may need to re-group computations to eliminate this."
),
(ast.DurationType, ast.StretchType): TypeError("Cannot multiply duration and stretch"),
(ast.DurationType, ast.AngleType): TypeError("Cannot multiply duration and angle"),
(ast.DurationType, ast.ComplexType): TypeError("Cannot multiply duration and complex"),
(ast.StretchType, ast.FloatType): left_type,
(ast.StretchType, ast.IntType): left_type,
(ast.StretchType, ast.UintType): left_type,
(ast.StretchType, ast.DurationType): TypeError("Cannot multiply stretch and duration"),
(ast.StretchType, ast.StretchType): TypeError("Cannot multiply two stretches"),
(ast.StretchType, ast.AngleType): TypeError("Cannot multiply stretch and angle"),
(ast.StretchType, ast.ComplexType): TypeError("Cannot multiply stretch and complex"),
(ast.AngleType, ast.FloatType): left_type,
(ast.AngleType, ast.IntType): left_type,
(ast.AngleType, ast.UintType): left_type,
(ast.AngleType, ast.DurationType): TypeError("Cannot multiply angle and duration"),
(ast.AngleType, ast.StretchType): TypeError("Cannot multiply angle and stretch"),
(ast.AngleType, ast.AngleType): TypeError("Cannot multiply two angles"),
(ast.AngleType, ast.ComplexType): TypeError("Cannot multiply angle and complex"),
(ast.ComplexType, ast.FloatType): left_type,
(ast.ComplexType, ast.IntType): left_type,
(ast.ComplexType, ast.UintType): left_type,
(ast.ComplexType, ast.DurationType): TypeError("Cannot multiply complex and duration"),
(ast.ComplexType, ast.StretchType): TypeError("Cannot multiply complex and stretch"),
(ast.ComplexType, ast.AngleType): TypeError("Cannot multiply complex and angle"),
(ast.ComplexType, ast.ComplexType): left_type,
}
Expand All @@ -266,36 +279,49 @@ def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.C
(ast.FloatType, ast.IntType): left_type,
(ast.FloatType, ast.UintType): left_type,
(ast.FloatType, ast.DurationType): TypeError("Cannot divide float by duration"),
(ast.FloatType, ast.StretchType): TypeError("Cannot divide float by stretch"),
(ast.FloatType, ast.AngleType): TypeError("Cannot divide float by angle"),
(ast.FloatType, ast.ComplexType): right_type,
(ast.IntType, ast.FloatType): right_type,
(ast.IntType, ast.IntType): float_type,
(ast.IntType, ast.UintType): float_type,
(ast.IntType, ast.DurationType): TypeError("Cannot divide int by duration"),
(ast.IntType, ast.StretchType): TypeError("Cannot divide int by stretch"),
(ast.IntType, ast.AngleType): TypeError("Cannot divide int by angle"),
(ast.IntType, ast.ComplexType): right_type,
(ast.UintType, ast.FloatType): right_type,
(ast.UintType, ast.IntType): float_type,
(ast.UintType, ast.UintType): float_type,
(ast.UintType, ast.DurationType): TypeError("Cannot divide uint by duration"),
(ast.UintType, ast.StretchType): TypeError("Cannot divide uint by stretch"),
(ast.UintType, ast.AngleType): TypeError("Cannot divide uint by angle"),
(ast.UintType, ast.ComplexType): right_type,
(ast.DurationType, ast.FloatType): left_type,
(ast.DurationType, ast.IntType): left_type,
(ast.DurationType, ast.UintType): left_type,
(ast.DurationType, ast.DurationType): ast.FloatType(),
(ast.DurationType, ast.StretchType): ast.FloatType(),
(ast.DurationType, ast.AngleType): TypeError("Cannot divide duration by angle"),
(ast.DurationType, ast.ComplexType): TypeError("Cannot divide duration by complex"),
(ast.StretchType, ast.FloatType): left_type,
(ast.StretchType, ast.IntType): left_type,
(ast.StretchType, ast.UintType): left_type,
(ast.StretchType, ast.DurationType): ast.FloatType(),
(ast.StretchType, ast.StretchType): ast.FloatType(),
(ast.StretchType, ast.AngleType): TypeError("Cannot divide stretch by angle"),
(ast.StretchType, ast.ComplexType): TypeError("Cannot divide stretch by complex"),
(ast.AngleType, ast.FloatType): left_type,
(ast.AngleType, ast.IntType): left_type,
(ast.AngleType, ast.UintType): left_type,
(ast.AngleType, ast.DurationType): TypeError("Cannot divide by duration"),
(ast.AngleType, ast.StretchType): TypeError("Cannot divide by stretch"),
(ast.AngleType, ast.AngleType): float_type,
(ast.AngleType, ast.ComplexType): TypeError("Cannot divide by angle by complex"),
(ast.ComplexType, ast.FloatType): left_type,
(ast.ComplexType, ast.IntType): left_type,
(ast.ComplexType, ast.UintType): left_type,
(ast.ComplexType, ast.DurationType): TypeError("Cannot divide by duration"),
(ast.ComplexType, ast.StretchType): TypeError("Cannot divide by stretch"),
(ast.ComplexType, ast.AngleType): TypeError("Cannot divide by angle"),
(ast.ComplexType, ast.ComplexType): left_type,
}
Expand Down Expand Up @@ -416,7 +442,7 @@ def __init__(
self.type = ast_type

# Adding floats to durations is not allowed. So we promote types as necessary.
if isinstance(self.type, ast.DurationType) and self.op in [
if isinstance(self.type, (ast.DurationType, ast.StretchType)) and self.op in [
ast.BinaryOperator["+"],
ast.BinaryOperator["-"],
]:
Expand Down
16 changes: 16 additions & 0 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def make_declaration_statement(self, program: Program) -> ast.Statement:
stmt.annotations = make_annotations(self.annotations)
return stmt

def __str__(self) -> str:
return (
f"<{self.__class__.__name__} name={self.name} init_expression={self.init_expression}>"
)


class BoolVar(_ClassicalVar):
"""An (unsized) oqpy variable with bool type."""
Expand Down Expand Up @@ -373,6 +378,17 @@ class StretchVar(_ClassicalVar):

type_cls = ast.StretchType

def __init__(
self,
init_expression: AstConvertible | Literal["input", "output"] | None = None,
name: str | None = None,
*args: Any,
**type_kwargs: Any,
) -> None:
if init_expression is not None and not isinstance(init_expression, str):
init_expression = convert_float_to_duration(init_expression)
super().__init__(init_expression, name, *args, **type_kwargs)


AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar]

Expand Down
2 changes: 1 addition & 1 deletion oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def include(self, path: str) -> Program:

def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) -> None:
"""Helper function for variable assignment operations."""
if isinstance(var, classical_types.DurationVar):
if isinstance(var, (classical_types.DurationVar, classical_types.StretchVar)):
value = convert_float_to_duration(value)
var_ast = to_ast(self, var)
if isinstance(var_ast, ast.IndexExpression):
Expand Down
4 changes: 2 additions & 2 deletions oqpy/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ
if k_idx < len(call_args):
raise TypeError(f"{name}() got multiple values for argument '{k}'.")

if type(arg_types[k_idx]) == ast.DurationType:
if type(arg_types[k_idx]) in (ast.DurationType, ast.StretchType):
new_args[k_idx] = convert_float_to_duration(call_kwargs[k])
else:
new_args[k_idx] = call_kwargs[k]

# Casting floats into durations for the non-keyword arguments
for i, a in enumerate(call_args):
if type(arg_types[i]) == ast.DurationType:
if type(arg_types[i]) in (ast.DurationType, ast.StretchType):
new_args[i] = convert_float_to_duration(a)
return OQFunctionCall(
name, {k: v for k, v in zip(arg_names, new_args)}, return_type, extern_decl=extern_decl
Expand Down
6 changes: 4 additions & 2 deletions oqpy/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool =
if isinstance(time, OQPyExpression):
if isinstance(time.type, (ast.UintType, ast.IntType, ast.FloatType)):
time = time * OQDurationLiteral(1)
elif not isinstance(time.type, ast.DurationType):
elif not isinstance(time.type, (ast.DurationType, ast.StretchType)):
raise TypeError(f"Cannot convert expression with type {time.type} to duration")
if hasattr(time, "to_ast"):
return time # type: ignore[return-value]
Expand All @@ -91,7 +91,9 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool =


def convert_duration_to_float(value: AstConvertible) -> AstConvertible:
if isinstance(value, OQPyExpression) and isinstance(value.type, ast.DurationType):
if isinstance(value, OQPyExpression) and isinstance(
value.type, (ast.DurationType, ast.StretchType)
):
value = value / OQDurationLiteral(1)
return value

Expand Down
61 changes: 61 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,67 @@ def test_binary_expressions():
_check_respects_type_hints(prog)


def test_stretch_as_duration():
prog = Program()
s = StretchVar(1e-6, "s")
d = DurationVar(1e-7, "d")

myext = declare_extern(
"myext",
[("dd", oqpy.duration), ("ss", oqpy.stretch)],
)

prog.do_expression(myext(s, 2e-8))
prog.delay(s)
with Box(prog, s + d):
prog.delay(s / 2)
prog.set(s, 3e-6)
prog.set(s, s + d)
prog.set(d, s)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
extern myext(duration, stretch);
stretch s = 1.0us;
duration d = 100.0ns;
myext(s, 20.0ns);
delay[s];
box[s + d] {
delay[s / 2];
}
s = 3.0us;
s = s + d;
d = s;
"""
).strip()

assert prog.to_qasm() == expected
_check_respects_type_hints(prog)


def test_stretch_arithmetic_types():
s = StretchVar(name="s")
d = DurationVar(1e-7, name="d")

assert isinstance((s + s).type, ast.StretchType)
assert isinstance((s + d).type, ast.StretchType)
assert isinstance((d + s).type, ast.DurationType)
assert isinstance((2 * s).type, ast.StretchType)
assert isinstance((s * 2).type, ast.StretchType)
assert isinstance((s / 2).type, ast.StretchType)
assert isinstance((s / s).type, ast.FloatType)
assert isinstance((s / d).type, ast.FloatType)
assert isinstance((d / s).type, ast.FloatType)

with pytest.raises(TypeError):
s * d
with pytest.raises(TypeError):
2 / s
with pytest.raises(TypeError):
s * 1j


@pytest.mark.xfail
def test_add_incomptible_type():
# This test should fail since we add float to a duration and then don't type cast things
Expand Down
Loading