Skip to content

Commit

Permalink
Fixed memory leak, added CumulativeConfidence function, fixed predict…
Browse files Browse the repository at this point in the history
…ion result sorting
  • Loading branch information
Loren Kuich committed Dec 18, 2019
1 parent 199d048 commit 424177f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
62 changes: 44 additions & 18 deletions Assets/Coach-ML/Coach.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Barracuda;
using UnityEngine.Networking;
using System.Net;
using UnityEngine.Serialization;

namespace Coach
{
Expand Down Expand Up @@ -100,25 +99,27 @@ public static class ImageUtil
private static Texture2D Scale(this Texture2D tex, int width, int height, FilterMode mode = FilterMode.Trilinear)
{
Rect texR = new Rect(0, 0, width, height);
GpuScale(tex, width, height, mode);
RenderTexture rt = GpuScale(tex, width, height, mode);

// Update new texture
tex.Resize(width, height);
tex.ReadPixels(texR, 0, 0, true);
tex.Apply(true);

RenderTexture.ReleaseTemporary(rt);

return tex;
}

// Internal unility that renders the source texture into the RTT - the scaling method itself.
private static void GpuScale(Texture2D src, int width, int height, FilterMode fmode)
private static RenderTexture GpuScale(Texture2D src, int width, int height, FilterMode fmode)
{
//We need the source texture in VRAM because we render with it
src.filterMode = fmode;
src.Apply(true);

//Using RTT for best quality and performance. Thanks, Unity 5
RenderTexture rtt = new RenderTexture(width, height, 32);
var rtt = RenderTexture.GetTemporary(width, height, 32);

//Set the RTT in order to render to it
Graphics.SetRenderTarget(rtt);
Expand All @@ -129,6 +130,8 @@ private static void GpuScale(Texture2D src, int width, int height, FilterMode fm
//Then clear & draw the texture to fill the entire RTT.
GL.Clear(true, true, new Color(0, 0, 0, 0));
Graphics.DrawTexture(new Rect(0, 0, 1, 1), src);

return rtt;
}

private static Tensor ToTensor(this Texture2D tex, ImageDims dims)
Expand Down Expand Up @@ -180,30 +183,29 @@ public class CoachResult
///<summary>
//Unsorted prediction results
///</summary>
public List<LabelProbability> Results { get; private set; }
public LabelProbability[] Results { get; private set; }

///<summary>
//Sorted prediction results, descending in Confidence
///</summary>
// public List<LabelProbability> SortedResults { get; private set; }
public LabelProbability[] SortedResults { get; private set; }

public CoachResult(string[] labels, Tensor output)
{
Debug.LogWarning(output);
Results = new List<LabelProbability>();
Results = new LabelProbability[labels.Length];

for (var i = 0; i < labels.Length; i++)
{
string label = labels[i];
float probability = output[i];

Results.Add(new LabelProbability()
Results[i] = new LabelProbability()
{
Label = label,
Confidence = probability
});
};
}
// SortedResults = Results.OrderByDescending(r => r.Confidence).ToList();
SortedResults = Results.OrderByDescending(r => r.Confidence).ToArray();

output.Dispose();
}
Expand All @@ -213,15 +215,27 @@ public CoachResult(string[] labels, Tensor output)
///</summary>
public LabelProbability Best()
{
return Results.FirstOrDefault();
return SortedResults.FirstOrDefault();
}

///<summary>
///Least Confident result
///</summary>
public LabelProbability Worst()
{
return Results.LastOrDefault();
return SortedResults.LastOrDefault();
}
}

public struct CumulativeConfidenceResult
{
public float Threshhold;
public CoachResult LastResult;
public float CumulativeConfidence;

public bool IsPassedThreshold()
{
return CumulativeConfidence >= Threshhold;
}
}

Expand Down Expand Up @@ -314,6 +328,18 @@ public CoachResult Predict(byte[] image, string inputName = "input", string outp
return GetModelResult(imageTensor, inputName, outputName);
}

public void CumulativeConfidence(Texture2D image, float threshhold, ref CumulativeConfidenceResult result)
{
var prediction = Predict(image);
result.LastResult = prediction;
result.Threshhold = threshhold;

if (result.LastResult.Best().Label != prediction.Best().Label)
result.CumulativeConfidence = 0;
else if (result.CumulativeConfidence <= threshhold)
result.CumulativeConfidence += prediction.Best().Confidence;
}

private CoachResult GetModelResult(Tensor imageTensor, string inputName = "input", string outputName = "output")
{
var inputs = new Dictionary<string, Tensor>();
Expand All @@ -338,18 +364,15 @@ public void CleanUp()
[Serializable]
public class StatusDef
{
[FormerlySerializedAs("short")]
public string _short;

[FormerlySerializedAs("long")]
public string _long;
}

[Serializable]
public class ModelDef
{
public string name;
public StatusDef status;
// public StatusDef status;
public int version;
public string module;
public string[] labels;
Expand Down Expand Up @@ -456,7 +479,10 @@ public async Task CacheModel(string modelName, string path = ".", bool skipMatch
if (!IsAuthenticated())
throw new Exception("User is not authenticated");

ModelDef model = this.Profile.models.Single(m => m.name == modelName);
ModelDef model = this.Profile.models.SingleOrDefault(m => m.name == modelName);
if (model == null)
throw new Exception($"{modelName} is an invalid model");

int version = model.version;

string modelDir = Path.Combine(path, modelName);
Expand Down
18 changes: 8 additions & 10 deletions Assets/Scenes/SampleScene.unity
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ LightmapSettings:
m_PVRFilteringAtrousPositionSigmaDirect: 0.5
m_PVRFilteringAtrousPositionSigmaIndirect: 2
m_PVRFilteringAtrousPositionSigmaAO: 1
m_ShowResolutionOverlay: 1
m_ExportTrainingData: 0
m_TrainingDataDestination: TrainingData
m_LightProbeSampleCountMultiplier: 4
m_LightingDataAsset: {fileID: 0}
m_UseShadowmask: 1
--- !u!196 &4
Expand Down Expand Up @@ -170,7 +171,7 @@ MonoBehaviour:
m_GameObject: {fileID: 211247794}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 1301386320, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: dc42784cf147c0c48a680349fa168899, type: 3}
m_Name:
m_EditorClassIdentifier:
m_IgnoreReversedGraphics: 1
Expand All @@ -187,7 +188,7 @@ MonoBehaviour:
m_GameObject: {fileID: 211247794}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 1980459831, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 0cd44c1031e13a943bb63640046fad76, type: 3}
m_Name:
m_EditorClassIdentifier:
m_UiScaleMode: 1
Expand Down Expand Up @@ -230,10 +231,9 @@ MonoBehaviour:
m_GameObject: {fileID: 211247794}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 0c0359307f108214fa7f019386a8b150, type: 3}
m_Script: {fileID: 11500000, guid: ac400c37f901663459a68cbfc8c1dfc0, type: 3}
m_Name:
m_EditorClassIdentifier:
Image: {fileID: 211247801}
--- !u!222 &211247800
CanvasRenderer:
m_ObjectHideFlags: 0
Expand All @@ -251,7 +251,7 @@ MonoBehaviour:
m_GameObject: {fileID: 211247794}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: -98529514, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 1344c3c82d62a2a41a3576d8abb8e3ea, type: 3}
m_Name:
m_EditorClassIdentifier:
m_Material: {fileID: 0}
Expand All @@ -260,8 +260,6 @@ MonoBehaviour:
m_OnCullStateChanged:
m_PersistentCalls:
m_Calls: []
m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI,
Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
m_Texture: {fileID: 2800000, guid: 835e52329d78a3d44b7a1df76348098b, type: 3}
m_UVRect:
serializedVersion: 2
Expand Down Expand Up @@ -380,7 +378,7 @@ MonoBehaviour:
m_GameObject: {fileID: 632981135}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 1077351063, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 4f231c4fb786f3946a6b90b886c48677, type: 3}
m_Name:
m_EditorClassIdentifier:
m_HorizontalAxis: Horizontal
Expand All @@ -399,7 +397,7 @@ MonoBehaviour:
m_GameObject: {fileID: 632981135}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: -619905303, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 76c392e42b5098c458856cdf6ecaaaa1, type: 3}
m_Name:
m_EditorClassIdentifier:
m_FirstSelected: {fileID: 0}
Expand Down

0 comments on commit 424177f

Please sign in to comment.