diff --git a/certtostore_windows.go b/certtostore_windows.go index ff1bfa8..26141a7 100644 --- a/certtostore_windows.go +++ b/certtostore_windows.go @@ -35,6 +35,7 @@ import ( "path/filepath" "reflect" "strings" + "sync" "syscall" "time" "unicode/utf16" @@ -43,6 +44,7 @@ import ( "golang.org/x/crypto/cryptobyte/asn1" "golang.org/x/crypto/cryptobyte" "golang.org/x/sys/windows" + "github.com/hashicorp/go-multierror" "github.com/google/logger" ) @@ -233,7 +235,6 @@ func intendedKeyUsage(enc uint32, cert *windows.CertContext) (usage uint16) { // WinCertStore is a CertStorage implementation for the Windows Certificate Store. type WinCertStore struct { - CStore windows.Handle Prov uintptr ProvName string issuers []string @@ -241,6 +242,9 @@ type WinCertStore struct { container string keyStorageFlags uintptr certChains [][]*x509.Certificate + stores map[string]*storeHandle + + mu sync.Mutex } // OpenWinCertStore creates a WinCertStore. Call Close() when finished using the store. @@ -257,6 +261,7 @@ func OpenWinCertStore(provider, container string, issuers, intermediateIssuers [ issuers: issuers, intermediateIssuers: intermediateIssuers, container: container, + stores: make(map[string]*storeHandle), } if legacyKey { @@ -372,17 +377,10 @@ func (w *WinCertStore) CertWithContext() (*x509.Certificate, *windows.CertContex // cert is a helper function to lookup certificates based on a known issuer. // store is used to specify which store to perform the lookup in (system or user). func (w *WinCertStore) cert(issuers []string, searchRoot *uint16, store uint32) (*x509.Certificate, *windows.CertContext, error) { - // Open a handle to the system cert store - certStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - store, - uintptr(unsafe.Pointer(searchRoot))) + h, err := w.storeHandle(store, searchRoot) if err != nil { - return nil, nil, fmt.Errorf("CertOpenStore returned: %v", err) + return nil, nil, err } - defer windows.CertCloseStore(certStore, 0) var prev *windows.CertContext var cert *x509.Certificate @@ -394,7 +392,7 @@ func (w *WinCertStore) cert(issuers []string, searchRoot *uint16, store uint32) // pass 0 as the third parameter because it is not used // https://msdn.microsoft.com/en-us/library/windows/desktop/aa376064(v=vs.85).aspx - nc, err := findCert(certStore, encodingX509ASN|encodingPKCS7, 0, findIssuerStr, i, prev) + nc, err := findCert(h, encodingX509ASN|encodingPKCS7, 0, findIssuerStr, i, prev) if err != nil { return nil, nil, fmt.Errorf("finding certificates: %v", err) } @@ -430,9 +428,20 @@ func freeObject(h uintptr) error { return fmt.Errorf("NCryptFreeObject returned %X: %v", r, err) } -// Close frees the handle to the certificate provider +// Close frees the handle to the certificate provider, the certificate store, etc. func (w *WinCertStore) Close() error { - return freeObject(w.Prov) + var result error + for _, v := range w.stores { + if v != nil { + if err := v.Close(); err != nil { + multierror.Append(result, err) + } + } + } + if err := freeObject(w.Prov); err != nil { + multierror.Append(result, err) + } + return result } // Link will associate the certificate installed in the system store to the user store. @@ -479,20 +488,13 @@ func (w *WinCertStore) Link() error { fmt.Printf("found a matching private key for the certificate, but association failed: %v", err) } - // Open a handle to the user cert store - userStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - certStoreCurrentUser, - uintptr(unsafe.Pointer(my))) + h, err := w.storeHandle(certStoreCurrentUser, my) if err != nil { - return fmt.Errorf("CertOpenStore for the user store returned: %v", err) + return err } - defer windows.CertCloseStore(userStore, 0) // Add the cert context to the users certificate store - if err := windows.CertAddCertificateContextToStore(userStore, certContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { + if err := windows.CertAddCertificateContextToStore(h, certContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { return fmt.Errorf("CertAddCertificateContextToStore returned: %v", err) } @@ -507,6 +509,35 @@ func (w *WinCertStore) Link() error { return nil } +type storeHandle struct { + handle *windows.Handle +} + +func newStoreHandle(provider uint32, store *uint16) (*storeHandle, error) { + var s storeHandle + if s.handle != nil { + return &s, nil + } + st, err := windows.CertOpenStore( + certStoreProvSystem, + 0, + 0, + provider, + uintptr(unsafe.Pointer(store))) + if err != nil { + return nil, fmt.Errorf("CertOpenStore for the user store returned: %v", err) + } + s.handle = &st + return &s, nil +} + +func (s *storeHandle) Close() error { + if s.handle != nil { + return windows.CertCloseStore(*s.handle, 1) + } + return nil +} + // linkLegacy will associate the private key for a system certificate backed by cryptoAPI to // the copy of the certificate stored in the user store. This makes the key available to legacy // applications which may require it be specifically present in the users store to be read. @@ -565,19 +596,13 @@ func (w *WinCertStore) Remove(removeSystem bool) error { // remove removes a certificate issued by w.issuer from the user and/or system cert stores. func (w *WinCertStore) remove(issuer string, removeSystem bool) error { - userStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - certStoreCurrentUser, - uintptr(unsafe.Pointer(my))) + h, err := w.storeHandle(certStoreCurrentUser, my) if err != nil { - return fmt.Errorf("certopenstore for the user store returned: %v", err) + return err } - defer windows.CertCloseStore(userStore, 0) userCertContext, err := findCert( - userStore, + h, encodingX509ASN|encodingPKCS7, 0, findIssuerStr, @@ -600,19 +625,13 @@ func (w *WinCertStore) remove(issuer string, removeSystem bool) error { return nil } - systemStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - certStoreLocalMachine, - uintptr(unsafe.Pointer(my))) + h2, err := w.storeHandle(certStoreLocalMachine, my) if err != nil { - return fmt.Errorf("certopenstore for the system store returned: %v", err) + return err } - defer windows.CertCloseStore(systemStore, 0) systemCertContext, err := findCert( - systemStore, + h2, encodingX509ASN|encodingPKCS7, 0, findIssuerStr, @@ -1357,19 +1376,13 @@ func (w *WinCertStore) Store(cert *x509.Certificate, intermediate *x509.Certific } // Open a handle to the system cert store - systemStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - certStoreLocalMachine, - uintptr(unsafe.Pointer(my))) + h, err := w.storeHandle(certStoreLocalMachine, my) if err != nil { - return fmt.Errorf("CertOpenStore for the system store returned: %v", err) + return err } - defer windows.CertCloseStore(systemStore, 0) // Add the cert context to the system certificate store - if err := windows.CertAddCertificateContextToStore(systemStore, certContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { + if err := windows.CertAddCertificateContextToStore(h, certContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { return fmt.Errorf("CertAddCertificateContextToStore returned: %v", err) } @@ -1383,26 +1396,35 @@ func (w *WinCertStore) Store(cert *x509.Certificate, intermediate *x509.Certific } defer windows.CertFreeCertificateContext(intContext) - // Open a handle to the intermediate cert store - caStore, err := windows.CertOpenStore( - certStoreProvSystem, - 0, - 0, - certStoreLocalMachine, - uintptr(unsafe.Pointer(ca))) + h2, err := w.storeHandle(certStoreLocalMachine, ca) if err != nil { - return fmt.Errorf("CertOpenStore for the intermediate store returned: %v", err) + return err } - defer windows.CertCloseStore(caStore, 0) // Add the intermediate cert context to the store - if err := windows.CertAddCertificateContextToStore(caStore, intContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { + if err := windows.CertAddCertificateContextToStore(h2, intContext, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { return fmt.Errorf("CertAddCertificateContextToStore returned: %v", err) } return nil } +// Returns a handle to a given cert store, opening the handle as needed. +func (w *WinCertStore) storeHandle(provider uint32, store *uint16) (windows.Handle, error) { + w.mu.Lock() + defer w.mu.Unlock() + + key := fmt.Sprintf("%d%s", provider, windows.UTF16PtrToString(store)) + var err error + if w.stores[key] == nil { + w.stores[key], err = newStoreHandle(provider, store) + if err != nil { + return 0, err + } + } + return *w.stores[key].handle, nil +} + // copyFile copies the contents of one file from one location to another func copyFile(from, to string) error { source, err := os.Open(from)