import io import json import os import base64 import oci import requests import subprocess import tempfile import time import logging import secrets import string import snowflake.connector # ====================================================== # Logging setup # ====================================================== log_level = os.environ.get("log_level", "INFO").upper() logger = logging.getLogger() logger.setLevel(getattr(logging, log_level, logging.INFO)) # ====================================================== # Helper functions # ====================================================== def generate_random_password(length=24): #Generates a random password using safe special characters safe_specials = "-_" alphabet = string.ascii_letters + string.digits + safe_specials return "".join(secrets.choice(alphabet) for _ in range(length)) def b64_decode_to_str(b64_or_plain: str) -> str: #Decodes de secrets from base64 to string try: return base64.b64decode(b64_or_plain).decode("utf-8") except Exception: return b64_or_plain def write_secure_file(path: str, content: str, mode: int = 0o600): #Saves the files with secure permissions with open(path, "w", encoding="utf-8") as f: f.write(content) os.chmod(path, mode) def run_openssl(cmd: str): #Run ssl commands for key generation logger.debug(f'Running openssl: %s', cmd) subprocess.run(cmd, shell=True, check=True) def extract_pub_from_priv(priv_path: str, passphrase: str | None, pub_out: str): # Extract public key from private key if passphrase: run_openssl( f"openssl pkcs8 -in {priv_path} -nocrypt -passin pass:{passphrase} " f"-topk8 -out /tmp/tmp_unencrypted.p8 2>/dev/null || true" ) run_openssl(f"openssl rsa -in /tmp/tmp_unencrypted.p8 -pubout -out {pub_out}") if os.path.exists("/tmp/tmp_unencrypted.p8"): os.remove("/tmp/tmp_unencrypted.p8") else: run_openssl(f"openssl rsa -in {priv_path} -pubout -out {pub_out}") def normalize_pubkey_text(pub_text: str) -> str: # Normalize PEM public key text lines = [l.strip() for l in pub_text.splitlines() if l.strip() and "PUBLIC KEY" not in l] return "".join(lines) # ====================================================== # Main OCI Function handler # ====================================================== def handler(ctx, data: io.BytesIO = None): logger.info(f'=== Starting GoldenGate → Snowflake key rotation ===') # Load config from Fn environment variables gg_cfg = { "url": os.environ["goldengate_url"], "user": os.environ["goldengate_user"], "password_secret_ocid": os.environ["goldengate_password_secret_ocid"], "replicat_name": os.environ["goldengate_replicat_name"], "connection_id": os.environ["goldengate_connection_id"], } sf_cfg = { "account": os.environ["snowflake_account"], "user": os.environ["snowflake_user"], "warehouse": os.environ.get("snowflake_warehouse"), "database": os.environ.get("snowflake_database"), "schema": os.environ.get("snowflake_schema"), "role": os.environ.get("snowflake_role"), } vault_cfg = { "private_key_secret_ocid": os.environ["vault_private_key_secret_ocid"], "key_password_secret_ocid": os.environ["vault_key_password_secret_ocid"], } environment = os.environ["environment"] notif_topic_ocid = os.environ["notifications_topic_ocid"] # -------------------------------------------------- # OCI Clients # -------------------------------------------------- signer = oci.auth.signers.get_resource_principals_signer() secrets_client = oci.secrets.SecretsClient(config={}, signer=signer) gg_client = oci.golden_gate.GoldenGateClient(config={}, signer=signer) vault_client = oci.vault.VaultsClient(config={}, signer=signer) notification_client = oci.ons.NotificationDataPlaneClient(config={}, signer=signer) # -------------------------------------------------- # Utility functions # -------------------------------------------------- def send_notification(message: str, subject: str): #Sends notifications via OCI topic try: notification_client.publish_message( topic_id=notif_topic_ocid, message_details=oci.ons.models.MessageDetails(body=message, title=subject), ) logger.info(f'Notification sent: %s', subject) except Exception: logger.exception("Failed to send notification") def get_secret_value(secret_ocid: str) -> str: #Retrieve and decode a secret value from OCI Vault try: logger.info(f"Fetching secret value from Vault (OCID: {secret_ocid})") sb = secrets_client.get_secret_bundle(secret_ocid).data return b64_decode_to_str(sb.secret_bundle_content.content) except Exception as e: logger.exception(f"Failed to retrieve secret from Vault (OCID: {secret_ocid})") raise RuntimeError(f"Failed to retrieve secret from Vault (OCID: {secret_ocid}): {e}") from e def update_secret(secret_id: str, plain_value: str): #Update a secret value in OCI Vault try: logger.info(f"Updating secret in Vault (OCID: {secret_id})") encoded = base64.b64encode(plain_value.encode("utf-8")).decode("utf-8") vault_client.update_secret( secret_id=secret_id, update_secret_details=oci.vault.models.UpdateSecretDetails( secret_content=oci.vault.models.Base64SecretContentDetails( content=encoded, content_type="BASE64" ) ), ) logger.info(f"Secret updated successfully (OCID: {secret_id})") except Exception as e: logger.exception(f"Failed to update secret in Vault (OCID: {secret_id})") raise RuntimeError(f"Failed to update secret in Vault (OCID: {secret_id}): {e}") from e def delete_secret_version(secret_id: str, version_number: int): #Delete old secret version in OCI Vault try: logger.info(f"Scheduling deletion of secret version {version_number} for secret {secret_id}") vault_client.schedule_secret_version_deletion( secret_id=secret_id, secret_version_number=version_number, schedule_secret_version_deletion_details= oci.vault.models.ScheduleSecretVersionDeletionDetails() ) logger.info(f"Secret version {version_number} deletion scheduled") except Exception as e: logger.exception(f"Failed to delete secret version {version_number} for secret {secret_id}") raise def get_old_version_to_delete(secret_id: str) -> int | None: #Returns the version number that should be deleted. Only delete the 2nd-oldest version. Never delete the current version or its previous one try: versions = vault_client.list_secret_versions(secret_id).data versions_sorted = sorted(versions, key=lambda v: v.version_number) if len(versions_sorted) > 2: return versions_sorted[-3].version_number logger.info(f"No deletable secret versions found for secret {secret_id}.") return None except Exception as e: logger.exception(f"Failed to list secret versions for secret {secret_id}") raise RuntimeError(f"Failed to list secret versions for secret {secret_id}: {e}") from e def wait_for_secret_ready(secret_id: str, timeout=120, interval=10): #Waits until the secret lifecycle state is not UPDATING. Prevents IncorrectState errors during version deletion. start = time.time() try: while time.time() - start < timeout: try: secret = vault_client.get_secret(secret_id).data state = secret.lifecycle_state logger.info(f"Secret {secret_id} lifecycle state: {state}") if state != "UPDATING": return # Secret is now ready except Exception as inner_e: logger.warning(f"Error checking state for secret {secret_id}: {inner_e}") time.sleep(interval) raise TimeoutError( f"Secret {secret_id} stuck in UPDATING state for more than {timeout} seconds." ) except Exception as e: logger.exception(f"Failed while waiting for secret {secret_id} to leave UPDATING state") raise RuntimeError(f"wait_for_secret_ready failed for secret {secret_id}: {e}") from e def gg_command(cmd_name: str): #Execute a GoldenGate API REST command try: logger.info("Executing GoldenGate command: %s", cmd_name) payload = { "name": cmd_name, "processName": gg_cfg["replicat_name"] } auth = (gg_cfg["user"], get_secret_value(gg_cfg["password_secret_ocid"])) # Longer timeout ONLY for stop timeout = (5, 60) if cmd_name.lower() == "stop" else 30 try: r = requests.post( f"{gg_cfg['url'].rstrip('/')}/services/v2/commands/execute", auth=auth, headers={"Content-Type": "application/json"}, json=payload, timeout=timeout, ) r.raise_for_status() return r.json() except requests.exceptions.ReadTimeout: if cmd_name.lower() != "stop": raise logger.warning( "STOP command timed out; assuming replicat is stopping and polling status" ) wait_for_replicat_stopped() return {"status": "stopped"} except Exception as e: logger.exception("GoldenGate command '%s' failed", cmd_name) raise RuntimeError(f"GoldenGate command '{cmd_name}' failed: {e}") from e def get_replicat_status(): #Fetch the current GoldenGate replicat status. Returns values such as: running, stopping, stopped. try: r = requests.get( f"{gg_cfg['url'].rstrip('/')}/services/v2/replicats/{gg_cfg['replicat_name']}/info/status", auth=(gg_cfg["user"], get_secret_value(gg_cfg["password_secret_ocid"])), timeout=30, ) r.raise_for_status() try: status = r.json()["response"]["status"] return status except (KeyError, ValueError) as parse_e: logger.error( f"Unexpected replicat status response format: {r.text}" ) raise RuntimeError( "Invalid response format while fetching replicat status" ) from parse_e except Exception as e: logger.exception("Failed to fetch GoldenGate replicat status") raise RuntimeError(f"get_replicat_status failed: {e}") from e def wait_for_replicat_stopped(timeout=900, interval=10): #Wait until the replicat status becomes stopped. start = time.time() try: while time.time() - start < timeout: try: status = get_replicat_status() logger.info(f"Replicat status: {status}") if status == "stopped": logger.info("Replicat stopped successfully") return except Exception as inner_e: # Temporary failures should not abort the stop sequence logger.warning( f"Failed to fetch replicat status; will retry: {inner_e}" ) time.sleep(interval) raise TimeoutError( f"Replicat did not reach stopped state within {timeout} seconds" ) except Exception as e: logger.exception("Error while waiting for replicat to stop") raise RuntimeError(f"wait_for_replicat_stopped failed: {e}") from e def refresh_gg_connection(connection_id: str, private_key_file_secret_id: str, private_key_passphrase_secret_id: str): #Refresh GoldenGate connection to reload Vault secret references logger.info('Refreshing GoldenGate connection to reload secret references...') try: refresh_details = oci.golden_gate.models.DefaultRefreshConnectionDetails(type="DEFAULT") response = gg_client.refresh_connection( connection_id=connection_id, refresh_connection_details=refresh_details ) work_request_id = response.headers.get("opc-work-request-id") if not work_request_id: raise RuntimeError("No work request ID found in GoldenGate refresh response.") max_wait_secs = 600 poll_interval = 10 waited = 0 while waited < max_wait_secs: wr = gg_client.get_work_request(work_request_id).data status = wr.status pct = wr.percent_complete or 0 logger.info(f'Work request status: {status} ({pct:.0f}%)') if status in ("SUCCEEDED", "FAILED", "CANCELED"): break time.sleep(poll_interval) waited += poll_interval if status != "SUCCEEDED": raise RuntimeError(f"GoldenGate connection refresh failed with status: {status}") logger.info('GoldenGate connection refreshed successfully — secret references reloaded.') except Exception as e: logger.exception("Failed to refresh GoldenGate connection vault references") raise RuntimeError(f"refresh_gg_connection failed: {e}") from e # -------------------------------------------------- # Snowflake helpers # -------------------------------------------------- def connect_snowflake_with_jwt(key_file: str, key_passphrase: str): #Connect to snowflake return snowflake.connector.connect( account=sf_cfg["account"], user=sf_cfg["user"], authenticator="SNOWFLAKE_JWT", private_key_file=key_file, private_key_file_pwd=key_passphrase, warehouse=sf_cfg.get("warehouse"), database=sf_cfg.get("database"), schema=sf_cfg.get("schema"), role=sf_cfg.get("role"), ) def get_snowflake_current_pubkey(key_file: str, key_pass: str) -> str: #Get current snowflake RSA_PUBLIC_KEY conn = connect_snowflake_with_jwt(key_file, key_pass) cur = conn.cursor() try: cur.execute(f"DESCRIBE USER IDENTIFIER('{sf_cfg['user']}')") for r in cur.fetchall(): if len(r) >= 2 and "RSA_PUBLIC_KEY" in r[0]: return r[1] or "" return "" finally: cur.close() conn.close() def set_snowflake_key_for_connected_user(key_file: str, key_pass: str, public_key_pem: str, key_slot="RSA_PUBLIC_KEY"): #Sets RSA_PUBLIC_KEY and/or RSA_PUBLIC_KEY_2 on snowflake conn = connect_snowflake_with_jwt(key_file, key_pass) cur = conn.cursor() try: key_stripped = normalize_pubkey_text(public_key_pem) if public_key_pem else "" sql = f"CALL ADMIN_UTILS.ENV_MANAGEMENT.SET_RSA_KEY_FOR_CONNECTED_USER('{key_slot}', '{key_stripped}')" logger.debug(f"Executing Snowflake SQL: {sql}") cur.execute(sql) finally: cur.close() conn.close() def backup_snowflake_pubkey(curr_priv_path, curr_pass): #Backup current Snowflake public key to RSA_PUBLIC_KEY_2 try: logger.info("Backing up current Snowflake RSA key to RSA_PUBLIC_KEY_2") pubkey = get_snowflake_current_pubkey(curr_priv_path, curr_pass) set_snowflake_key_for_connected_user(curr_priv_path, curr_pass, pubkey, key_slot="RSA_PUBLIC_KEY_2") except Exception as e: logger.exception("Failed to backup Snowflake public key") raise RuntimeError(f"backup_snowflake_pubkey failed: {e}") from e def clear_secondary_pubkey(curr_priv_path, curr_pass): #Clear RSA_PUBLIC_KEY_2 try: logger.info("Clearing RSA_PUBLIC_KEY_2 after successful connection") set_snowflake_key_for_connected_user(curr_priv_path, curr_pass, "", key_slot="RSA_PUBLIC_KEY_2") except Exception as e: logger.exception("Failed to clear RSA_PUBLIC_KEY_2") raise RuntimeError(f"clear_secondary_pubkey failed: {e}") from e # -------------------------------------------------- # Rotation flow # -------------------------------------------------- try: current_step = "Stopping replicat" logger.info(current_step) gg_command("stop") current_step = "Fetching current private key and passphrase from Vault" logger.info(current_step) current_priv = get_secret_value(vault_cfg["private_key_secret_ocid"]) current_pass = get_secret_value(vault_cfg["key_password_secret_ocid"]) with tempfile.TemporaryDirectory() as tmpdir: curr_priv_path = os.path.join(tmpdir, "current_key.p8") write_secure_file(curr_priv_path, current_priv) # Generate new keypair current_step = "Generating new keypair" logger.info(current_step) new_passphrase = generate_random_password() new_priv_path = os.path.join(tmpdir, "new_key.p8") new_pub_path = os.path.join(tmpdir, "new_key.pub") run_openssl( f"openssl genrsa 2048 | openssl pkcs8 -topk8 -v2 des3 -inform PEM " f"-out {new_priv_path} -passout pass:{new_passphrase}" ) run_openssl(f"openssl rsa -in {new_priv_path} -passin pass:{new_passphrase} -pubout -out {new_pub_path}") new_pub = open(new_pub_path).read() new_priv = open(new_priv_path).read() # Extract current public key current_step = "Extracting current public key" logger.info(current_step) curr_pub_path = os.path.join(tmpdir, "current_key.pub") extract_pub_from_priv(curr_priv_path, current_pass, curr_pub_path) curr_pub = open(curr_pub_path).read() # Validate and backup current_step = "Validating Snowflake public key matches Vault private key" logger.info(current_step) sf_pub = get_snowflake_current_pubkey(curr_priv_path, current_pass) if normalize_pubkey_text(sf_pub) != normalize_pubkey_text(curr_pub): logger.error("Mismatch between Snowflake and Vault keys") raise RuntimeError("Snowflake public key mismatch with stored private key") current_step = "Backing up current Snowflake key" logger.info(current_step) backup_snowflake_pubkey(curr_priv_path, current_pass) # Set new key current_step = "Updating Snowflake with new RSA key" logger.info(current_step) set_snowflake_key_for_connected_user(curr_priv_path, current_pass, new_pub) # Validate new key current_step = "Validating Snowflake connection with new key" logger.info(current_step) try: conn = connect_snowflake_with_jwt(new_priv_path, new_passphrase) conn.close() clear_secondary_pubkey(curr_priv_path, current_pass) logger.info("New key validated successfully.") except Exception: set_snowflake_key_for_connected_user(curr_priv_path, current_pass, curr_pub) logger.exception("Failed to validate new key; rolled back") raise RuntimeError("Failed to validate new key; rolled back") # Manage Vault current_step = "Managing Vault secrets" logger.info(current_step) update_secret(vault_cfg["private_key_secret_ocid"], new_priv) update_secret(vault_cfg["key_password_secret_ocid"], new_passphrase) #Waits until secrets finish updating wait_for_secret_ready(vault_cfg["private_key_secret_ocid"]) wait_for_secret_ready(vault_cfg["key_password_secret_ocid"]) # Save old secret version numbers old_private_version = get_old_version_to_delete(vault_cfg["private_key_secret_ocid"]) old_pass_version = get_old_version_to_delete(vault_cfg["key_password_secret_ocid"]) # Delete old secret versions after successful rotation if old_private_version: delete_secret_version(vault_cfg["private_key_secret_ocid"], old_private_version) else: logger.info("No deletable private key version found.") if old_pass_version: delete_secret_version(vault_cfg["key_password_secret_ocid"], old_pass_version) else: logger.info("No deletable passphrase version found.") # Refresh GoldenGate connection current_step = "Refreshing GoldenGate connection" logger.info(current_step) refresh_gg_connection( gg_cfg["connection_id"], vault_cfg["private_key_secret_ocid"], vault_cfg["key_password_secret_ocid"] ) current_step = "Starting replicat" logger.info(current_step) gg_command("start") send_notification( f"[{environment}] GoldenGate → Snowflake key rotation completed successfully ✅", f"[{environment}] Key Rotation Success" ) logger.info("=== Key rotation completed successfully ===") return json.dumps({"result": "success"}) except Exception as e: logger.exception("Key rotation failed") subject = f"Key Rotation Failure - {current_step}" subject = f"[{environment}] Key Rotation Failure - {current_step}" message = f"Key rotation failed during step '{current_step}' with error: {e}" send_notification(message, subject) return json.dumps({"result": "failure", "step": current_step, "error": str(e)})