5 min read

Connecting to Snowflake with Scala Spark Using Key Pair Authentication

Connecting to Snowflake with Scala Spark Using Key Pair Authentication
“Every winter has its spring.” H. Tuttle

Snowflake is a popular cloud data platform known for its scalability and performance. Establishing a connection to Snowflake can be done using various methods, such as username and password, OAuth, and key pair authentication. In this post, we will focus on connecting to Snowflake using Scala Spark with key pair authentication.

While this post focuses on key pair authentication, Snowflake supports:

  • Username and Password: Simple but less secure for production.
  • OAuth: Used for integration with external identity providers.
  • External Browser: For connecting via a browser session.

Understanding Key Pair Authentication

Key pair authentication is a secure method for connecting to Snowflake without using traditional passwords. It involves a private key stored securely on your system and a public key registered with Snowflake. The private key is used to authenticate the user and establish a secure connection.

Steps to Create a Private Key

  1. Generate the Key Pair: You can use OpenSSL to generate a private key in .p8 format:
openssl genpkey -algorithm RSA -out private_key.p8 -pkeyopt rsa_keygen_bits:2048

This generates a non-encrypted private key.

  1. Encrypt the Private Key (Optional): Use OpenSSL to add a passphrase for added security:
openssl pkcs8 -in private_key.p8 -topk8 -out encrypted_private_key.p8 -v2 aes-256-cbc

Replace aes-256-cbc with your preferred encryption algorithm.

  1. Register the Public Key: Extract the public key from the private key and register it with your Snowflake user account:
openssl rsa -in private_key.p8 -pubout -out public_key.pem

Use the Snowflake UI or SQL commands to register the public key.

Code Example

Key Components of the Code

[full code below the post]

  1. Dependencies and Imports: The code relies on libraries such as BouncyCastle for handling private keys, SparkSession for Spark operations, and DriverManager for JDBC connections.
import org.bouncycastle.openssl.{PEMParser}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
  1. Configuration Map: The configuration map contains essential details such as the Snowflake account, region, and private key file path.
val configMap: Map[String, String] = Map(
  ACCOUNT        -> "<account_name>",
  REGION         -> "<region>",
  USERNAME       -> "<username>",
  ROLE           -> "<role>",
  WAREHOUSE      -> "<warehouse>",
  DATABASE       -> "<database>",
  SCHEMA         -> "<schema>",
  PK_PATH        -> "<private_key_path>",
  PK_PASSPHRASE  -> "<private_key_passphrase>"
)
  1. Private Key Handling: The getPrivateKey method reads and decrypts the private key file. It uses BouncyCastle for secure handling of encrypted and non-encrypted keys.
def getPrivateKey(filename: String, passphrase: String): PrivateKey = {
    val pemParser = new PEMParser(new FileReader(filename))
    val pemObject = pemParser.readObject()

    val privateKeyInfo = pemObject match {
        case encrypted: PKCS8EncryptedPrivateKeyInfo =>
            val decryptorProvider = new JceOpenSSLPKCS8DecryptorProviderBuilder()
              .build(passphrase.toCharArray)
            encrypted.decryptPrivateKeyInfo(decryptorProvider)
        case keyInfo: PrivateKeyInfo => keyInfo
        case _ => throw new IllegalArgumentException("Unsupported key format")
    }

    pemParser.close()
    new JcaPEMKeyConverter().getPrivateKey(privateKeyInfo)
}
  1. Encoding the Private Key: Spark expects the private key in Base64 format. The getEncoded method encodes the private key appropriately.
def getEncoded(privateKey: PrivateKey): String = {
    Base64.getMimeEncoder(64, "\n".getBytes).encodeToString(privateKey.getEncoded)
}
  1. Building the Spark Config Map: The getSparkConfigMap method prepares the Snowflake connection parameters required for Spark to interact with Snowflake.
private val sparkConfigMap: Map[String, String] = Map(
  "sfURL" -> s"${config(ACCOUNT)}.${config(REGION)}.snowflakecomputing.com",
  "sfUser" -> config(USERNAME),
  "pem_private_key" -> getEncoded(getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
)
  1. Spark Read and Write Operations:
    1. read fetches data from Snowflake based on a SQL query.
    2. write inserts data into a Snowflake table.
def read(session: SparkSession, query: String): Dataset[Row] = {
    session.read
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(sparkConfigMap.asJava)
      .option("query", query)
      .load()
}

def write(df: Dataset[_], tableName: String, mode: SaveMode): Unit = {
    df.write
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(sparkConfigMap.asJava)
      .option("dbtable", tableName)
      .mode(mode)
      .save()
}
  1. JDBC Connection: The getConnection method demonstrates a direct JDBC connection to Snowflake, which is useful for testing or advanced operations.
def getConnection: Connection = {
    val jdbcUrl = s"jdbc:snowflake://${config(ACCOUNT)}.${config(REGION)}.snowflakecomputing.com/"
    val properties = new Properties()
    properties.put("user", config(USERNAME))
    properties.put("privateKey", getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))

    DriverManager.getConnection(jdbcUrl, properties)
}
  1. Main Method: The main method tests the connection and runs a sample query.
def main(args: Array[String]): Unit = {
    val snowflakeConnection = new SnowflakeConnection(configMap)
    val df = snowflakeConnection.read(spark, "SELECT * FROM TEST_TABLE LIMIT 10;")
    df.show()
}

Conclusion

This implementation demonstrates how to securely connect to Snowflake using Scala Spark and key pair authentication. The code ensures security by handling private keys with encryption and uses the Spark-Snowflake connector for seamless data operations. While key pair authentication is highly secure, you can explore other methods depending on your use case.

package com.innovid.identityTask

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.openssl.{PEMParser, jcajce}
import org.bouncycastle.openssl.jcajce.{JcaPEMKeyConverter, JceOpenSSLPKCS8DecryptorProviderBuilder}
import org.bouncycastle.operator.InputDecryptorProvider
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo

import java.io.FileReader
import java.nio.file.Paths
import java.security.{PrivateKey, Security}
import java.sql.Connection
import java.sql.DriverManager
import java.util.{Base64, Properties}

import scala.collection.JavaConverters._

class SnowflakeConnection(config: Map[String, String]) {

  import SnowflakeConnection._

  // Validate required keys
  requiredKeys.foreach { key =>
    if (!config.contains(key) || config(key) == null)
      throw new IllegalArgumentException(s"Required key $key not found")
  }

  private val sparkConfigMap: Map[String, String] = getSparkConfigMap

  private def getSparkConfigMap: Map[String, String] = {
    val sfOptions = Map(
      "sfURL" -> getConnectionURL,
      "sfUser" -> config(USERNAME),
      "sfDatabase" -> config(DATABASE),
      "sfSchema" -> config(SCHEMA),
      "sfRole" -> config(ROLE),
      "sfWarehouse" -> config(WAREHOUSE),
      "pem_private_key" -> getEncoded(getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
    )
    sfOptions
  }

  private def getConnectionURL: String = {
    val snowflakeAccount = config(ACCOUNT)
    val snowflakeRegion = config(REGION)
    s"$snowflakeAccount.$snowflakeRegion.snowflakecomputing.com"
  }

  def getConnection: Connection = {
    val jdbcUrl = s"jdbc:snowflake://${getConnectionURL}/"
    val properties = new Properties()
    properties.put("user", config(USERNAME))
    properties.put("role", config(ROLE))
    properties.put("privateKey", getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
    properties.put("account", config(ACCOUNT))
    properties.put("warehouse", config(WAREHOUSE))
    properties.put("db", config(DATABASE))
    properties.put("schema", config(SCHEMA))

    try {
      DriverManager.getConnection(jdbcUrl, properties)
    } catch {
      case e: Exception => throw new RuntimeException(e)
    }
  }

  def read(session: SparkSession, query: String): Dataset[Row] = {
    session.read
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(sparkConfigMap.asJava)
      .option("query", query)
      .load()
  }

  def write(df: Dataset[_], tableName: String, mode: SaveMode): Unit = {
    df.write
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(sparkConfigMap.asJava)
      .option("dbtable", tableName)
      .mode(mode)
      .save()
  }
}

object SnowflakeConnection {
  val ACCOUNT = "sfAccount"
  val REGION = "sfRegion"
  val USERNAME = "sfUsername"
  val ROLE = "sfRole"
  val WAREHOUSE = "sfWarehouse"
  val DATABASE = "sfDatabase"
  val SCHEMA = "sfSchema"
  val PK_PATH = "sfPrivateKey"
  val PK_PASSPHRASE = "sfPassphrase"

  val SNOWFLAKE_SOURCE_NAME = "net.snowflake.spark.snowflake"

  private val requiredKeys = Seq(DATABASE, REGION, WAREHOUSE, ROLE, USERNAME, ACCOUNT, SCHEMA, PK_PATH, PK_PASSPHRASE)

  def getEncoded(privateKey: PrivateKey): String = {
    try {
      val keyBytes = privateKey.getEncoded
      Base64.getMimeEncoder(64, "\n".getBytes).encodeToString(keyBytes)
    } catch {
      case e: Exception => throw new RuntimeException("Error encoding PrivateKey", e)
    }
  }

  def getPrivateKey(filename: String, passphrase: String): PrivateKey = {
    try {
      Security.addProvider(new BouncyCastleProvider())
      val pemParser = new PEMParser(new FileReader(Paths.get(filename).toFile))
      val pemObject = pemParser.readObject()

      val privateKeyInfo = pemObject match {
        case encrypted: PKCS8EncryptedPrivateKeyInfo =>
          val decryptorProvider: InputDecryptorProvider =
            new JceOpenSSLPKCS8DecryptorProviderBuilder().build(passphrase.toCharArray)
          encrypted.decryptPrivateKeyInfo(decryptorProvider)
        case keyInfo: PrivateKeyInfo => keyInfo
        case _ => throw new IllegalArgumentException("Unsupported key format")
      }

      pemParser.close()
      val converter = new JcaPEMKeyConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME)
      converter.getPrivateKey(privateKeyInfo)
    } catch {
      case e: Exception => throw new RuntimeException(e)
    }
  }

  /**
   * A simple main method to test the Snowflake connection and a sample query.
   */
  def main(args: Array[String]): Unit = {

    // Adjust these values to match your Snowflake account and key details
    val configMap: Map[String, String] = Map(
      ACCOUNT        -> "<account_name>",
      REGION         -> "<region>",
      USERNAME       -> "<username>",
      ROLE           -> "<role>",
      WAREHOUSE      -> "<warehouse>",
      DATABASE       -> "<database>",
      SCHEMA         -> "<schema>",
      PK_PATH        -> "<private_key_path>",
      PK_PASSPHRASE  -> "<private_key_passphrase>"
    )

    val spark = SparkSession.builder()
      .appName("SnowflakeConnectionTest")
      .master("local[*]") // For local testing; use an appropriate cluster manager in production
      .getOrCreate()

    try {
      // Create our SnowflakeConnection instance
      val snowflakeConnection = new SnowflakeConnection(configMap)

      // Test the JDBC connection
      println("Testing JDBC connection...")
      val jdbcConn = snowflakeConnection.getConnection
      println("JDBC connection successful!")
      jdbcConn.close()

      // Test a simple query through Spark
      val testQuery = "SELECT * From <SOME_TABLE>"
      println(s"Running test query: $testQuery")

      val df = snowflakeConnection.read(spark, testQuery)
      df.show()

    } catch {
      case ex: Exception =>
        println(s"An error occurred: ${ex.getMessage}")
        ex.printStackTrace()
    } finally {
      spark.stop()
    }
  }
}