Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
- binja: add mypy config for top-level binaryninja module to fix mypy issues @devs6186 #2399
- ci: deprecate macos-13 runner and use Python v3.13 for testing @mike-hunhoff #2777
- ci: pin pip-audit action SHAs and update to v1.1.0 @kami922 #1131
- tests: add tests and refactor API extraction and symbol generation helpers @pranavthakur0-0 #1899

### Raw diffs
- [capa v9.3.1...master](https://github.com/mandiant/capa/compare/v9.3.1...master)
Expand Down
26 changes: 15 additions & 11 deletions capa/features/extractors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,26 @@ def generate_symbols(dll: str, symbol: str, include_dll=False) -> Iterator[str]:
dll = dll[0:-4] if dll.endswith(".drv") else dll
dll = dll[0:-3] if dll.endswith(".so") else dll

if include_dll or is_ordinal(symbol):
# ws2_32.#1
# ordinal imports like ws2_32.#1 always include the DLL name
if is_ordinal(symbol):
yield f"{dll}.{symbol}"
return

# for non-ordinal symbols
if include_dll:
# kernel32.CreateFileA
yield f"{dll}.{symbol}"

if not is_ordinal(symbol):
# CreateFileA
yield symbol
# CreateFileA
yield symbol

if is_aw_function(symbol):
if include_dll:
# kernel32.CreateFile
yield f"{dll}.{symbol[:-1]}"
if is_aw_function(symbol):
if include_dll:
# kernel32.CreateFile
yield f"{dll}.{symbol[:-1]}"

# CreateFile
yield symbol[:-1]
# CreateFile
yield symbol[:-1]


def reformat_forwarded_export_name(forwarded_name: str) -> str:
Expand Down
13 changes: 8 additions & 5 deletions capa/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,18 @@ def pop_statement_description_entry(d):


def trim_dll_part(api: str) -> str:
# ordinal imports, like ws2_32.#1, keep dll
# ordinal imports, like ws2_32.#1, keep dll part
if ".#" in api:
return api

# kernel32.CreateFileA
# .NET namespace, like System.Diagnostics.Debugger::IsLogging, keep the namespace part
if "::" in api:
return api

# kernel32.CreateFileA -> CreateFileA
if api.count(".") == 1:
if "::" not in api:
# skip System.Convert::FromBase64String
api = api.split(".")[1]
api = api.split(".")[1]

return api


Expand Down
106 changes: 102 additions & 4 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import codecs

import capa.helpers
from capa.rules import trim_dll_part
from capa.features.extractors import helpers


Expand All @@ -30,14 +31,63 @@ def test_all_zeros():
assert helpers.all_zeros(d) is False


def test_is_aw_function():
# A-suffixed function
assert helpers.is_aw_function("CreateFileA") is True
# W-suffixed function
assert helpers.is_aw_function("CreateFileW") is True
# longer name ending with W
assert helpers.is_aw_function("LoadLibraryExW") is True

# does not end with A or W
assert helpers.is_aw_function("WriteFile") is False
assert helpers.is_aw_function("recv") is False

# too short (length < 2)
assert helpers.is_aw_function("A") is False
assert helpers.is_aw_function("W") is False
assert helpers.is_aw_function("") is False
Comment on lines +34 to +49
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new test functions like test_is_aw_function, test_is_ordinal, test_trim_dll_part, and test_reformat_forwarded_export_name are great additions. To improve their maintainability and reduce code duplication, consider using pytest.mark.parametrize. This would require adding import pytest to the file.

For example, test_is_aw_function could be refactored like this:

@pytest.mark.parametrize(
    "symbol, expected",
    [
        # A-suffixed function
        ("CreateFileA", True),
        # W-suffixed function
        ("CreateFileW", True),
        # longer name ending with W
        ("LoadLibraryExW", True),
        # does not end with A or W
        ("WriteFile", False),
        ("recv", False),
        # too short (length < 2)
        ("A", False),
        ("W", False),
        ("", False),
    ],
)
def test_is_aw_function(symbol, expected):
    assert helpers.is_aw_function(symbol) is expected

This approach makes it easier to add new test cases in the future. A similar pattern can be applied to the other new test functions, which would make the test suite more concise and easier to maintain.



def test_is_ordinal():
# ordinal symbols start with #
assert helpers.is_ordinal("#1") is True
assert helpers.is_ordinal("#42") is True

# normal symbol names
assert helpers.is_ordinal("CreateFileA") is False

# empty string
assert helpers.is_ordinal("") is False

# # not at the start
assert helpers.is_ordinal("foo#1") is False


def test_generate_symbols():
# .dll extension is stripped
assert list(helpers.generate_symbols("name.dll", "api", include_dll=True)) == list(
helpers.generate_symbols("name", "api", include_dll=True)
)
assert list(helpers.generate_symbols("name.dll", "api", include_dll=False)) == list(
helpers.generate_symbols("name", "api", include_dll=False)
)

# .drv extension is stripped
assert list(helpers.generate_symbols("winspool.drv", "OpenPrinterA", include_dll=True)) == list(
helpers.generate_symbols("winspool", "OpenPrinterA", include_dll=True)
)

# .so extension is stripped
assert list(helpers.generate_symbols("libc.so", "printf", include_dll=True)) == list(
helpers.generate_symbols("libc", "printf", include_dll=True)
)

# uppercase DLL name is lowercased
symbols = list(helpers.generate_symbols("KERNEL32", "CreateFileA", include_dll=True))
assert "kernel32.CreateFileA" in symbols
assert "KERNEL32.CreateFileA" not in symbols

# A/W import
symbols = list(helpers.generate_symbols("kernel32", "CreateFileA", include_dll=True))
assert len(symbols) == 4
Expand All @@ -46,7 +96,15 @@ def test_generate_symbols():
assert "CreateFileA" in symbols
assert "CreateFile" in symbols

# import
# W-suffixed import
symbols = list(helpers.generate_symbols("kernel32", "CreateFileW", include_dll=True))
assert len(symbols) == 4
assert "kernel32.CreateFileW" in symbols
assert "kernel32.CreateFile" in symbols
assert "CreateFileW" in symbols
assert "CreateFile" in symbols

# import (non-A/W)
symbols = list(helpers.generate_symbols("kernel32", "WriteFile", include_dll=True))
assert len(symbols) == 2
assert "kernel32.WriteFile" in symbols
Expand All @@ -57,23 +115,63 @@ def test_generate_symbols():
assert len(symbols) == 1
assert "ws2_32.#1" in symbols

# A/W api
# A/W api (no DLL prefix in output)
symbols = list(helpers.generate_symbols("kernel32", "CreateFileA", include_dll=False))
assert len(symbols) == 2
assert "CreateFileA" in symbols
assert "CreateFile" in symbols

# api
# api (non-A/W, no DLL prefix in output)
symbols = list(helpers.generate_symbols("kernel32", "WriteFile", include_dll=False))
assert len(symbols) == 1
assert "WriteFile" in symbols

# ordinal api
# ordinal api (DLL prefix always included for ordinals)
symbols = list(helpers.generate_symbols("ws2_32", "#1", include_dll=False))
assert len(symbols) == 1
assert "ws2_32.#1" in symbols


def test_trim_dll_part():
# normal DLL.API: strip DLL prefix
assert trim_dll_part("kernel32.CreateFileA") == "CreateFileA"

# ordinal import: keep as-is
assert trim_dll_part("ws2_32.#1") == "ws2_32.#1"

# .NET namespace with :: keep as-is
assert trim_dll_part("System.Convert::FromBase64String") == "System.Convert::FromBase64String"

# .NET multi-dot namespace with :: keep as-is
assert trim_dll_part("System.Diagnostics.Debugger::IsLogging") == "System.Diagnostics.Debugger::IsLogging"

# no dot: unchanged
assert trim_dll_part("CreateFileA") == "CreateFileA"

# multiple dots (count > 1), no :: unchanged
assert trim_dll_part("a.b.c.CreateFile") == "a.b.c.CreateFile"


def test_reformat_forwarded_export_name():
# uppercase DLL is lowercased, symbol is preserved verbatim
assert helpers.reformat_forwarded_export_name("NTDLL.RtlAllocateHeap") == "ntdll.RtlAllocateHeap"

# already lowercase
assert helpers.reformat_forwarded_export_name("kernel32.HeapAlloc") == "kernel32.HeapAlloc"

# DLL name with hyphens
assert (
helpers.reformat_forwarded_export_name("api-ms-win-core-file-l1-1-0.CreateFileW")
== "api-ms-win-core-file-l1-1-0.CreateFileW"
)

# full path with embedded dots: rpartition splits on last dot
assert (
helpers.reformat_forwarded_export_name("C:\\Windows\\NTDLL.RtlAllocateHeap")
== "c:\\windows\\ntdll.RtlAllocateHeap"
)


def test_is_dev_environment():
# testing environment should be a dev environment
assert capa.helpers.is_dev_environment() is True
Loading