|
1 | 1 | # type: ignore |
2 | 2 |
|
| 3 | +from types import SimpleNamespace |
| 4 | + |
3 | 5 | import pytest |
4 | 6 | import sqlparse |
5 | 7 | from sqlparse.sql import Identifier, IdentifierList, Token, TokenList |
@@ -563,6 +565,98 @@ def split(self): |
563 | 565 | assert need_completion_reset('ignored') is False |
564 | 566 |
|
565 | 567 |
|
| 568 | +def test_classify_sandbox_statement_treats_token_error_as_quit(monkeypatch): |
| 569 | + def raise_token_error(*_args, **_kwargs): |
| 570 | + raise sql_utils.sqlglot.errors.TokenError('bad token') |
| 571 | + |
| 572 | + monkeypatch.setattr(sql_utils.sqlglot, 'tokenize', raise_token_error) |
| 573 | + |
| 574 | + assert sql_utils.classify_sandbox_statement('`') == ('quit', None) |
| 575 | + |
| 576 | + |
| 577 | +def test_classify_sandbox_statement_treats_empty_tokens_as_quit(monkeypatch): |
| 578 | + monkeypatch.setattr(sql_utils.sqlglot, 'tokenize', lambda *_args, **_kwargs: []) |
| 579 | + |
| 580 | + assert sql_utils.classify_sandbox_statement('ignored') == ('quit', None) |
| 581 | + |
| 582 | + |
| 583 | +def test_find_password_after_eq_returns_none_for_non_string_token() -> None: |
| 584 | + token_type = sql_utils.sqlglot.tokens.TokenType |
| 585 | + tokens = [ |
| 586 | + SimpleNamespace(token_type=token_type.EQ, text='='), |
| 587 | + SimpleNamespace(token_type=token_type.VAR, text='CURRENT_USER'), |
| 588 | + ] |
| 589 | + |
| 590 | + assert sql_utils._find_password_after_eq(tokens) is None |
| 591 | + |
| 592 | + |
| 593 | +@pytest.mark.parametrize( |
| 594 | + ('text', 'expected'), |
| 595 | + [ |
| 596 | + ('', ('quit', None)), |
| 597 | + (' ', ('quit', None)), |
| 598 | + ('quit', ('quit', None)), |
| 599 | + ('exit', ('quit', None)), |
| 600 | + ('\\q', ('quit', None)), |
| 601 | + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", ('alter_user', 'new')), |
| 602 | + ('ALTER USER root IDENTIFIED WITH mysql_native_password', ('alter_user', None)), |
| 603 | + ("SET PASSWORD = 'newpass'", ('set_password', 'newpass')), |
| 604 | + ('SELECT 1', (None, None)), |
| 605 | + ], |
| 606 | +) |
| 607 | +def test_classify_sandbox_statement(text: str, expected: tuple[str | None, str | None]) -> None: |
| 608 | + assert sql_utils.classify_sandbox_statement(text) == expected |
| 609 | + |
| 610 | + |
| 611 | +@pytest.mark.parametrize( |
| 612 | + ('text', 'expected'), |
| 613 | + [ |
| 614 | + ('', True), |
| 615 | + (' ', True), |
| 616 | + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), |
| 617 | + ('alter user root identified by "pw"', True), |
| 618 | + ("SET PASSWORD = 'newpass'", True), |
| 619 | + ("set password = 'newpass'", True), |
| 620 | + ('quit', True), |
| 621 | + ('exit', True), |
| 622 | + ('\\q', True), |
| 623 | + ('SELECT 1', False), |
| 624 | + ('DROP TABLE t', False), |
| 625 | + ('USE mydb', False), |
| 626 | + ('SHOW DATABASES', False), |
| 627 | + ], |
| 628 | +) |
| 629 | +def test_is_sandbox_allowed(text: str, expected: bool) -> None: |
| 630 | + assert sql_utils.is_sandbox_allowed(text) is expected |
| 631 | + |
| 632 | + |
| 633 | +@pytest.mark.parametrize( |
| 634 | + ('text', 'expected'), |
| 635 | + [ |
| 636 | + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), |
| 637 | + ("SET PASSWORD = 'newpass'", True), |
| 638 | + ('SELECT 1', False), |
| 639 | + ('quit', False), |
| 640 | + ], |
| 641 | +) |
| 642 | +def test_is_password_change(text: str, expected: bool) -> None: |
| 643 | + assert sql_utils.is_password_change(text) is expected |
| 644 | + |
| 645 | + |
| 646 | +@pytest.mark.parametrize( |
| 647 | + ('text', 'expected'), |
| 648 | + [ |
| 649 | + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'", 'newpass'), |
| 650 | + ("SET PASSWORD = 'secret123'", 'secret123'), |
| 651 | + ("ALTER USER root IDENTIFIED BY 'p@ss w0rd!'", 'p@ss w0rd!'), |
| 652 | + ('ALTER USER root IDENTIFIED WITH mysql_native_password', None), |
| 653 | + ('SELECT 1', None), |
| 654 | + ], |
| 655 | +) |
| 656 | +def test_extract_new_password(text: str, expected: str | None) -> None: |
| 657 | + assert sql_utils.extract_new_password(text) == expected |
| 658 | + |
| 659 | + |
566 | 660 | @pytest.mark.parametrize( |
567 | 661 | ('status_plain', 'expected'), |
568 | 662 | [ |
|
0 commit comments