Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user/next/workshop/exercises/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gt4py.next import Dimension, DimensionKind, FieldOffset
from gt4py.next.program_processors.runners import roundtrip
from gt4py.next.program_processors.runners.gtfn import (
run_gtfn_cached as gtfn_cpu,
run_gtfn as gtfn_cpu,
run_gtfn_gpu as gtfn_gpu,
)

Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@
where,
)
from .otf.compiled_program import wait_for_compilation
from .program_processors.runners.gtfn import (
run_gtfn_cached as gtfn_cpu,
run_gtfn_gpu_cached as gtfn_gpu,
)
from .program_processors.runners.gtfn import run_gtfn as gtfn_cpu, run_gtfn_gpu as gtfn_gpu
from .program_processors.runners.roundtrip import default as itir_python


Expand Down
5 changes: 2 additions & 3 deletions src/gt4py/next/program_processors/formatters/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

@program_formatter.program_formatter
def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str:
# TODO(tehrengruber): This is a little ugly. Revisit.
gtfn_translation = gtfn.GTFNBackendFactory().executor.translation # type: ignore[attr-defined]
gtfn_translation = gtfn.GTFNCompileWorkflowFactory(cached_translation=False).translation
assert isinstance(gtfn_translation, GTFNTranslationStep)
return gtfn_translation.generate_stencil_source(
program,
offset_provider=kwargs.get("offset_provider", None), # type: ignore[arg-type]
offset_provider=kwargs.get("offset_provider", {}),
column_axis=kwargs.get("column_axis", None),
)
4 changes: 0 additions & 4 deletions src/gt4py/next/program_processors/runners/dace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from gt4py.next.program_processors.runners.dace.workflow.backend import (
make_dace_backend,
run_dace_cpu,
run_dace_cpu_cached,
run_dace_cpu_noopt,
run_dace_gpu,
run_dace_gpu_cached,
run_dace_gpu_noopt,
)

Expand All @@ -23,9 +21,7 @@
"get_sdfg_args",
"make_dace_backend",
"run_dace_cpu",
"run_dace_cpu_cached",
"run_dace_cpu_noopt",
"run_dace_gpu",
"run_dace_gpu_cached",
"run_dace_gpu_noopt",
]
13 changes: 5 additions & 8 deletions src/gt4py/next/program_processors/runners/dace/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,12 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG:

compile_workflow = typing.cast(
recipes.OTFCompileWorkflow,
self.backend.executor
if not hasattr(self.backend.executor, "step")
else self.backend.executor.step,
self.backend.executor.step if self.backend.cached else self.backend.executor, # type: ignore[attr-defined]
) # We know which backend we are using, but we don't know if the compile workflow is cached.
compile_workflow_translation = (
compile_workflow.translation
if not hasattr(compile_workflow.translation, "step")
else compile_workflow.translation.step
) # Same for the translation stage, which could be a `CachedStep` depending on backend configuration.
assert hasattr(
compile_workflow.translation, "step"
) # the translation stage is always cached
compile_workflow_translation = compile_workflow.translation.step
# TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously
# and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with
# the other parts of the workaround when possible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Params:
hash_function = stages.compilation_hash
otf_workflow = factory.SubFactory(
DaCeWorkflowFactory,
cached_translation=True,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
)
Expand Down Expand Up @@ -127,7 +128,6 @@ def make_dace_backend(
gpu=gpu,
cached=cached,
auto_optimize=auto_optimize,
otf_workflow__cached_translation=cached,
otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False),
otf_workflow__bare_translation__auto_optimize_args=optimization_args,
otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride,
Expand All @@ -139,38 +139,22 @@ def make_dace_backend(

run_dace_cpu = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=True,
async_sdfg_call=False,
)
run_dace_cpu_noopt = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=False,
async_sdfg_call=False,
)
run_dace_cpu_cached = make_dace_backend(
gpu=False,
cached=True,
auto_optimize=True,
async_sdfg_call=False,
)

run_dace_gpu = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=True,
async_sdfg_call=True,
)
run_dace_gpu_noopt = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=False,
async_sdfg_call=True,
)
run_dace_gpu_cached = make_dace_backend(
gpu=True,
cached=True,
auto_optimize=True,
async_sdfg_call=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class Params:
lambda: config.CMAKE_BUILD_TYPE
)

bare_translation = factory.SubFactory(
DaCeTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
)
cached_translation = factory.Trait(
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
Expand All @@ -58,12 +63,6 @@ class Params:
),
)

bare_translation = factory.SubFactory(
DaCeTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
)

translation = factory.LazyAttribute(lambda o: o.bare_translation)
bindings = factory.LazyAttribute(
lambda o: functools.partial(
Expand Down
35 changes: 13 additions & 22 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class Params:
lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type)
)

bare_translation = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)
cached_translation = factory.Trait(
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
Expand All @@ -131,13 +135,7 @@ class Params:
),
)

bare_translation = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)

translation = factory.LazyAttribute(lambda o: o.bare_translation)

bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] = (
nanobind.bind_source
)
Expand All @@ -158,7 +156,6 @@ class Meta:
class Params:
name_device = "cpu"
name_cached = ""
name_temps = ""
name_postfix = ""
gpu = factory.Trait(
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
Expand All @@ -174,32 +171,26 @@ class Params:
device_type = core_defs.DeviceType.CPU
hash_function = stages.compilation_hash
otf_workflow = factory.SubFactory(
GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type")
GTFNCompileWorkflowFactory,
cached_translation=True,
device_type=factory.SelfAttribute("..device_type"),
)

name = factory.LazyAttribute(
lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}"
lambda o: f"run_gtfn_{o.name_device}{o.name_cached}{o.name_postfix}"
)

executor = factory.LazyAttribute(lambda o: o.otf_workflow)
allocator = next_allocators.StandardCPUFieldBufferAllocator()
transforms = backend.DEFAULT_TRANSFORMS


run_gtfn = GTFNBackendFactory()
run_gtfn = GTFNBackendFactory(cached=True)

run_gtfn_imperative = GTFNBackendFactory(
name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True
)

run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True)

run_gtfn_gpu = GTFNBackendFactory(gpu=True)

run_gtfn_gpu_cached = GTFNBackendFactory(
gpu=True, cached=True, otf_workflow__cached_translation=True
cached=True,
name_postfix="_imperative",
otf_workflow__translation__use_imperative_backend=True,
)

run_gtfn_no_transforms = GTFNBackendFactory(
otf_workflow__bare_translation__enable_itir_transforms=False
)
run_gtfn_gpu = GTFNBackendFactory(cached=True, gpu=True)
8 changes: 4 additions & 4 deletions tests/next_tests/benchmarks/benchmark_program_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from gt4py.next.program_processors.runners import dace as dace_backends

DACE_BACKENDS = (
[dace_backends.run_dace_cpu_cached, dace_backends.run_dace_gpu_cached]
[dace_backends.run_dace_cpu, dace_backends.run_dace_gpu]
if cp is not None
else [dace_backends.run_dace_cpu_cached]
else [dace_backends.run_dace_cpu]
)
except ImportError:
DACE_BACKENDS = []
Expand Down Expand Up @@ -357,9 +357,9 @@ def benchmark(*args, **kwargs):
backend_name = (arg.split("=", 1)[-1]).strip()
match backend_name:
case "dace-cpu":
backends.append(dace_backends.run_dace_cpu_cached)
backends.append(dace_backends.run_dace_cpu)
case "dace-gpu":
backends.append(dace_backends.run_dace_gpu_cached)
backends.append(dace_backends.run_dace_gpu)
case "gtfn-cpu":
backends.append(gtfn_cpu)
case "gtfn-gpu":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from gt4py import next as gtx
from gt4py.next import backend, common
from gt4py.next.iterator.transforms import apply_common_transforms
from gt4py.next.program_processors.runners.gtfn import run_gtfn
from gt4py.next.program_processors.runners import gtfn
from gt4py.next import custom_layout_allocators as next_allocators

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
Expand All @@ -33,16 +34,16 @@ def exec_alloc_descriptor():
return backend.Backend(
name="run_gtfn_with_temporaries_and_sizes",
transforms=backend.DEFAULT_TRANSFORMS,
executor=run_gtfn.executor.replace(
translation=run_gtfn.executor.translation.replace(
executor=gtfn.GTFNCompileWorkflowFactory(
translation=gtfn.gtfn_module.GTFNTranslationStepFactory(
symbolic_domain_sizes={
"Cell": "num_cells",
"Edge": "num_edges",
"Vertex": "num_vertices",
}
)
),
allocator=run_gtfn.allocator,
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
try:
from gt4py.next.program_processors.runners import dace as dace_backends

BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached]
BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu]
except ImportError:
BACKENDS = [None, gtfn_cpu]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
abs,
)
from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at
from gt4py.next.program_processors.runners.gtfn import run_gtfn
from gt4py.next.program_processors.runners import gtfn

from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data
from next_tests.unit_tests.conftest import program_processor, run_processor
Expand Down Expand Up @@ -189,11 +189,8 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected):
inps = field_maker(*array_maker(*inputs))
out = field_maker((np.zeros_like(*array_maker(expected))))[0]

gtfn_without_transforms = dataclasses.replace(
run_gtfn,
executor=run_gtfn.executor.replace(
translation=run_gtfn.executor.translation.replace(enable_itir_transforms=False),
), # avoid inlining the function
gtfn_without_transforms = gtfn.GTFNBackendFactory(
otf_workflow__bare_translation__enable_itir_transforms=False
)

fencil(builtin, out, *inps, processor=gtfn_without_transforms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def test_gtfn_file_cache(program_example):
data=fencil,
args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}),
)
cached_gtfn_translation_step = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=True
).executor.step.translation
cached_gtfn_translation_step = gtfn.GTFNCompileWorkflowFactory(
cached_translation=True
).translation

bare_gtfn_translation_step = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=False
).executor.step.translation
bare_gtfn_translation_step = gtfn.GTFNCompileWorkflowFactory(
cached_translation=False
).translation

cache_key = stages.fingerprint_compilable_program(compilable_program)

Expand All @@ -156,29 +156,3 @@ def test_gtfn_file_cache(program_example):
bare_gtfn_translation_step(compilable_program)
== cached_gtfn_translation_step.cache[cache_key]
)


# TODO(egparedes): we should switch to use the cached backend by default and then remove this test
def test_gtfn_file_cache_whole_workflow(cartesian_case_no_backend):
cartesian_case = cartesian_case_no_backend
cartesian_case.backend = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=True
)
cartesian_case.allocator = next_allocators.StandardCPUFieldBufferAllocator()

assert cartesian_case.backend is not None
assert cartesian_case.allocator is not None

@gtx.field_operator
def testee(a: cases.IJKField) -> cases.IJKField:
field_tuple = (a, a)
field_0 = field_tuple[0]
field_1 = field_tuple[1]
return field_0

# first call: this generates the cache file
cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)
# clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again
object.__setattr__(cartesian_case.backend.executor, "cache", {})
# second call: the cache file is used
cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
# see https://docs.pytest.org/en/latest/how-to/fixtures.html#override-a-fixture-on-a-test-module-level
@pytest.fixture(
params=[
pytest.param(dace_backends.run_dace_cpu_cached, marks=pytest.mark.requires_dace),
pytest.param(dace_backends.run_dace_cpu, marks=pytest.mark.requires_dace),
pytest.param(
dace_backends.run_dace_gpu_cached,
dace_backends.run_dace_gpu,
marks=(pytest.mark.requires_gpu, pytest.mark.requires_dace),
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def test_backend_factory_trait_device():
assert cpu_version.name == "run_gtfn_cpu"
assert gpu_version.name == "run_gtfn_gpu"

assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU
assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA
assert isinstance(cpu_version.executor.translation, workflow.CachedStep)
assert cpu_version.executor.translation.step.device_type is core_defs.DeviceType.CPU
assert isinstance(gpu_version.executor.translation, workflow.CachedStep)
assert gpu_version.executor.translation.step.device_type is core_defs.DeviceType.CUDA

assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU
assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA
Expand Down
Loading