-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Custom Vision Training and Prediction Samples (Azure-Samples#8)
* Add Custom Vision Training and Prediction Samples * Update readme, requirements to install and keys. * Add the samples dir to import path so imports can be found. * Add missing import
- Loading branch information
Showing
26 changed files
with
108 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import sys | ||
|
||
from azure.cognitiveservices.vision.customvision.training import training_api | ||
from azure.cognitiveservices.vision.customvision.prediction import prediction_endpoint | ||
from azure.cognitiveservices.vision.customvision.prediction.prediction_endpoint import models | ||
|
||
TRAINING_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY" | ||
SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_PREDICTION_KEY" | ||
|
||
# Add this directory to the path so that custom_vision_training_samples can be found | ||
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".")) | ||
|
||
IMAGES_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "images") | ||
|
||
def find_or_train_project(): | ||
try: | ||
training_key = os.environ[TRAINING_KEY_ENV_NAME] | ||
except KeyError: | ||
raise SubscriptionKeyError("You need to set the {} env variable.".format(TRAINING_KEY_ENV_NAME)) | ||
|
||
# Use the training API to find the SDK sample project created from the training example. | ||
from custom_vision_training_samples import train_project, SAMPLE_PROJECT_NAME | ||
trainer = training_api.TrainingApi(training_key) | ||
|
||
for proj in trainer.get_projects(): | ||
if (proj.name == SAMPLE_PROJECT_NAME): | ||
return proj | ||
|
||
# Or, if not found, we will run the training example to create it. | ||
return train_project(training_key) | ||
|
||
def predict_project(subscription_key): | ||
predictor = prediction_endpoint.PredictionEndpoint(subscription_key) | ||
|
||
# Find or train a new project to use for prediction. | ||
project = find_or_train_project() | ||
|
||
with open(os.path.join(IMAGES_FOLDER, "Test", "test_image.jpg"), mode="rb") as test_data: | ||
results = predictor.predict_image(project.id, test_data.read()) | ||
|
||
# Display the results. | ||
for prediction in results.predictions: | ||
print ("\t" + prediction.tag + ": {0:.2f}%".format(prediction.probability * 100)) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys, os.path | ||
sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) | ||
from tools import execute_samples, SubscriptionKeyError | ||
execute_samples(globals(), SUBSCRIPTION_KEY_ENV_NAME) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
import time | ||
|
||
from azure.cognitiveservices.vision.customvision.training import training_api | ||
|
||
SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY" | ||
SAMPLE_PROJECT_NAME = "Python SDK Sample" | ||
|
||
IMAGES_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "images") | ||
|
||
def train_project(subscription_key): | ||
|
||
trainer = training_api.TrainingApi(subscription_key) | ||
|
||
# Create a new project | ||
print ("Creating project...") | ||
project = trainer.create_project(SAMPLE_PROJECT_NAME) | ||
|
||
# Make two tags in the new project | ||
hemlock_tag = trainer.create_tag(project.id, "Hemlock") | ||
cherry_tag = trainer.create_tag(project.id, "Japanese Cherry") | ||
|
||
print ("Adding images...") | ||
hemlock_dir = os.path.join(IMAGES_FOLDER, "Hemlock") | ||
for image in os.listdir(hemlock_dir): | ||
with open(os.path.join(hemlock_dir, image), mode="rb") as img_data: | ||
trainer.create_images_from_data(project.id, img_data.read(), [ hemlock_tag.id ]) | ||
|
||
cherry_dir = os.path.join(IMAGES_FOLDER, "Japanese Cherry") | ||
for image in os.listdir(cherry_dir): | ||
with open(os.path.join(cherry_dir, image), mode="rb") as img_data: | ||
trainer.create_images_from_data(project.id, img_data.read(), [ cherry_tag.id ]) | ||
|
||
print ("Training...") | ||
iteration = trainer.train_project(project.id) | ||
while (iteration.status == "Training"): | ||
iteration = trainer.get_iteration(project.id, iteration.id) | ||
print ("Training status: " + iteration.status) | ||
time.sleep(1) | ||
|
||
# The iteration is now trained. Make it the default project endpoint | ||
trainer.update_iteration(project.id, iteration.id, is_default=True) | ||
print ("Done!") | ||
return project | ||
|
||
if __name__ == "__main__": | ||
import sys, os.path | ||
sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) | ||
from tools import execute_samples | ||
execute_samples(globals(), SUBSCRIPTION_KEY_ENV_NAME) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters