Skip to content

Commit

Permalink
fix(bug): Added mutext for downloader providers. Fixes kubeflow#1531 (k…
Browse files Browse the repository at this point in the history
…ubeflow#1539)

* Added mutext

* add package

* pass as ref instead of copy

* updated after feedback
  • Loading branch information
NikeNano authored Apr 23, 2021
1 parent 9ac96ba commit 5a7e06a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 33 deletions.
17 changes: 9 additions & 8 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ package main
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"time"

"github.com/kelseyhightower/envconfig"
"github.com/kubeflow/kfserving/pkg/agent"
"github.com/kubeflow/kfserving/pkg/agent/storage"
Expand All @@ -21,13 +29,6 @@ import (
"knative.dev/serving/pkg/queue"
"knative.dev/serving/pkg/queue/health"
"knative.dev/serving/pkg/queue/readiness"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"time"
)

var (
Expand Down Expand Up @@ -282,7 +283,7 @@ func startModelPuller(logger *zap.SugaredLogger) {
}
watcher := agent.NewWatcher(*configDir, *modelDir, logger)
logger.Info("Starting puller")
agent.StartPullerAndProcessModels(downloader, watcher.ModelEvents, logger)
agent.StartPullerAndProcessModels(&downloader, watcher.ModelEvents, logger)
go watcher.Start()
}

Expand Down
13 changes: 9 additions & 4 deletions pkg/agent/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"github.com/pkg/errors"
"go.uber.org/zap"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
"sync"

"github.com/kubeflow/kfserving/pkg/agent/storage"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"github.com/pkg/errors"
"go.uber.org/zap"
)

type Downloader struct {
ModelDir string
mu sync.Mutex
Providers map[storage.Protocol]storage.Provider
Logger *zap.SugaredLogger
}
Expand Down Expand Up @@ -77,7 +80,9 @@ func (d *Downloader) download(modelName string, storageUri string) error {
if err != nil {
return errors.Wrapf(err, "unsupported protocol")
}
d.mu.Lock()
provider, err := storage.GetProvider(d.Providers, protocol)
d.mu.Unlock()
if err != nil {
return errors.Wrapf(err, "unable to create or get provider for protocol %s", protocol)
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/agent/puller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package agent
import (
"bytes"
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
v1 "github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"go.uber.org/zap"
"io/ioutil"
"net/http"
"path/filepath"
"sync"

"github.com/kubeflow/kfserving/pkg/agent/storage"
v1 "github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"go.uber.org/zap"
)

type OpType string
Expand All @@ -40,7 +41,7 @@ type Puller struct {
completions chan *ModelOp
opStats map[string]map[OpType]int
waitGroup WaitGroupWrapper
Downloader Downloader
Downloader *Downloader
logger *zap.SugaredLogger
}

Expand All @@ -55,7 +56,7 @@ type WaitGroupWrapper struct {
wg sync.WaitGroup
}

func StartPullerAndProcessModels(downloader Downloader, commands <-chan ModelOp, logger *zap.SugaredLogger) {
func StartPullerAndProcessModels(downloader *Downloader, commands <-chan ModelOp, logger *zap.SugaredLogger) {
puller := Puller{
channelMap: make(map[string]*ModelChannel),
completions: make(chan *ModelOp, 4),
Expand Down
33 changes: 17 additions & 16 deletions pkg/agent/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ limitations under the License.
package agent

import (
gstorage "cloud.google.com/go/storage"
"context"
"fmt"
"io/ioutil"
logger "log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"

gstorage "cloud.google.com/go/storage"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/kubeflow/kfserving/pkg/agent/mocks"
Expand All @@ -29,14 +37,7 @@ import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"go.uber.org/zap"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/resource"
logger "log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
)

var _ = Describe("Watcher", func() {
Expand Down Expand Up @@ -87,7 +88,7 @@ var _ = Describe("Watcher", func() {
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
waitGroup: WaitGroupWrapper{sync.WaitGroup{}},
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test1",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -120,7 +121,7 @@ var _ = Describe("Watcher", func() {
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
waitGroup: WaitGroupWrapper{sync.WaitGroup{}},
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test1",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -166,7 +167,7 @@ var _ = Describe("Watcher", func() {
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
waitGroup: WaitGroupWrapper{sync.WaitGroup{}},
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test2",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -224,7 +225,7 @@ var _ = Describe("Watcher", func() {
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
waitGroup: WaitGroupWrapper{sync.WaitGroup{}},
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test3",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -298,7 +299,7 @@ var _ = Describe("Watcher", func() {
channelMap: make(map[string]*ModelChannel),
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test4",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -470,7 +471,7 @@ var _ = Describe("Watcher", func() {
channelMap: make(map[string]*ModelChannel),
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test1",
Providers: map[storage.Protocol]storage.Provider{
storage.GCS: &cl,
Expand All @@ -496,7 +497,7 @@ var _ = Describe("Watcher", func() {
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
waitGroup: WaitGroupWrapper{sync.WaitGroup{}},
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test2",
Providers: map[storage.Protocol]storage.Provider{
storage.S3: &storage.S3Provider{
Expand Down Expand Up @@ -707,7 +708,7 @@ var _ = Describe("Watcher", func() {
channelMap: make(map[string]*ModelChannel),
completions: make(chan *ModelOp, 4),
opStats: make(map[string]map[OpType]int),
Downloader: Downloader{
Downloader: &Downloader{
ModelDir: modelDir + "/test1",
Providers: map[storage.Protocol]storage.Provider{
storage.HTTPS: &cl,
Expand Down

0 comments on commit 5a7e06a

Please sign in to comment.