Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mwaa.utils.get_rds_iam_credentials import RDSIAMCredentialProvider

DB_IAM_USERNAME = "airflow_user"
DB_ADMIN_USERNAME = "adminuser"
DB_NAME = "AirflowMetadata"

# Usually, we pass the `__name__` variable instead as that defaults to the module path,
Expand All @@ -44,32 +45,32 @@ def _ensure_rds_iam_user():
try:
# Set db_connection_url using RDS IAM credentials
try:
# On default, try to connect to RDS using IAM authentication
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)

logger.info("Creating engine using RDS IAM and validating connection")
# On default, try to connect to engine using admin user to create/update airflow_user
logger.info("Creating db_connection_url using static credentials")
db_connection_url = get_db_connection_string()
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
# Test that the connection is working
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")
logger.info("Engine created using static credentials")

except Exception as e:
# If RDS IAM authentication fails, connect with static credentials
# This is needed on environment creation since airflow_user is not created yet
logger.warning(f"Exception type: {type(e).__name__}, message: {e}")
db_connection_url = get_db_connection_string()
logger.warning("Engine creation using RDS IAM failed... Attempting to create engine using static credentials")
# If adminuser connection fails due to RDS IAM set up, then use RDS IAM for connection
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)
logger.info("Creating engine using RDS IAM and validating connection")
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
logger.info("Engine created using static credentials")
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")

with db_engine.connect() as conn:
with conn.begin():
result = conn.execute(text("SELECT 1 FROM pg_roles WHERE rolname = :rolename"), {"rolename": DB_IAM_USERNAME})
Expand All @@ -80,16 +81,25 @@ def _ensure_rds_iam_user():
else:
logger.info(f"db rds iam user already exists")

# Always ensure permissions are up to date
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
current_role = conn.execute(
text("SELECT current_user")
).scalar()

if current_role == DB_ADMIN_USERNAME:
# Always ensure permissions are up to date
logger.info(f"Current role is {DB_ADMIN_USERNAME}, setting up permissions for airflow_user")
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT {DB_ADMIN_USERNAME} TO {DB_IAM_USERNAME}"))

elif current_role == "airflow_user":
logger.info("Current role is airflow_user")
except Exception as e:
logger.warning(f"Error while ensuring rds iam db credentials, skipping. {e}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mwaa.utils.get_rds_iam_credentials import RDSIAMCredentialProvider

DB_IAM_USERNAME = "airflow_user"
DB_ADMIN_USERNAME = "adminuser"
DB_NAME = "AirflowMetadata"

# Usually, we pass the `__name__` variable instead as that defaults to the module path,
Expand All @@ -44,32 +45,31 @@ def _ensure_rds_iam_user():
try:
# Set db_connection_url using RDS IAM credentials
try:
# On default, try to connect to RDS using IAM authentication
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)

logger.info("Creating engine using RDS IAM and validating connection")
# On default, try to connect to engine using admin user to create/update airflow_user
logger.info("Creating db_connection_url using static credentials")
db_connection_url = get_db_connection_string()
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
# Test that the connection is working
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")
logger.info("Engine created using static credentials")

except Exception as e:
# If RDS IAM authentication fails, connect with static credentials
# This is needed on environment creation since airflow_user is not created yet
logger.warning(f"Exception type: {type(e).__name__}, message: {e}")
db_connection_url = get_db_connection_string()
logger.warning("Engine creation using RDS IAM failed... Attempting to create engine using static credentials")
# If adminuser connection fails due to RDS IAM set up, then use RDS IAM for connection
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)
logger.info("Creating engine using RDS IAM and validating connection")
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
logger.info("Engine created using static credentials")
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")

with db_engine.connect() as conn:
with conn.begin():
Expand All @@ -81,16 +81,26 @@ def _ensure_rds_iam_user():
else:
logger.info(f"db rds iam user already exists")

# Always ensure permissions are up to date
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
current_role = conn.execute(
text("SELECT current_user")
).scalar()

if current_role == DB_ADMIN_USERNAME:
# Always ensure permissions are up to date
logger.info(f"Current role is {DB_ADMIN_USERNAME}, setting up permissions for airflow_user")
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT {DB_ADMIN_USERNAME} TO {DB_IAM_USERNAME}"))

elif current_role == "airflow_user":
logger.info("Current role is airflow_user")
except Exception as e:
logger.warning(f"Error while ensuring rds iam db credentials, skipping. {e}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mwaa.utils.get_rds_iam_credentials import RDSIAMCredentialProvider

DB_IAM_USERNAME = "airflow_user"
DB_ADMIN_USERNAME = "adminuser"
DB_NAME = "AirflowMetadata"

# Usually, we pass the `__name__` variable instead as that defaults to the module path,
Expand All @@ -44,32 +45,31 @@ def _ensure_rds_iam_user():
try:
# Set db_connection_url using RDS IAM credentials
try:
# On default, try to connect to RDS using IAM authentication
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)

logger.info("Creating engine using RDS IAM and validating connection")
# On default, try to connect to engine using admin user to create/update airflow_user
logger.info("Creating db_connection_url using static credentials")
db_connection_url = get_db_connection_string()
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
# Test that the connection is working
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")
logger.info("Engine created using static credentials")

except Exception as e:
# If RDS IAM authentication fails, connect with static credentials
# This is needed on environment creation since airflow_user is not created yet
logger.warning(f"Exception type: {type(e).__name__}, message: {e}")
db_connection_url = get_db_connection_string()
logger.warning("Engine creation using RDS IAM failed... Attempting to create engine using static credentials")
# If adminuser connection fails due to RDS IAM set up, then use RDS IAM for connection
logger.info("Creating db_connection_url using RDS IAM credentials")
token = RDSIAMCredentialProvider.get_token()
db_connection_url = RDSIAMCredentialProvider.create_db_connection_url(token)
logger.info("Creating engine using RDS IAM and validating connection")
db_engine = create_engine(
db_connection_url,
connect_args={"connect_timeout": 3}
)
logger.info("Engine created using static credentials")
with db_engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Engine created using RDS IAM and connection validated")

with db_engine.connect() as conn:
with conn.begin():
Expand All @@ -81,16 +81,26 @@ def _ensure_rds_iam_user():
else:
logger.info(f"db rds iam user already exists")

# Always ensure permissions are up to date
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
current_role = conn.execute(
text("SELECT current_user")
).scalar()

if current_role == DB_ADMIN_USERNAME:
# Always ensure permissions are up to date
logger.info(f"Current role is {DB_ADMIN_USERNAME}, setting up permissions for airflow_user")
conn.execute(text(f"GRANT rds_iam TO {DB_IAM_USERNAME}"))
conn.execute(text(f'GRANT ALL PRIVILEGES ON DATABASE "{DB_NAME}" TO {DB_IAM_USERNAME}'))
conn.execute(text(f"GRANT ALL ON SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL TABLES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT ALL ON ALL FUNCTIONS IN SCHEMA public TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {DB_IAM_USERNAME}"))
conn.execute(text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON FUNCTIONS TO {DB_IAM_USERNAME}"))
conn.execute(text(f"GRANT {DB_ADMIN_USERNAME} TO {DB_IAM_USERNAME}"))

elif current_role == "airflow_user":
logger.info("Current role is airflow_user")
except Exception as e:
logger.warning(f"Error while ensuring rds iam db credentials, skipping. {e}")

Expand Down
Loading
Loading