Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Release History

## 1.51.0 (TBD)

### Snowpark Python API Updates

#### New Features

- Added support for `DataFrame.pipe`.

## 1.50.0 (TBD)

### Snowpark Python API Updates
Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DataFrame
DataFrame.natural_join
DataFrame.orderBy
DataFrame.order_by
DataFrame.pipe
DataFrame.pivot
DataFrame.print_schema
DataFrame.printSchema
Expand Down
47 changes: 47 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
Expand Down Expand Up @@ -243,10 +244,23 @@
else:
from collections.abc import Iterable

# Python 3.9 needs to use typing_extensions.ParamSpec and typing_extensions.Concatenate
# Python 3.10+ can use typing.ParamSpec and typing.Concatenate because they are available in the standard library
if sys.version_info < (3, 10):
from typing_extensions import Concatenate, ParamSpec
else:
from typing import Concatenate, ParamSpec


if TYPE_CHECKING:
import modin.pandas # pragma: no cover
from table import Table # pragma: no cover


T = TypeVar("T")
P = ParamSpec("P")


_logger = getLogger(__name__)

_ONE_MILLION = 1000000
Expand Down Expand Up @@ -7099,6 +7113,39 @@ def print_schema(self, level: Optional[int] = None) -> None:
# naturalJoin = natural_join
# withColumns = with_columns

def pipe(
self,
function: Callable[Concatenate["DataFrame", P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Applies a function to the DataFrame and returns the result.

Args:
function: A user-defined function (UDF) to apply to the DataFrame.
*args: Additional positional arguments to pass to the UDF.
**kwargs: Additional keyword arguments to pass to the UDF.

Returns:
The result of applying the function to the DataFrame.

Example::

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> def test_function(df: DataFrame, col: str, threshold: float = 0):
... df = df.filter(df[col] > threshold)
... return df
>>> result = df.pipe(test_function, "a", threshold=1)
>>> result.show()
-------------
|"A" |"B" |
-------------
|3 |4 |
-------------
<BLANKLINE>
"""
return function(self, *args, **kwargs)


def map(
dataframe: DataFrame,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,28 @@ def test_dataFrame_printSchema(capfd, mock_server_connection):
)


def test_dataframe_pipe(session):
df: DataFrame = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])

# test normal function
def test_function(df: DataFrame, col: str, threshold: float = 0.0):
df = df.filter(df[col] > threshold)
return df.collect(), df.count()

result, expected_result = df.pipe(test_function, "a", threshold=1), test_function(
df, "a", 1
)

assert result == expected_result

# test lambda function
result, expected_result = df.pipe(lambda x: int(x.count())), (
lambda x: int(x.count())
)(df)

assert result == expected_result


def test_session():
fake_session = mock.create_autospec(Session, _session_id=123456)
fake_session._analyzer = mock.Mock()
Expand Down
Loading