diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtil.java b/src/main/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtil.java index e38a48cde3..2e11fed64a 100644 --- a/src/main/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtil.java +++ b/src/main/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtil.java @@ -21,16 +21,22 @@ public class EncryptionDecryptionUtil { - public static String encrypt(final String secret, final String data) { - final Cipher cipher = createCipherFromSecret(secret, CipherMode.ENCRYPT); - final byte[] cipherText = createCipherText(cipher, data.getBytes(StandardCharsets.UTF_8)); - return Base64.getEncoder().encodeToString(cipherText); + private final Cipher encryptCipher; + private final Cipher decryptCipher; + + public EncryptionDecryptionUtil(final String secret) { + this.encryptCipher = createCipherFromSecret(secret, CipherMode.ENCRYPT); + this.decryptCipher = createCipherFromSecret(secret, CipherMode.DECRYPT); + } + + public String encrypt(final String data) { + byte[] encryptedBytes = processWithCipher(data.getBytes(StandardCharsets.UTF_8), encryptCipher); + return Base64.getEncoder().encodeToString(encryptedBytes); } - public static String decrypt(final String secret, final String encryptedString) { - final Cipher cipher = createCipherFromSecret(secret, CipherMode.DECRYPT); - final byte[] cipherText = createCipherText(cipher, Base64.getDecoder().decode(encryptedString)); - return new String(cipherText, StandardCharsets.UTF_8); + public String decrypt(final String encryptedString) { + byte[] decodedBytes = Base64.getDecoder().decode(encryptedString); + return new String(processWithCipher(decodedBytes, decryptCipher), StandardCharsets.UTF_8); } private static Cipher createCipherFromSecret(final String secret, final CipherMode mode) { @@ -41,15 +47,15 @@ private static Cipher createCipherFromSecret(final String secret, final CipherMo cipher.init(mode.opmode, originalKey); return cipher; } catch (final Exception e) { - throw new RuntimeException("Error creating cipher from secret in mode " + mode.name()); + throw new RuntimeException("Error creating cipher from secret in mode " + mode.name(), e); } } - private static byte[] createCipherText(final Cipher cipher, final byte[] data) { + private static byte[] processWithCipher(final byte[] data, final Cipher cipher) { try { return cipher.doFinal(data); } catch (final Exception e) { - throw new RuntimeException("The cipher was unable to perform pass over data"); + throw new RuntimeException("Error processing data with cipher", e); } } diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java index ec096ea117..e484300f18 100644 --- a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java +++ b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java @@ -43,6 +43,7 @@ public class JwtVendor { private final JoseJwtProducer jwtProducer; private final LongSupplier timeProvider; private final Boolean bwcModeEnabled; + private final EncryptionDecryptionUtil encryptionDecryptionUtil; public JwtVendor(final Settings settings, final Optional timeProvider) { JoseJwtProducer jwtProducer = new JoseJwtProducer(); @@ -56,6 +57,7 @@ public JwtVendor(final Settings settings, final Optional timeProvi throw new IllegalArgumentException("encryption_key cannot be null"); } else { this.claimsEncryptionKey = settings.get("encryption_key"); + this.encryptionDecryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); } if (timeProvider.isPresent()) { this.timeProvider = timeProvider.get(); @@ -140,7 +142,7 @@ public String createJwt( if (roles != null) { String listOfRoles = String.join(",", roles); - jwtClaims.setProperty("er", EncryptionDecryptionUtil.encrypt(claimsEncryptionKey, listOfRoles)); + jwtClaims.setProperty("er", encryptionDecryptionUtil.encrypt(listOfRoles)); } else { throw new Exception("Roles cannot be null"); } diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index 3c9d40054a..65f5a373cc 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -60,12 +60,15 @@ public class OnBehalfOfAuthenticator implements HTTPAuthenticator { private final Boolean oboEnabled; private final String clusterName; + private final EncryptionDecryptionUtil encryptionUtil; + public OnBehalfOfAuthenticator(Settings settings, String clusterName) { String oboEnabledSetting = settings.get("enabled"); oboEnabled = oboEnabledSetting == null ? Boolean.TRUE : Boolean.valueOf(oboEnabledSetting); encryptionKey = settings.get("encryption_key"); jwtParser = initParser(settings.get("signing_key")); this.clusterName = clusterName; + this.encryptionUtil = new EncryptionDecryptionUtil(encryptionKey); } private JwtParser initParser(final String signingKey) { @@ -84,7 +87,7 @@ private List extractSecurityRolesFromClaims(Claims claims) { String rolesClaim = ""; if (er != null) { - rolesClaim = EncryptionDecryptionUtil.decrypt(encryptionKey, er.toString()); + rolesClaim = encryptionUtil.decrypt(er.toString()); } else if (dr != null) { rolesClaim = dr.toString(); } else { diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtilsTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtilsTest.java index 34b9b3a100..4890f380f9 100644 --- a/src/test/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtilsTest.java +++ b/src/test/java/org/opensearch/security/authtoken/jwt/EncryptionDecryptionUtilsTest.java @@ -22,8 +22,10 @@ public void testEncryptDecrypt() { String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes()); String data = "Hello, OpenSearch!"; - String encryptedString = EncryptionDecryptionUtil.encrypt(secret, data); - String decryptedString = EncryptionDecryptionUtil.decrypt(secret, encryptedString); + EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret); + + String encryptedString = util.encrypt(data); + String decryptedString = util.decrypt(encryptedString); Assert.assertEquals(data, decryptedString); } @@ -34,11 +36,13 @@ public void testDecryptingWithWrongKey() { String secret2 = Base64.getEncoder().encodeToString("wrongKey1234567".getBytes()); String data = "Hello, OpenSearch!"; - String encryptedString = EncryptionDecryptionUtil.encrypt(secret1, data); + EncryptionDecryptionUtil util1 = new EncryptionDecryptionUtil(secret1); + String encryptedString = util1.encrypt(data); - RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> EncryptionDecryptionUtil.decrypt(secret2, encryptedString)); + EncryptionDecryptionUtil util2 = new EncryptionDecryptionUtil(secret2); + RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> util2.decrypt(encryptedString)); - Assert.assertEquals("The cipher was unable to perform pass over data", ex.getMessage()); + Assert.assertEquals("Error processing data with cipher", ex.getMessage()); } @Test @@ -46,10 +50,8 @@ public void testDecryptingCorruptedData() { String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes()); String corruptedEncryptedString = "corruptedData"; - RuntimeException ex = Assert.assertThrows( - RuntimeException.class, - () -> EncryptionDecryptionUtil.decrypt(secret, corruptedEncryptedString) - ); + EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret); + RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> util.decrypt(corruptedEncryptedString)); Assert.assertEquals("Last unit does not have enough valid bits", ex.getMessage()); } @@ -59,8 +61,9 @@ public void testEncryptDecryptEmptyString() { String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes()); String data = ""; - String encryptedString = EncryptionDecryptionUtil.encrypt(secret, data); - String decryptedString = EncryptionDecryptionUtil.decrypt(secret, encryptedString); + EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret); + String encryptedString = util.encrypt(data); + String decryptedString = util.decrypt(encryptedString); Assert.assertEquals(data, decryptedString); } @@ -70,7 +73,8 @@ public void testEncryptNullValue() { String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes()); String data = null; - EncryptionDecryptionUtil.encrypt(secret, data); + EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret); + util.encrypt(data); } @Test(expected = NullPointerException.class) @@ -78,6 +82,7 @@ public void testDecryptNullValue() { String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes()); String data = null; - EncryptionDecryptionUtil.decrypt(secret, data); + EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret); + util.decrypt(data); } } diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java index 7c03ff912d..006f6ebc8d 100644 --- a/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java +++ b/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java @@ -60,8 +60,8 @@ public void testCreateJwtWithRoles() throws Exception { List roles = List.of("IT", "HR"); List backendRoles = List.of("Sales", "Support"); String expectedRoles = "IT,HR"; - Integer expirySeconds = 300; - LongSupplier currentTime = () -> (int) 100; + int expirySeconds = 300; + LongSupplier currentTime = () -> (long) 100; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); Long expectedExp = currentTime.getAsLong() + expirySeconds; @@ -78,8 +78,8 @@ public void testCreateJwtWithRoles() throws Exception { Assert.assertNotNull(jwt.getClaim("iat")); Assert.assertNotNull(jwt.getClaim("exp")); Assert.assertEquals(expectedExp, jwt.getClaim("exp")); - Assert.assertNotEquals(expectedRoles, jwt.getClaim("er")); - Assert.assertEquals(expectedRoles, EncryptionDecryptionUtil.decrypt(claimsEncryptionKey, jwt.getClaim("er").toString())); + EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); + Assert.assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString())); Assert.assertNull(jwt.getClaim("br")); } @@ -93,15 +93,13 @@ public void testCreateJwtWithBackwardsCompatibilityMode() throws Exception { String expectedRoles = "IT,HR"; String expectedBackendRoles = "Sales,Support"; - Integer expirySeconds = 300; - LongSupplier currentTime = () -> (int) 100; + int expirySeconds = 300; + LongSupplier currentTime = () -> (long) 100; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); Settings settings = Settings.builder() .put("signing_key", "abc123") .put("encryption_key", claimsEncryptionKey) - // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings .put(ConfigConstants.EXTENSIONS_BWC_PLUGIN_MODE, true) - // CS-ENFORCE-SINGLE .build(); Long expectedExp = currentTime.getAsLong() + expirySeconds; @@ -117,8 +115,8 @@ public void testCreateJwtWithBackwardsCompatibilityMode() throws Exception { Assert.assertNotNull(jwt.getClaim("iat")); Assert.assertNotNull(jwt.getClaim("exp")); Assert.assertEquals(expectedExp, jwt.getClaim("exp")); - Assert.assertNotEquals(expectedRoles, jwt.getClaim("er")); - Assert.assertEquals(expectedRoles, EncryptionDecryptionUtil.decrypt(claimsEncryptionKey, jwt.getClaim("er").toString())); + EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); + Assert.assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString())); Assert.assertNotNull(jwt.getClaim("br")); Assert.assertEquals(expectedBackendRoles, jwt.getClaim("br")); } @@ -170,14 +168,14 @@ public void testCreateJwtWithBadRoles() { String subject = "admin"; String audience = "audience_0"; List roles = null; - Integer expirySecond = 300; + Integer expirySeconds = 300; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); JwtVendor jwtVendor = new JwtVendor(settings, Optional.empty()); Throwable exception = Assert.assertThrows(RuntimeException.class, () -> { try { - jwtVendor.createJwt(issuer, subject, audience, expirySecond, roles, List.of()); + jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, List.of()); } catch (Exception e) { throw new RuntimeException(e); }