Connecting to Snowflake with Scala Spark Using Key Pair Authentication
Snowflake is a popular cloud data platform known for its scalability and performance. Establishing a connection to Snowflake can be done using various methods, such as username and password, OAuth, and key pair authentication. In this post, we will focus on connecting to Snowflake using Scala Spark with key pair authentication.
While this post focuses on key pair authentication, Snowflake supports:
- Username and Password: Simple but less secure for production.
- OAuth: Used for integration with external identity providers.
- External Browser: For connecting via a browser session.
Understanding Key Pair Authentication
Key pair authentication is a secure method for connecting to Snowflake without using traditional passwords. It involves a private key stored securely on your system and a public key registered with Snowflake. The private key is used to authenticate the user and establish a secure connection.
Steps to Create a Private Key
- Generate the Key Pair: You can use OpenSSL to generate a private key in
.p8
format:
openssl genpkey -algorithm RSA -out private_key.p8 -pkeyopt rsa_keygen_bits:2048
This generates a non-encrypted private key.
- Encrypt the Private Key (Optional): Use OpenSSL to add a passphrase for added security:
openssl pkcs8 -in private_key.p8 -topk8 -out encrypted_private_key.p8 -v2 aes-256-cbc
Replace aes-256-cbc
with your preferred encryption algorithm.
- Register the Public Key: Extract the public key from the private key and register it with your Snowflake user account:
openssl rsa -in private_key.p8 -pubout -out public_key.pem
Use the Snowflake UI or SQL commands to register the public key.
Code Example
Key Components of the Code
[full code below the post]
- Dependencies and Imports: The code relies on libraries such as
BouncyCastle
for handling private keys,SparkSession
for Spark operations, andDriverManager
for JDBC connections.
import org.bouncycastle.openssl.{PEMParser}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
- Configuration Map: The configuration map contains essential details such as the Snowflake account, region, and private key file path.
val configMap: Map[String, String] = Map(
ACCOUNT -> "<account_name>",
REGION -> "<region>",
USERNAME -> "<username>",
ROLE -> "<role>",
WAREHOUSE -> "<warehouse>",
DATABASE -> "<database>",
SCHEMA -> "<schema>",
PK_PATH -> "<private_key_path>",
PK_PASSPHRASE -> "<private_key_passphrase>"
)
- Private Key Handling: The
getPrivateKey
method reads and decrypts the private key file. It uses BouncyCastle for secure handling of encrypted and non-encrypted keys.
def getPrivateKey(filename: String, passphrase: String): PrivateKey = {
val pemParser = new PEMParser(new FileReader(filename))
val pemObject = pemParser.readObject()
val privateKeyInfo = pemObject match {
case encrypted: PKCS8EncryptedPrivateKeyInfo =>
val decryptorProvider = new JceOpenSSLPKCS8DecryptorProviderBuilder()
.build(passphrase.toCharArray)
encrypted.decryptPrivateKeyInfo(decryptorProvider)
case keyInfo: PrivateKeyInfo => keyInfo
case _ => throw new IllegalArgumentException("Unsupported key format")
}
pemParser.close()
new JcaPEMKeyConverter().getPrivateKey(privateKeyInfo)
}
- Encoding the Private Key: Spark expects the private key in Base64 format. The
getEncoded
method encodes the private key appropriately.
def getEncoded(privateKey: PrivateKey): String = {
Base64.getMimeEncoder(64, "\n".getBytes).encodeToString(privateKey.getEncoded)
}
- Building the Spark Config Map: The
getSparkConfigMap
method prepares the Snowflake connection parameters required for Spark to interact with Snowflake.
private val sparkConfigMap: Map[String, String] = Map(
"sfURL" -> s"${config(ACCOUNT)}.${config(REGION)}.snowflakecomputing.com",
"sfUser" -> config(USERNAME),
"pem_private_key" -> getEncoded(getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
)
- Spark Read and Write Operations:
read
fetches data from Snowflake based on a SQL query.write
inserts data into a Snowflake table.
def read(session: SparkSession, query: String): Dataset[Row] = {
session.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(sparkConfigMap.asJava)
.option("query", query)
.load()
}
def write(df: Dataset[_], tableName: String, mode: SaveMode): Unit = {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(sparkConfigMap.asJava)
.option("dbtable", tableName)
.mode(mode)
.save()
}
- JDBC Connection: The
getConnection
method demonstrates a direct JDBC connection to Snowflake, which is useful for testing or advanced operations.
def getConnection: Connection = {
val jdbcUrl = s"jdbc:snowflake://${config(ACCOUNT)}.${config(REGION)}.snowflakecomputing.com/"
val properties = new Properties()
properties.put("user", config(USERNAME))
properties.put("privateKey", getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
DriverManager.getConnection(jdbcUrl, properties)
}
- Main Method: The main method tests the connection and runs a sample query.
def main(args: Array[String]): Unit = {
val snowflakeConnection = new SnowflakeConnection(configMap)
val df = snowflakeConnection.read(spark, "SELECT * FROM TEST_TABLE LIMIT 10;")
df.show()
}
Conclusion
This implementation demonstrates how to securely connect to Snowflake using Scala Spark and key pair authentication. The code ensures security by handling private keys with encryption and uses the Spark-Snowflake connector for seamless data operations. While key pair authentication is highly secure, you can explore other methods depending on your use case.
package com.innovid.identityTask
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.openssl.{PEMParser, jcajce}
import org.bouncycastle.openssl.jcajce.{JcaPEMKeyConverter, JceOpenSSLPKCS8DecryptorProviderBuilder}
import org.bouncycastle.operator.InputDecryptorProvider
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo
import java.io.FileReader
import java.nio.file.Paths
import java.security.{PrivateKey, Security}
import java.sql.Connection
import java.sql.DriverManager
import java.util.{Base64, Properties}
import scala.collection.JavaConverters._
class SnowflakeConnection(config: Map[String, String]) {
import SnowflakeConnection._
// Validate required keys
requiredKeys.foreach { key =>
if (!config.contains(key) || config(key) == null)
throw new IllegalArgumentException(s"Required key $key not found")
}
private val sparkConfigMap: Map[String, String] = getSparkConfigMap
private def getSparkConfigMap: Map[String, String] = {
val sfOptions = Map(
"sfURL" -> getConnectionURL,
"sfUser" -> config(USERNAME),
"sfDatabase" -> config(DATABASE),
"sfSchema" -> config(SCHEMA),
"sfRole" -> config(ROLE),
"sfWarehouse" -> config(WAREHOUSE),
"pem_private_key" -> getEncoded(getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
)
sfOptions
}
private def getConnectionURL: String = {
val snowflakeAccount = config(ACCOUNT)
val snowflakeRegion = config(REGION)
s"$snowflakeAccount.$snowflakeRegion.snowflakecomputing.com"
}
def getConnection: Connection = {
val jdbcUrl = s"jdbc:snowflake://${getConnectionURL}/"
val properties = new Properties()
properties.put("user", config(USERNAME))
properties.put("role", config(ROLE))
properties.put("privateKey", getPrivateKey(config(PK_PATH), config(PK_PASSPHRASE)))
properties.put("account", config(ACCOUNT))
properties.put("warehouse", config(WAREHOUSE))
properties.put("db", config(DATABASE))
properties.put("schema", config(SCHEMA))
try {
DriverManager.getConnection(jdbcUrl, properties)
} catch {
case e: Exception => throw new RuntimeException(e)
}
}
def read(session: SparkSession, query: String): Dataset[Row] = {
session.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(sparkConfigMap.asJava)
.option("query", query)
.load()
}
def write(df: Dataset[_], tableName: String, mode: SaveMode): Unit = {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(sparkConfigMap.asJava)
.option("dbtable", tableName)
.mode(mode)
.save()
}
}
object SnowflakeConnection {
val ACCOUNT = "sfAccount"
val REGION = "sfRegion"
val USERNAME = "sfUsername"
val ROLE = "sfRole"
val WAREHOUSE = "sfWarehouse"
val DATABASE = "sfDatabase"
val SCHEMA = "sfSchema"
val PK_PATH = "sfPrivateKey"
val PK_PASSPHRASE = "sfPassphrase"
val SNOWFLAKE_SOURCE_NAME = "net.snowflake.spark.snowflake"
private val requiredKeys = Seq(DATABASE, REGION, WAREHOUSE, ROLE, USERNAME, ACCOUNT, SCHEMA, PK_PATH, PK_PASSPHRASE)
def getEncoded(privateKey: PrivateKey): String = {
try {
val keyBytes = privateKey.getEncoded
Base64.getMimeEncoder(64, "\n".getBytes).encodeToString(keyBytes)
} catch {
case e: Exception => throw new RuntimeException("Error encoding PrivateKey", e)
}
}
def getPrivateKey(filename: String, passphrase: String): PrivateKey = {
try {
Security.addProvider(new BouncyCastleProvider())
val pemParser = new PEMParser(new FileReader(Paths.get(filename).toFile))
val pemObject = pemParser.readObject()
val privateKeyInfo = pemObject match {
case encrypted: PKCS8EncryptedPrivateKeyInfo =>
val decryptorProvider: InputDecryptorProvider =
new JceOpenSSLPKCS8DecryptorProviderBuilder().build(passphrase.toCharArray)
encrypted.decryptPrivateKeyInfo(decryptorProvider)
case keyInfo: PrivateKeyInfo => keyInfo
case _ => throw new IllegalArgumentException("Unsupported key format")
}
pemParser.close()
val converter = new JcaPEMKeyConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME)
converter.getPrivateKey(privateKeyInfo)
} catch {
case e: Exception => throw new RuntimeException(e)
}
}
/**
* A simple main method to test the Snowflake connection and a sample query.
*/
def main(args: Array[String]): Unit = {
// Adjust these values to match your Snowflake account and key details
val configMap: Map[String, String] = Map(
ACCOUNT -> "<account_name>",
REGION -> "<region>",
USERNAME -> "<username>",
ROLE -> "<role>",
WAREHOUSE -> "<warehouse>",
DATABASE -> "<database>",
SCHEMA -> "<schema>",
PK_PATH -> "<private_key_path>",
PK_PASSPHRASE -> "<private_key_passphrase>"
)
val spark = SparkSession.builder()
.appName("SnowflakeConnectionTest")
.master("local[*]") // For local testing; use an appropriate cluster manager in production
.getOrCreate()
try {
// Create our SnowflakeConnection instance
val snowflakeConnection = new SnowflakeConnection(configMap)
// Test the JDBC connection
println("Testing JDBC connection...")
val jdbcConn = snowflakeConnection.getConnection
println("JDBC connection successful!")
jdbcConn.close()
// Test a simple query through Spark
val testQuery = "SELECT * From <SOME_TABLE>"
println(s"Running test query: $testQuery")
val df = snowflakeConnection.read(spark, testQuery)
df.show()
} catch {
case ex: Exception =>
println(s"An error occurred: ${ex.getMessage}")
ex.printStackTrace()
} finally {
spark.stop()
}
}
}