forked from kubeflow/code-intelligence
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Combine repository specific labels with the universal model for issue…
… kind. * To combine multiple models we define a base class IssueLabelModel that defines a common interface for all the models. * This way different models can just implement that interface and we can easily combine the results. * UniveralKindLabelModel will be used for the generic model to predict label kind that is trained on all repositories. * The UniversalKindLabelModel class is based on the IssueLabeler code https://github.com/machine-learning-apps/Issue-Label-Bot/blob/536e8bf4928b03d522dd021c0464587747e90a87/flask_app/utils.py#L67 * kubeflow#70 Combine multiple models
- Loading branch information
Jeremy Lewi
committed
Dec 23, 2019
1 parent
03709c7
commit 80ef7f7
Showing
3 changed files
with
103 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""The models packages defines wrappers around different models.""" | ||
|
||
class IssueLabelerModel: | ||
"""A base class for all Issue label models. | ||
This class defines a common interface for all issue label models. | ||
""" | ||
|
||
def predict_issue_labels(self, title:str , text:str ): | ||
"""Return a dictionary of label probabilities. | ||
Args: | ||
title: The title for the issue | ||
text: The text for the issue | ||
Return | ||
------ | ||
dict: Dictionary of label to probability of that label for the | ||
the issue str -> float | ||
""" | ||
raise NotImplementedError("predict_issue_probability should be overridden " | ||
"in a subclass.") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
from tensorflow.keras import models as keras_models | ||
from tensorflow.keras import utils as keras_utils | ||
|
||
import dill as dpickle | ||
|
||
from urllib.request import urlopen | ||
from label_microservice import models | ||
|
||
class UniversalKindLabelModel(models.IssueLabelerModel): | ||
"""UniversalKindLabelModel is a universal model that is trained across all repos. | ||
The model predicts the kind for an issue. | ||
""" | ||
def __init__(self, class_names=['bug', 'feature', 'question']): | ||
"""Instantiate the model. | ||
Args: | ||
class_names: The specific label names to use for the three classes. | ||
""" | ||
super(UniversalKindLabelModel, self).__init__() | ||
|
||
# TODO(jlewi): We should probably parameterize the models rather than | ||
# hardcoding it. | ||
title_pp_url = "https://storage.googleapis.com/codenet/issue_labels/issue_label_model_files/title_pp.dpkl" | ||
body_pp_url = 'https://storage.googleapis.com/codenet/issue_labels/issue_label_model_files/body_pp.dpkl' | ||
model_url = 'https://storage.googleapis.com/codenet/issue_labels/issue_label_model_files/Issue_Label_v1_best_model.hdf5' | ||
model_filename = 'downloaded_model.hdf5' | ||
|
||
with urlopen(title_pp_url) as f: | ||
self.title_pp = dpickle.load(f) | ||
|
||
with urlopen(body_pp_url) as f: | ||
self.body_pp = dpickle.load(f) | ||
|
||
model_path = keras_utils.get_file(fname=model_filename, origin=model_url) | ||
self.model = keras_models.load_model(model_path) | ||
|
||
self.class_names = class_names | ||
|
||
def predict_issue_labels(self, body:str, title:str): | ||
""" | ||
Get probabilities for the each class. | ||
Parameters | ||
---------- | ||
body: str | ||
the issue body | ||
title: str | ||
the issue title | ||
Returns | ||
------ | ||
Dict[str:float] | ||
Example | ||
------- | ||
>>> issue_labeler = IssueLabeler(body_pp, title_pp, model) | ||
>>> issue_labeler.get_probabilities('hello world', 'hello world') | ||
{'bug': 0.08372017741203308, | ||
'feature': 0.6401631832122803, | ||
'question': 0.2761166989803314} | ||
""" | ||
#transform raw text into array of ints | ||
vec_body = self.body_pp.transform([body]) | ||
vec_title = self.title_pp.transform([title]) | ||
|
||
# get predictions | ||
probs = self.model.predict(x=[vec_body, vec_title]).tolist()[0] | ||
|
||
return {k:v for k,v in zip(self.class_names, probs)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters