-
Notifications
You must be signed in to change notification settings - Fork 6.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]Add Online Explainability notebooks for SageMaker Clarify #3613
Changes from all commits
d7c7408
58322c4
974ad94
d6acf89
37c1d27
a61dd5d
a8d1a2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from io import StringIO | ||
import numpy as np | ||
import os | ||
import pandas as pd | ||
import json | ||
from transformers import AutoTokenizer, AutoModelForSequenceClassification | ||
import torch | ||
from typing import Any, Dict, List | ||
|
||
|
||
def model_fn(model_dir: str) -> Dict[str, Any]: | ||
""" | ||
Load the model for inference | ||
""" | ||
model_path = os.path.join(model_dir, "model") | ||
|
||
# Load HuggingFace tokenizer. | ||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | ||
|
||
# Load HuggingFace model from disk. | ||
model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.to(device) | ||
model.eval() | ||
model_dict = {"model": model, "tokenizer": tokenizer} | ||
return model_dict | ||
|
||
|
||
def predict_fn(input_data: List, model: Dict) -> np.ndarray: | ||
""" | ||
Apply model to the incoming request | ||
""" | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
tokenizer = model["tokenizer"] | ||
huggingface_model = model["model"] | ||
|
||
encoded_input = tokenizer(input_data, truncation=True, padding=True, max_length=128, return_tensors="pt") | ||
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} | ||
with torch.no_grad(): | ||
output = huggingface_model(input_ids=encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"]) | ||
res = torch.nn.Softmax(dim=1)(output.logits).detach().cpu().numpy()[:, 1] | ||
return res | ||
|
||
|
||
def input_fn(request_body: str, request_content_type: str) -> List[str]: | ||
""" | ||
Deserialize and prepare the prediction input | ||
""" | ||
if request_content_type == "application/json": | ||
sentences = [json.loads(request_body)] | ||
|
||
elif request_content_type == "text/csv": | ||
# We have a single column with the text. | ||
sentences = list(pd.read_csv(StringIO(request_body), header=None).values[:, 0].astype(str)) | ||
else: | ||
sentences = request_body | ||
return sentences |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
transformers==4.2.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this file being used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, all files in code directory are being packaged in model container. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like its saying pip install sagemaker captum --upgrade and then it doesn't know what captum is, I believe the fix here is to get rid of sagemaker and just make it !pip install captum --upgrade? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be, though when manually running the notebook it's not causing an issue. I can try that once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aaronmarkham @rdamazon Tried this, bot execution is still failing with the same exception. I was able to execute the notebook in manual run successfully. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're currently testing this on our end |
||
torch==1.7.1 | ||
pandas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.
Problem:
This line of code lacks validation when processing input data through the following parameter: 'request_body' (index: 0 | type: str). The parameter is exposed to external callers, because its enclosing class and method are publicly accessible. This means that upstream validation, if it exists, can be bypassed. Other validated parameters: 'request_content_type'. The same parameter type is validated here for example: amazon-sagemaker-examples/sagemaker-clarify/online_explainability/natural_language_processing/code/inference.py:15. Malicious, malformed, or unbounded inputs can cause unexpected runtime behavior or crashes, and can slow performance.
Fix:
Add checks to ensure the validity of the parameter's value, such as testing it for nullness, emptiness, or equality. Or to prevent direct calls to it, reduce the method's visibility using single or double underscore.
Learn more about potential threats and guidance from the Common Weakness Enumeration website and the OWASP Cheat Sheet series.