Skip to content

Commit

Permalink
example and test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dmihalcik-virtru committed Jan 10, 2025
1 parent eccd033 commit e542b85
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 30 deletions.
15 changes: 1 addition & 14 deletions examples/cmd/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func init() {
encryptCmd.Flags().BoolVar(&nanoFormat, "nano", false, "Output in nanoTDF format")
encryptCmd.Flags().BoolVar(&autoconfigure, "autoconfigure", true, "Use attribute grants to select kases")
encryptCmd.Flags().BoolVar(&noKIDInKAO, "no-kid-in-kao", false, "[deprecated] Disable storing key identifiers in TDF KAOs")
encryptCmd.Flags().BoolVar(&noKIDInNano, "no-kid-in-nano", true, "Disable storing key identifiers in nanoTDF KAS ResourceLocator")
encryptCmd.Flags().BoolVar(&noKIDInNano, "no-kid-in-nano", false, "Disable storing key identifiers in nanoTDF KAS ResourceLocator")
encryptCmd.Flags().StringVarP(&outputName, "output", "o", "sensitive.txt.tdf", "name or path of output file; - for stdout")
encryptCmd.Flags().IntVarP(&collection, "collection", "c", 0, "number of nano's to create for collection. If collection >0 (default) then output will be <iteration>_<output>")

Expand All @@ -51,19 +51,6 @@ func encrypt(cmd *cobra.Command, args []string) error {
plainText := args[0]
in := strings.NewReader(plainText)

opts := []sdk.Option{
sdk.WithInsecurePlaintextConn(),
sdk.WithClientCredentials("opentdf-sdk", "secret", nil),
}

if noKIDInKAO {
opts = append(opts, sdk.WithNoKIDInKAO())
}
// double negative always gets me
if !noKIDInNano {
opts = append(opts, sdk.WithNoKIDInNano())
}

// Create new offline client
client, err := newSDK()
if err != nil {
Expand Down
11 changes: 10 additions & 1 deletion examples/cmd/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@ func newSDK() (*sdk.SDK, error) {
if storeCollectionHeaders {
opts = append(opts, sdk.WithStoreCollectionHeaders())
}
if clientCredentials != "" {

if noKIDInKAO {
opts = append(opts, sdk.WithNoKIDInKAO())
}
if noKIDInNano {
opts = append(opts, sdk.WithNoKIDInNano())
}
if clientCredentials == "" {
opts = append(opts, sdk.WithClientCredentials("opentdf-sdk", "secret", nil))
} else {
i := strings.Index(clientCredentials, ":")
if i < 0 {
return nil, fmt.Errorf("invalid client id/secret pair")
Expand Down
6 changes: 3 additions & 3 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,16 @@ func (s SDK) getPublicKey(ctx context.Context, url, algorithm string) (*KASInfo,
}

kid := resp.GetKid()
if s.config.tdfFeatures.noKID {
kid = ""
}

ki := KASInfo{
URL: url,
Algorithm: algorithm,
KID: kid,
PublicKey: resp.GetPublicKey(),
}
if s.config.tdfFeatures.noKID {
ki.KID = ""
}
if s.kasKeyCache != nil {
s.kasKeyCache.store(ki)
}
Expand Down
3 changes: 3 additions & 0 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFCon
SplitID: splitID,
WrappedKey: string(ocrypto.Base64Encode(wrappedKey)),
}
if s.config.tdfFeatures.noKID {
keyAccess.KID = ""
}

manifest.EncryptionInformation.KeyAccessObjs = append(manifest.EncryptionInformation.KeyAccessObjs, keyAccess)
}
Expand Down
31 changes: 30 additions & 1 deletion service/internal/security/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestMarshalTo(t *testing.T) {
wantErr bool
}{
{
name: "upgrade2023CertID",
name: "upgrade2023CertIDA",
config: CryptoConfig2024{
Standard: Standard{
RSAKeys: map[string]StandardKeyInfo{
Expand All @@ -42,6 +42,35 @@ func TestMarshalTo(t *testing.T) {
},
wantErr: false,
},
{
name: "upgrade2023CertIDB",
config: CryptoConfig2024{
Standard: Standard{
RSAKeys: map[string]StandardKeyInfo{
"r1": {PrivateKeyPath: "r1_private.pem", PublicKeyPath: "r1_public.pem"},
"r2": {PrivateKeyPath: "r2_private.pem", PublicKeyPath: "r2_public.pem"},
},
ECKeys: map[string]StandardKeyInfo{
"e1": {PrivateKeyPath: "e1_private.pem", PublicKeyPath: "e1_public.pem"},
"e2": {PrivateKeyPath: "e2_private.pem", PublicKeyPath: "e2_public.pem"},
},
},
},
input: map[string]any{
"enabled": true,
"eccertid": "e1",
"rsacertid": "r1",
},
expected: KASConfigDupe{
Keyring: []CurrentKeyFor{
{Algorithm: "rsa:2048", KID: "r1", Private: "r1_private.pem", Certificate: "r1_public.pem", Active: true, Legacy: true},
{Algorithm: "rsa:2048", KID: "r2", Private: "r2_private.pem", Certificate: "r2_public.pem", Active: false, Legacy: true},
{Algorithm: "ec:secp256r1", KID: "e1", Private: "e1_private.pem", Certificate: "e1_public.pem", Active: true, Legacy: true},
{Algorithm: "ec:secp256r1", KID: "e2", Private: "e2_private.pem", Certificate: "e2_public.pem", Active: false, Legacy: true},
},
},
wantErr: false,
},
{
name: "upgrade2023NoCertIDs",
config: CryptoConfig2024{
Expand Down
23 changes: 18 additions & 5 deletions service/kas/access/publicKey.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ func (p Provider) LegacyPublicKey(ctx context.Context, req *connect.Request[kasp
if err != nil {
return nil, err
}
kid, err := p.CryptoProvider.CurrentKID(algorithm)
kids, err := p.CryptoProvider.CurrentKID(algorithm)
if err != nil {
return nil, err
}
if len(kids) == 0 {
return nil, security.ErrCertNotFound
}
if len(kids) > 1 {
p.Logger.ErrorContext(ctx, "multiple keys found for algorithm", "algorithm", algorithm, "kids", kids)
}
fmt := recrypt.KeyFormatPEM
pem, err := p.CryptoProvider.PublicKey(algorithm, kid, fmt)
pem, err := p.CryptoProvider.PublicKey(algorithm, kids[:1], fmt)
if err != nil {
p.Logger.ErrorContext(ctx, "CryptoProvider.ECPublicKey failed", "err", err)
return nil, connect.NewError(connect.CodeInternal, errors.Join(ErrConfig, errors.New("configuration error")))
Expand All @@ -54,14 +60,21 @@ func (p Provider) PublicKey(ctx context.Context, req *connect.Request[kaspb.Publ
algorithm = recrypt.AlgorithmRSA2048
}

kid, err := p.CryptoProvider.CurrentKID(algorithm)
kids, err := p.CryptoProvider.CurrentKID(algorithm)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if len(kids) == 0 {
return nil, security.ErrCertNotFound
}
fmt, err := p.CryptoProvider.ParseKeyFormat(req.Msg.GetFmt())
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
if len(kids) > 1 && fmt != recrypt.KeyFormatJWK {
p.Logger.WarnContext(ctx, "multiple active keys found for algorithm, only returning the first one", "algorithm", algorithm, "kids", kids, "fmt", fmt)
kids = kids[:1]
}

r := func(value string, kid []recrypt.KeyIdentifier, err error) (*connect.Response[kaspb.PublicKeyResponse], error) {
if errors.Is(err, security.ErrCertNotFound) {
Expand All @@ -78,8 +91,8 @@ func (p Provider) PublicKey(ctx context.Context, req *connect.Request[kaspb.Publ
return connect.NewResponse(&kaspb.PublicKeyResponse{PublicKey: value, Kid: string(kid[0])}), nil
}

v, err := p.CryptoProvider.PublicKey(algorithm, kid, fmt)
return r(v, kid, err)
v, err := p.CryptoProvider.PublicKey(algorithm, kids, fmt)
return r(v, kids, err)
}

func exportRsaPublicKeyAsPemStr(pubkey *rsa.PublicKey) (string, error) {
Expand Down
5 changes: 3 additions & 2 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
}

if body.Algorithm == "" {
p.Logger.DebugContext(ctx, "default rewrap algorithm")
body.Algorithm = "rsa:2048"
p.Logger.DebugContext(ctx, "default rewrap algorithm", "alg", body.Algorithm)
}

if body.Algorithm == "ec:secp256r1" {
Expand Down Expand Up @@ -315,12 +315,13 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, body *RequestBody, entity *en
return nil, err400("bad request")
}
}
p.Logger.DebugContext(ctx, "paging through legacy KIDs for kid free kao", "kids", kidsToCheck)
symmetricKey, err := p.CryptoProvider.Unwrap(kidsToCheck[0], body.KeyAccess.WrappedKey)
for _, kid := range kidsToCheck[1:] {
if err == nil {
break
}
p.Logger.DebugContext(ctx, "continue paging through legacy KIDs for kid free kao", "err", err)
p.Logger.DebugContext(ctx, "continue paging through legacy KIDs for kid free kao", "err", err, "kid", kid)
symmetricKey, err = p.CryptoProvider.Unwrap(kid, body.KeyAccess.WrappedKey)
}
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions test/tdf-roundtrips.bats
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ server:
standard:
rsa:
r1:
private_key_path: kas-private.pem
public_key_path: kas-cert.pem
private_key_path: kas-r1-private.pem
public_key_path: kas-r1-cert.pem
r2:
private_key_path: kas-r2-private.pem
public_key_path: kas-r2-cert.pem
ec:
e1:
private_key_path: kas-ec-private.pem
public_key_path: kas-ec-cert.pem
private_key_path: kas-e1-private.pem
public_key_path: kas-e1-cert.pem
e2:
private_key_path: kas-e2-private.pem
public_key_path: kas-e2-cert.pem
Expand Down

0 comments on commit e542b85

Please sign in to comment.