Skip to content

Commit

Permalink
Merge pull request #18 from markvai/add_closers_to_models
Browse files Browse the repository at this point in the history
Add close methods to all tasks
  • Loading branch information
matteo-grella authored Jun 12, 2023
2 parents ee9a7d9 + 25816b2 commit f7bd8ab
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pkg/tasks/languagemodeling/bert/languagemodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ func LoadMaskedLanguageModel(modelPath string) (*LanguageModel, error) {
}, nil
}

// Close finalizes the LanguageModel resources.
// It satisfies the interface io.Closer.
func (m *LanguageModel) Close() error {
return m.embeddingsRepo.Close()
}

// Predict returns the predicted tokens
func (m *LanguageModel) Predict(_ context.Context, text string, parameters languagemodeling.Parameters) (languagemodeling.Response, error) {
if parameters.K == 0 {
Expand Down
6 changes: 6 additions & 0 deletions pkg/tasks/questionanswering/bert/questionanswering.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ func LoadQuestionAnswering(modelPath string) (*QuestionAnswering, error) {
}, nil
}

// Close finalizes the QuestionAnswering resources.
// It satisfies the interface io.Closer.
func (qa *QuestionAnswering) Close() error {
return qa.embeddingsRepo.Close()
}

// Answer returns the answers for the given question and passage.
// The options may assume default values if those are not set.
func (qa *QuestionAnswering) Answer(_ context.Context, question string, passage string, opts *questionanswering.Options) (questionanswering.Response, error) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/tasks/textclassification/bert/textclassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ func ID2Label(value map[string]string) []string {
return y
}

// Close finalizes the TextClassification resources.
// It satisfies the interface io.Closer.
func (m *TextClassification) Close() error {
return m.embeddingsRepo.Close()
}

// Classify returns the classification of the given text.
func (m *TextClassification) Classify(_ context.Context, text string) (textclassification.Response, error) {
tokenized := m.tokenize(text)
Expand Down
6 changes: 6 additions & 0 deletions pkg/tasks/textencoding/bert/textencoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ func LoadTextEncoding(modelPath string) (*TextEncoding, error) {
}, nil
}

// Close finalizes the TextEncoding resources.
// It satisfies the interface io.Closer.
func (m *TextEncoding) Close() error {
return m.embeddingsRepo.Close()
}

// Encode returns the dense encoded representation of the given text.
func (m *TextEncoding) Encode(_ context.Context, text string, poolingStrategy int) (textencoding.Response, error) {
tokenized := m.tokenize(text)
Expand Down
6 changes: 6 additions & 0 deletions pkg/tasks/tokenclassification/bert/tokenclassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func ID2Label(value map[string]string) []string {
return y
}

// Close finalizes the TokenClassification resources.
// It satisfies the interface io.Closer.
func (m *TokenClassification) Close() error {
return m.embeddingsRepo.Close()
}

// Classify returns the classification of the given text.
func (m *TokenClassification) Classify(_ context.Context, text string, parameters tokenclassification.Parameters) (tokenclassification.Response, error) {
tokenized := m.tokenize(text)
Expand Down
6 changes: 6 additions & 0 deletions pkg/tasks/tokenclassification/flair/tokenclassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ func ID2Label(value map[string]string) []string {
return y
}

// Close finalizes the TokenClassification resources.
// It satisfies the interface io.Closer.
func (m *TokenClassification) Close() error {
return m.embeddingsRepo.Close()
}

// Classify returns the classification of the given text.
func (m *TokenClassification) Classify(_ context.Context, text string, parameters tokenclassification.Parameters) (tokenclassification.Response, error) {
tokenized := m.tokenize(text)
Expand Down

0 comments on commit f7bd8ab

Please sign in to comment.