diff --git a/base.go b/base.go index a79f3eb..1ca51ab 100644 --- a/base.go +++ b/base.go @@ -101,7 +101,7 @@ func (bc *baseConnector) disconnectDevicesByDeviceName(ctx context.Context, name } func (bc *baseConnector) disconnectNVMEDevicesByDeviceName(ctx context.Context, name string) error { - defer tracer.TraceFuncCall(ctx, "baseConnector.disconnectDevicesByDeviceName")() + defer tracer.TraceFuncCall(ctx, "baseConnector.disconnectNVMEDevicesByDeviceName")() if !bc.scsi.IsDeviceExist(ctx, name) { logger.Info(ctx, "device %s not found", name) return nil @@ -109,7 +109,7 @@ func (bc *baseConnector) disconnectNVMEDevicesByDeviceName(ctx context.Context, var err error var wwn string if strings.HasPrefix(name, deviceMapperPrefix) { - wwn, err = bc.getDMWWN(ctx, name) + wwn, err = bc.getNVMEDMWWN(ctx, name) } else { wwn, err = bc.scsi.GetNVMEDeviceWWN(ctx, []string{name}) @@ -210,3 +210,29 @@ func (bc *baseConnector) getDMWWN(ctx context.Context, dm string) (string, error logger.Info(ctx, "WWN for DM %s is: %s", dm, wwn) return wwn, nil } + +func (bc *baseConnector) getNVMEDMWWN(ctx context.Context, dm string) (string, error) { + defer tracer.TraceFuncCall(ctx, "baseConnector.getDMWWN")() + logger.Info(ctx, "resolve wwn for DM: %s", dm) + children, err := bc.scsi.GetDMChildren(ctx, dm) + if err == nil { + logger.Debug(ctx, "children for DM %s: %s", dm, children) + wwn, err := bc.scsi.GetNVMEDeviceWWN(ctx, children) + if err != nil { + logger.Error(ctx, "failed to read WWN for DM %s children: %s", dm, err.Error()) + return "", err + } + logger.Debug(ctx, "WWN for DM %s is: %s", dm, wwn) + return wwn, nil + } + logger.Debug(ctx, "failed to get children for DM %s: %s", dm, err.Error()) + logger.Info(ctx, "can't resolve DM %s WWN from children devices, query multipathd", dm) + wwn, err := bc.multipath.GetDMWWID(ctx, dm) + if err != nil { + msg := fmt.Sprintf("failed to resolve DM %s WWN: %s", dm, err.Error()) + logger.Error(ctx, msg) + return "", errors.New(msg) + } + logger.Info(ctx, "WWN for DM %s is: %s", dm, wwn) + return wwn, nil +} diff --git a/internal/scsi/interface.go b/internal/scsi/interface.go index 3b0eba4..d58abe4 100644 --- a/internal/scsi/interface.go +++ b/internal/scsi/interface.go @@ -20,9 +20,12 @@ type SCSI interface { GetNVMEDeviceWWN(ctx context.Context, devices []string) (string, error) GetDevicesByWWN(ctx context.Context, wwn string) ([]string, error) GetDMDeviceByChildren(ctx context.Context, devices []string) (string, error) + GetNVMEDMDeviceByChildren(ctx context.Context, devices []string) (string, error) + GetNVMEMultipathDMName(device string, pattern string) ([]string, error) GetDMChildren(ctx context.Context, dmPath string) ([]string, error) CheckDeviceIsValid(ctx context.Context, device string) bool GetDeviceNameByHCTL(ctx context.Context, h scsi.HCTL) (string, error) WaitUdevSymlink(ctx context.Context, deviceName string, wwn string) error WaitUdevSymlinkNVMe(ctx context.Context, deviceName string, wwn string) error + GetNVMESymlink(checkPath string) (string, error) } diff --git a/internal/scsi/scsi_mock.go b/internal/scsi/scsi_mock.go index 7f078b3..2110e97 100644 --- a/internal/scsi/scsi_mock.go +++ b/internal/scsi/scsi_mock.go @@ -166,6 +166,21 @@ func (mr *MockSCSIMockRecorder) GetDevicesByWWN(ctx, wwn interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevicesByWWN", reflect.TypeOf((*MockSCSI)(nil).GetDevicesByWWN), ctx, wwn) } +// GetNVMEDMDeviceByChildren mocks base method. +func (m *MockSCSI) GetNVMEDMDeviceByChildren(ctx context.Context, devices []string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNVMEDMDeviceByChildren", ctx, devices) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNVMEDMDeviceByChildren indicates an expected call of GetNVMEDMDeviceByChildren. +func (mr *MockSCSIMockRecorder) GetNVMEDMDeviceByChildren(ctx, devices interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNVMEDMDeviceByChildren", reflect.TypeOf((*MockSCSI)(nil).GetNVMEDMDeviceByChildren), ctx, devices) +} + // GetNVMEDeviceWWN mocks base method. func (m *MockSCSI) GetNVMEDeviceWWN(ctx context.Context, devices []string) (string, error) { m.ctrl.T.Helper() @@ -181,6 +196,36 @@ func (mr *MockSCSIMockRecorder) GetNVMEDeviceWWN(ctx, devices interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNVMEDeviceWWN", reflect.TypeOf((*MockSCSI)(nil).GetNVMEDeviceWWN), ctx, devices) } +// GetNVMEMultipathDMName mocks base method. +func (m *MockSCSI) GetNVMEMultipathDMName(device, pattern string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNVMEMultipathDMName", device, pattern) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNVMEMultipathDMName indicates an expected call of GetNVMEMultipathDMName. +func (mr *MockSCSIMockRecorder) GetNVMEMultipathDMName(device, pattern interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNVMEMultipathDMName", reflect.TypeOf((*MockSCSI)(nil).GetNVMEMultipathDMName), device, pattern) +} + +// GetNVMESymlink mocks base method. +func (m *MockSCSI) GetNVMESymlink(checkPath string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNVMESymlink", checkPath) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNVMESymlink indicates an expected call of GetNVMESymlink. +func (mr *MockSCSIMockRecorder) GetNVMESymlink(checkPath interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNVMESymlink", reflect.TypeOf((*MockSCSI)(nil).GetNVMESymlink), checkPath) +} + // IsDeviceExist mocks base method. func (m *MockSCSI) IsDeviceExist(ctx context.Context, device string) bool { m.ctrl.T.Helper() diff --git a/nvme_tcp.go b/nvme_tcp.go index b30da41..e98c1d9 100644 --- a/nvme_tcp.go +++ b/nvme_tcp.go @@ -398,7 +398,7 @@ func (c *NVMeTCPConnector) connectMultipathDevice( if wwn != "" && mpath == "" { var err error - mpath, err = c.scsi.GetDMDeviceByChildren(ctx, devices) + mpath, err = c.scsi.GetNVMEDMDeviceByChildren(ctx, devices) if err != nil { logger.Debug(ctx, "failed to get DM by children: %s", err.Error()) } diff --git a/pkg/scsi/scsi.go b/pkg/scsi/scsi.go index a9fe566..d5cce8a 100644 --- a/pkg/scsi/scsi.go +++ b/pkg/scsi/scsi.go @@ -36,12 +36,16 @@ import ( "golang.org/x/sync/singleflight" ) +// constants const ( - diskByIDPath = "/dev/disk/by-id/" - diskByIDSCSIPath = diskByIDPath + "scsi-" - diskByIDDMPath = diskByIDPath + "dm-uuid-mpath-" - diskByIDDMPathNVMe = diskByIDPath + "dm-uuid-mpath-eui." - scsiIDPath = "/lib/udev/scsi_id" + diskByIDPath = "/dev/disk/by-id/" + diskByIDSCSIPath = diskByIDPath + "scsi-" + diskByIDDMPath = diskByIDPath + "dm-uuid-mpath-" + diskByIDDMPathNVMe = diskByIDPath + "dm-uuid-mpath-eui." + scsiIDPath = "/lib/udev/scsi_id" + maxRetryCount = 10 + NVMEMultipathSleepTime = 500 + NVMESymlinkSleepTime = 200 ) // NewSCSI initializes scsi struct @@ -157,6 +161,12 @@ func (s *Scsi) GetDMDeviceByChildren(ctx context.Context, devices []string) (str return s.getDMDeviceByChildren(ctx, devices) } +// GetNVMEDMDeviceByChildren fetches multipath device name +func (s *Scsi) GetNVMEDMDeviceByChildren(ctx context.Context, devices []string) (string, error) { + defer tracer.TraceFuncCall(ctx, "scsi.GetNVMEDMDeviceByChildren")() + return s.getNVMEDMDeviceByChildren(ctx, devices) +} + // GetDMChildren fetches multipath block devices func (s *Scsi) GetDMChildren(ctx context.Context, dm string) ([]string, error) { defer tracer.TraceFuncCall(ctx, "scsi.GetDMChildren")() @@ -355,6 +365,54 @@ func (s *Scsi) getDMDeviceByChildren(ctx context.Context, devices []string) (str return "", errors.New("dm not found") } +//GetNVMEMultipathDMName finds the multipath DM mame for NVMe +func (s *Scsi) GetNVMEMultipathDMName(device string, pattern string) ([]string, error) { + + var retryCount = 0 + for { + matches, err := s.filePath.Glob(fmt.Sprintf(pattern, device)) + if len(matches) > 0 || retryCount == maxRetryCount { + return matches, err + } + time.Sleep(NVMEMultipathSleepTime * time.Millisecond) + retryCount = retryCount + 1 + } +} + +func (s *Scsi) getNVMEDMDeviceByChildren(ctx context.Context, devices []string) (string, error) { + logger.Info(ctx, "multipath - trying to find multipath DM name") + + pattern := "/sys/block/%s/holders/dm-*" + + var match string + + for _, d := range devices { + matches, err := s.GetNVMEMultipathDMName(d, pattern) + if err != nil { + return "", err + } + for _, m := range matches { + data, err := s.fileReader.ReadFile(path.Join(m, "dm/uuid")) + if err != nil { + logger.Error(ctx, "multipath - failed to read dm id file: %s", err.Error()) + continue + } + if strings.HasPrefix(string(data), "mpath") { + _, dm := path.Split(m) + if match == "" { + match = dm + } else if dm != match { + return "", &DevicesHaveDifferentParentsErr{} + } + } + } + } + if match != "" { + return match, nil + } + return "", errors.New("dm not found") +} + func (s *Scsi) getDMChildren(ctx context.Context, dm string) ([]string, error) { logger.Info(ctx, "multipath - get block device included in DM") var devices []string @@ -503,6 +561,20 @@ func (s *Scsi) waitUdevSymlink(ctx context.Context, deviceName string, wwn strin return nil } +//GetNVMESymlink return the NVMe symlink for the given path +func (s *Scsi) GetNVMESymlink(checkPath string) (string, error) { + + var retryCount = 1 + for { + symlink, err := s.filePath.EvalSymlinks(checkPath) + if err == nil || retryCount == maxRetryCount { + return symlink, err + } + time.Sleep(NVMESymlinkSleepTime * time.Millisecond) + retryCount = retryCount + 1 + } +} + func (s *Scsi) waitUdevSymlinkNVMe(ctx context.Context, deviceName string, wwn string) error { var checkPath string if strings.HasPrefix(deviceName, "dm-") { @@ -510,7 +582,7 @@ func (s *Scsi) waitUdevSymlinkNVMe(ctx context.Context, deviceName string, wwn s } else { checkPath = diskByIDSCSIPath + wwn } - symlink, err := s.filePath.EvalSymlinks(checkPath) + symlink, err := s.GetNVMESymlink(checkPath) if err != nil { msg := fmt.Sprintf("symlink for path %s not found: %s", checkPath, err.Error()) logger.Error(ctx, msg)