Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.Cl
raise TypeError(f"Could not identify types for product {left} and {right}") from e
if isinstance(result_type, Exception):
raise result_type
assert result_type is not None
return result_type


Expand Down Expand Up @@ -306,6 +307,7 @@ def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.C
raise TypeError(f"Could not identify types for quotient {left} and {right}") from e
if isinstance(result_type, Exception):
raise result_type
assert result_type is not None
return result_type


Expand Down Expand Up @@ -529,10 +531,16 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
return detect_and_convert_constants(item, program)
return ast.FloatLiteral(item)
if isinstance(item, slice):
return ast.RangeDefinition(
to_ast(program, item.start) if item.start is not None else None,
to_ast(program, item.stop - 1) if item.stop is not None else None,
to_ast(program, item.step) if item.step is not None else None,
# Technically, RangeDefinition is not an Expression
# but we can treat it like one, since an expression is allowable anywhere
# a RangeDefinition is allowed.
return cast(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return cast(
# Technically, RangeDefinition is not an Expression
# but we can treat it like one, since an expression is allowable anywhere
# a RangeDefinition is allowed.
return cast(

ast.Expression,
ast.RangeDefinition(
to_ast(program, item.start) if item.start is not None else None,
to_ast(program, item.stop - 1) if item.stop is not None else None,
to_ast(program, item.step) if item.step is not None else None,
),
)
if isinstance(item, Iterable):
return ast.ArrayLiteral([to_ast(program, i) for i in item])
Expand All @@ -555,6 +563,24 @@ def map_to_ast(program: Program, items: Iterable[AstConvertible]) -> list[ast.Ex
return [to_ast(program, item) for item in items]


def map_to_identifiers(
program: Program, items: Iterable[AstConvertible]
) -> list[ast.IndexedIdentifier | ast.Identifier]:
"""Convert a sequence of items into identifier ast nodes.

Verifies at runtime that each item converts to an Identifier or IndexedIdentifier.
"""
result: list[ast.IndexedIdentifier | ast.Identifier] = []
for item in items:
node = to_ast(program, item)
if not isinstance(node, (ast.Identifier, ast.IndexedIdentifier)):
raise TypeError(
f"Expected Identifier or IndexedIdentifier, got {type(node).__name__}"
)
result.append(node)
return result


def make_annotations(vals: Sequence[str | tuple[str, str]]) -> list[ast.Annotation]:
"""Convert strings/tuples of strings into Annotation ast nodes."""
anns: list[ast.Annotation] = []
Expand All @@ -570,15 +596,15 @@ def make_annotations(vals: Sequence[str | tuple[str, str]]) -> list[ast.Annotati
def detect_and_convert_constants(val: float | np.floating[Any], program: Program) -> ast.Expression:
"""Construct a float ast expression which is either a literal or an expression using constants."""
if val == 0:
return ast.FloatLiteral(val)
return ast.FloatLiteral(float(val))
if val < 0.5 or val > 100:
return ast.FloatLiteral(val)
return ast.FloatLiteral(float(val))
if math.isnan(val):
return ast.FloatLiteral(val)
return ast.FloatLiteral(float(val))
x = val / (math.pi / 4.0)
rx = round(x)
if not math.isclose(x, rx, rel_tol=1e-12):
return ast.FloatLiteral(val)
return ast.FloatLiteral(float(val))
term: OQPyExpression
if rx == 4:
term = classical_types.pi
Expand Down
22 changes: 12 additions & 10 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import functools
import random
import string
import sys
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -51,10 +50,6 @@

from oqpy.program import Program

if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType

__all__ = [
"pi",
Expand Down Expand Up @@ -150,7 +145,7 @@ def arrayreference_(
dims: int | list[int],
) -> ast.ArrayReferenceType:
"""Create an array reference type."""
dim = (
dim: ast.Expression | list[ast.Expression] = (
ast.IntegerLiteral(dims) if isinstance(dims, int) else [ast.IntegerLiteral(d) for d in dims]
)
return ast.ArrayReferenceType(base_type=dtype, dimensions=dim)
Expand Down Expand Up @@ -226,6 +221,8 @@ def to_ast(self, program: Program) -> ast.Identifier:

def make_declaration_statement(self, program: Program) -> ast.Statement:
"""Make an ast statement that declares the OQpy variable."""
assert self.type is not None
stmt: ast.IODeclaration | ast.ClassicalDeclaration
if isinstance(self.init_expression, str) and self.init_expression in ("input", "output"):
stmt = ast.IODeclaration(
ast.IOKeyword[self.init_expression], self.type, self.to_ast(program)
Expand Down Expand Up @@ -253,7 +250,7 @@ def __class_getitem__(cls: Type[_SizedVarT], item: int | None) -> Callable[...,
# Allows IntVar[64]() notation
return functools.partial(cls, size=item)

def __init__(self, *args: Any, size: int | None | EllipsisType = ..., **kwargs: Any):
def __init__(self, *args: Any, size: Any = ..., **kwargs: Any):
if size is ...:
self.size = self.default_size
elif size is None:
Expand Down Expand Up @@ -408,9 +405,12 @@ def __init__(
# Creating a dummy variable supports IntVar[64] etc.
base_type_instance = base_type()
if isinstance(base_type_instance, _SizedVar):
array_base_type = base_type_instance.type_cls(
size=ast.IntegerLiteral(base_type_instance.size)
size_arg = (
ast.IntegerLiteral(base_type_instance.size)
if base_type_instance.size is not None
else None
)
array_base_type = base_type_instance.type_cls(size=size_arg) # type: ignore[call-arg]
elif isinstance(base_type_instance, ComplexVar):
array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type)
else:
Expand All @@ -437,7 +437,9 @@ def __init__(
)

def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
return OQIndexExpression(collection=self, index=index, type_=self.base_type().type_cls())
base_instance = self.base_type()
assert base_instance.type is not None
return OQIndexExpression(collection=self, index=index, type_=base_instance.type)


class OQIndexExpression(OQPyExpression):
Expand Down
28 changes: 19 additions & 9 deletions oqpy/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar, overload
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar, cast, overload

from openpulse import ast

Expand Down Expand Up @@ -118,6 +118,7 @@ def ForIn(
yield var
state = program._pop()

set_declaration: ast.RangeDefinition | ast.DiscreteSet | ast.Expression
if isinstance(iterator, range):
# A range can only be iterated over integers.
assert identifier_type is IntVar, "A range can only be looped over an integer."
Expand All @@ -131,7 +132,10 @@ def ForIn(
set_declaration = to_ast(program, iterator)

stmt = ast.ForInLoop(
identifier_type.type_cls(), var.to_ast(program), set_declaration, state.body
identifier_type.type_cls(),
var.to_ast(program),
set_declaration,
state.body,
)
program._add_statement(stmt)

Expand All @@ -150,14 +154,20 @@ def __init__(self, start: AstConvertible, stop: AstConvertible, step: AstConvert

def to_ast(self, program: Program) -> ast.Expression:
"""Convert to an ast.RangeDefinition."""
return ast.RangeDefinition(
to_ast(program, self.start),
ast.BinaryExpression(
lhs=to_ast(program, self.stop),
op=ast.BinaryOperator["-"],
rhs=ast.IntegerLiteral(value=1),
# Technically, RangeDefinition is not an Expression
# but we can treat it like one, since an expression is allowable anywhere
# a RangeDefinition is allowed.
return cast(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return cast(
# Technically, RangeDefinition is not an Expression
# but we can treat it like one, since an expression is allowable anywhere
# a RangeDefinition is allowed.
return cast(

ast.Expression,
ast.RangeDefinition(
to_ast(program, self.start),
ast.BinaryExpression(
lhs=to_ast(program, self.stop),
op=ast.BinaryOperator["-"],
rhs=ast.IntegerLiteral(value=1),
),
to_ast(program, self.step) if self.step != 1 else None,
),
to_ast(program, self.step) if self.step != 1 else None,
)


Expand Down
Loading
Loading