diff --git a/membersrvc/ca/aca.go b/membersrvc/ca/aca.go index 85b02957bc3..8ecc3b46b77 100644 --- a/membersrvc/ca/aca.go +++ b/membersrvc/ca/aca.go @@ -340,6 +340,9 @@ func (aca *ACA) fetchAndPopulateAttributes(id, affiliation string) error { func (aca *ACA) findAttribute(owner *AttributeOwner, attributeName string) (*AttributePair, error) { var count int + mutex.RLock() + defer mutex.RUnlock() + err := aca.db.QueryRow("SELECT count(row) AS cant FROM Attributes WHERE id=? AND affiliation =? AND attributeName =?", owner.GetID(), owner.GetAffiliation(), attributeName).Scan(&count) if err != nil { diff --git a/membersrvc/ca/ca.go b/membersrvc/ca/ca.go index a2c2110aa0d..fe2858b549b 100644 --- a/membersrvc/ca/ca.go +++ b/membersrvc/ca/ca.go @@ -74,7 +74,7 @@ type AffiliationGroup struct { } var ( - mutex = &sync.Mutex{} + mutex = &sync.RWMutex{} caOrganization string caCountry string rootPath string @@ -367,9 +367,6 @@ func (ca *CA) createCertificate(id string, pub interface{}, usage x509.KeyUsage, } func (ca *CA) createCertificateFromSpec(spec *CertificateSpec, timestamp int64, kdfKey []byte, persist bool) ([]byte, error) { - mutex.Lock() - defer mutex.Unlock() - Trace.Println("Creating certificate for " + spec.GetID() + ".") raw, err := ca.newCertificateFromSpec(spec) @@ -386,6 +383,9 @@ func (ca *CA) createCertificateFromSpec(spec *CertificateSpec, timestamp int64, } func (ca *CA) persistCertificate(id string, timestamp int64, usage x509.KeyUsage, certRaw []byte, kdfKey []byte) error { + mutex.Lock() + defer mutex.Unlock() + hash := primitives.NewHash() hash.Write(certRaw) var err error @@ -451,6 +451,9 @@ func (ca *CA) newCertificateFromSpec(spec *CertificateSpec) ([]byte, error) { func (ca *CA) readCertificateByKeyUsage(id string, usage x509.KeyUsage) ([]byte, error) { Trace.Printf("Reading certificate for %s and usage %v", id, usage) + mutex.RLock() + defer mutex.RUnlock() + var raw []byte err := ca.db.QueryRow("SELECT cert FROM Certificates WHERE id=? AND usage=?", id, usage).Scan(&raw) @@ -464,6 +467,9 @@ func (ca *CA) readCertificateByKeyUsage(id string, usage x509.KeyUsage) ([]byte, func (ca *CA) readCertificateByTimestamp(id string, ts int64) ([]byte, error) { Trace.Println("Reading certificate for " + id + ".") + mutex.RLock() + defer mutex.RUnlock() + var raw []byte err := ca.db.QueryRow("SELECT cert FROM Certificates WHERE id=? AND timestamp=?", id, ts).Scan(&raw) @@ -473,6 +479,9 @@ func (ca *CA) readCertificateByTimestamp(id string, ts int64) ([]byte, error) { func (ca *CA) readCertificates(id string, opt ...int64) (*sql.Rows, error) { Trace.Println("Reading certificatess for " + id + ".") + mutex.RLock() + defer mutex.RUnlock() + if len(opt) > 0 && opt[0] != 0 { return ca.db.Query("SELECT cert, kdfkey FROM Certificates WHERE id=? AND timestamp=? ORDER BY usage", id, opt[0]) } @@ -483,12 +492,18 @@ func (ca *CA) readCertificates(id string, opt ...int64) (*sql.Rows, error) { func (ca *CA) readCertificateSets(id string, start, end int64) (*sql.Rows, error) { Trace.Println("Reading certificate sets for " + id + ".") + mutex.RLock() + defer mutex.RUnlock() + return ca.db.Query("SELECT cert, kdfKey, timestamp FROM Certificates WHERE id=? AND timestamp BETWEEN ? AND ? ORDER BY timestamp", id, start, end) } func (ca *CA) readCertificateByHash(hash []byte) ([]byte, error) { Trace.Println("Reading certificate for hash " + string(hash) + ".") + mutex.RLock() + defer mutex.RUnlock() + var raw []byte row := ca.db.QueryRow("SELECT cert FROM Certificates WHERE hash=?", hash) err := row.Scan(&raw) @@ -499,6 +514,9 @@ func (ca *CA) readCertificateByHash(hash []byte) ([]byte, error) { func (ca *CA) isValidAffiliation(affiliation string) (bool, error) { Trace.Println("Validating affiliation: " + affiliation) + mutex.RLock() + defer mutex.RUnlock() + var count int var err error err = ca.db.QueryRow("SELECT count(row) FROM AffiliationGroups WHERE name=?", affiliation).Scan(&count) @@ -662,6 +680,9 @@ func (ca *CA) registerAffiliationGroup(name string, parentName string) error { func (ca *CA) deleteUser(id string) error { Trace.Println("Deleting user " + id + ".") + mutex.Lock() + defer mutex.Unlock() + var row int err := ca.db.QueryRow("SELECT row FROM Users WHERE id=?", id).Scan(&row) if err == nil { @@ -684,6 +705,9 @@ func (ca *CA) deleteUser(id string) error { func (ca *CA) readUser(id string) *sql.Row { Trace.Println("Reading token for " + id + ".") + mutex.RLock() + defer mutex.RUnlock() + return ca.db.QueryRow("SELECT role, token, state, key, enrollmentId FROM Users WHERE id=?", id) } @@ -700,6 +724,9 @@ func (ca *CA) readUsers(role int) (*sql.Rows, error) { func (ca *CA) readRole(id string) int { Trace.Println("Reading role for " + id + ".") + mutex.RLock() + defer mutex.RUnlock() + var role int ca.db.QueryRow("SELECT role FROM Users WHERE id=?", id).Scan(&role) @@ -771,6 +798,9 @@ func (ca *CA) parseEnrollID(enrollID string) (id string, role string, affiliatio // and with metadata associated with 'newMemberMetadataStr' // Return nil if allowed, or an error if not allowed func (ca *CA) canRegister(registrar string, newMemberRole string, newMemberMetadataStr string) error { + mutex.RLock() + defer mutex.RUnlock() + // Read the user metadata associated with 'registrar' var registrarMetadataStr string err := ca.db.QueryRow("SELECT metadata FROM Users WHERE id=?", registrar).Scan(®istrarMetadataStr) diff --git a/membersrvc/ca/ecap.go b/membersrvc/ca/ecap.go index 2f1635bdfa5..4b02652ecc8 100644 --- a/membersrvc/ca/ecap.go +++ b/membersrvc/ca/ecap.go @@ -106,6 +106,7 @@ func (ecap *ECAP) CreateCertificatePair(ctx context.Context, in *pb.ECertCreateR id := in.Id.Id err := ecap.eca.readUser(id).Scan(&role, &tok, &state, &prev, &enrollID) + if err != nil { errMsg := "Identity lookup error: " + err.Error() Trace.Println(errMsg) @@ -127,7 +128,10 @@ func (ecap *ECAP) CreateCertificatePair(ctx context.Context, in *pb.ECertCreateR // initial request, create encryption challenge tok = []byte(randomString(12)) + mutex.Lock() _, err = ecap.eca.db.Exec("UPDATE Users SET token=?, state=?, key=? WHERE id=?", tok, 1, in.Enc.Key, id) + mutex.Unlock() + if err != nil { Error.Println(err) return nil, err @@ -190,14 +194,20 @@ func (ecap *ECAP) CreateCertificatePair(ctx context.Context, in *pb.ECertCreateR spec = NewDefaultCertificateSpecWithCommonName(id, enrollID, ekey.(*ecdsa.PublicKey), x509.KeyUsageDataEncipherment, pkix.Extension{Id: ECertSubjectRole, Critical: true, Value: []byte(strconv.Itoa(ecap.eca.readRole(id)))}) eraw, err := ecap.eca.createCertificateFromSpec(spec, ts, nil, true) if err != nil { + mutex.Lock() ecap.eca.db.Exec("DELETE FROM Certificates Where id=?", id) + mutex.Unlock() Error.Println(err) return nil, err } + mutex.Lock() _, err = ecap.eca.db.Exec("UPDATE Users SET state=? WHERE id=?", 2, id) + mutex.Unlock() if err != nil { + mutex.Lock() ecap.eca.db.Exec("DELETE FROM Certificates Where id=?", id) + mutex.Unlock() Error.Println(err) return nil, err } diff --git a/membersrvc/ca/tca.go b/membersrvc/ca/tca.go index 386a0217d40..5818c4a9c5d 100644 --- a/membersrvc/ca/tca.go +++ b/membersrvc/ca/tca.go @@ -240,6 +240,9 @@ func (tca *TCA) startTCAA(srv *grpc.Server) { } func (tca *TCA) getCertificateSets(enrollmentID string) ([]*TCertSet, error) { + mutex.RLock() + defer mutex.RUnlock() + var sets = []*TCertSet{} var err error @@ -269,6 +272,9 @@ func (tca *TCA) getCertificateSets(enrollmentID string) ([]*TCertSet, error) { } func (tca *TCA) persistCertificateSet(enrollmentID string, timestamp int64, nonce []byte, kdfKey []byte) error { + mutex.Lock() + defer mutex.Unlock() + var err error if _, err = tca.db.Exec("INSERT INTO TCertificateSets (enrollmentID, timestamp, nonce, kdfkey) VALUES (?, ?, ?, ?)", enrollmentID, timestamp, nonce, kdfKey); err != nil {