-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classification metrics #72
base: main
Are you sure you want to change the base?
Conversation
return [ yTrueTensor as Tensor1D, yPredTensor as Tensor1D, yTrueCount ]; | ||
}; | ||
|
||
export const accuracyScore = (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): number => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
await labelEncoder.init(concat([ yTrueTensor, yPredTensor ])); | ||
const yTrueEncode = await labelEncoder.encode(yTrueTensor); | ||
const yPredEncode = await labelEncoder.encode(yPredTensor); | ||
const numClasses = labelEncoder.categories.shape[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const numClasses = labelEncoder.categories.shape[0]; | |
const numOfClasses = labelEncoder.categories.shape[0]; |
src/metrics/classifier.ts
Outdated
const averageF1 = average == 'weighted' ? mul(f1s, weights).dataSync()[0] : divNoNan(sum(f1s), numClasses).dataSync()[0]; | ||
return { | ||
precisions: precisions, | ||
recalls: recalls, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recalls: recalls, | |
recalls, |
src/preprocess/encoder.ts
Outdated
@@ -1,4 +1,4 @@ | |||
import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk } from "@tensorflow/tfjs-core"; | |||
import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from "@tensorflow/tfjs-core"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from "@tensorflow/tfjs-core"; | |
import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from '@tensorflow/tfjs-core'; |
} | ||
this.cateMap = cateMap; | ||
} | ||
abstract encode(x: Tensor | number[] | string[]): Promise<Tensor>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use TensorLike1D
instead of the type expression?
src/preprocess/encoder.ts
Outdated
export class OneHotEncoder { | ||
public categories: Tensor; | ||
public cateMap: CateMap; | ||
export class OneHotEncoder extends EncoderBase{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export class OneHotEncoder extends EncoderBase{ | |
export class OneHotEncoder extends EncoderBase { |
src/preprocess/encoder.ts
Outdated
@@ -126,3 +132,38 @@ export class OneHotEncoder { | |||
return reshape(stack(cateTensors), [ -1 ]); | |||
} | |||
} | |||
|
|||
export class LabelEncoder extends EncoderBase{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export class LabelEncoder extends EncoderBase{ | |
export class LabelEncoder extends EncoderBase { |
*/ | ||
public async encode(x: Tensor | number[] | string[]): Promise<Tensor> { | ||
if (!this.categories) { | ||
throw TypeError('Please init encoder using init()'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw TypeError('Please init encoder using init()'); | |
throw new TypeError('Please initialize an encoder using `init()`'); |
src/preprocess/encoder.ts
Outdated
throw TypeError('Please init encoder using init()'); | ||
} | ||
const xTensor = checkArray(x, 'any', 1); | ||
const xData = await xTensor.dataSync(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const xData = await xTensor.dataSync(); | |
const xData = await xTensor.data(); |
const xTensor = checkArray(x, 'any', 1); | ||
const xData = await xTensor.dataSync(); | ||
xTensor.dispose(); | ||
return tensor(xData.map((d: number|string) => this.cateMap[d])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return tensor(xData.map((d: number|string) => this.cateMap[d])); | |
return tensor(xData.map((d) => this.cateMap[d])); |
Is the type required?
*/ | ||
public async decode(x: Tensor | number[]): Promise<Tensor> { | ||
if (!this.categories) { | ||
throw TypeError('Please init encoder using init()'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw TypeError('Please init encoder using init()'); | |
throw new TypeError('Please initialize an encoder using `init()`'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about implementing this in the base class?
test/node/preprocess/encoder.ts
Outdated
const encoder = new LabelEncoder(); | ||
await encoder.init(x); | ||
const xEncode = await encoder.encode(x); | ||
assert.deepEqual(xEncode.dataSync(), xLabelEncode.dataSync()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert.deepEqual(xEncode.dataSync(), xLabelEncode.dataSync()); | |
assert.deepEqual(await xEncode.data(), await xLabelEncode.data()); |
There is no need to use dataSync()
here.
test/node/preprocess/encoder.ts
Outdated
it('encode', async () => { | ||
const encoder = new LabelEncoder(); | ||
await encoder.init(x); | ||
const xEncode = await encoder.encode(x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xEncode
is declared at line#9, how about a new name?
test/node/preprocess/encoder.ts
Outdated
const encoder = new LabelEncoder(); | ||
await encoder.init(x); | ||
const xDecode = await encoder.decode(xLabelEncode); | ||
assert.deepEqual(x, xDecode.dataSync() as any); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, use await *.data()
instead of dataSync()
.
src/metrics/classifier.ts
Outdated
@@ -1,15 +1,77 @@ | |||
import { Tensor, equal, sum, div } from '@tensorflow/tfjs-core'; | |||
import { Tensor, equal, sum, div, math, Tensor1D, divNoNan, concat, mul, add, cast } from '@tensorflow/tfjs-core'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import { Tensor, equal, sum, div, math, Tensor1D, divNoNan, concat, mul, add, cast } from '@tensorflow/tfjs-core'; | |
import { Tensor, Tensor1D, equal, sum, div, math, divNoNan, concat, mul, add, cast } from '@tensorflow/tfjs-core'; |
src/metrics/classifier.ts
Outdated
* @param yPred predicted labels | ||
* @returns classification report object, the struct of report will be like following | ||
*/ | ||
export const classificationReport = async(yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'weighted'): Promise<ClassificationReport> => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this line is too long(over 80 chars).
src/metrics/classifier.ts
Outdated
* @param yTrue true labels | ||
* @param yPred predicted labels | ||
*/ | ||
export const checkSameLength = (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): [ Tensor1D, Tensor1D, number ] => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function must be moved to utils/validation.
src/metrics/classifier.ts
Outdated
const f1s = divNoNan(divNoNan(mul(precisions, recalls), add(precisions, recalls)), 2); | ||
const accuracy = accuracyScore(yTrue, yPred); | ||
const weights = divNoNan(sum(confusionMatrix, 0), sum(confusionMatrix)); | ||
const averagePrecision = average == 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const averagePrecision = average == 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; | |
const averagePrecision = average === 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; |
src/metrics/classifier.ts
Outdated
* @param yPred predicted labels | ||
* @returns classification report object, the struct of report will be like following | ||
*/ | ||
export const classificationReport = async(yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'weighted'): Promise<ClassificationReport> => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export const classificationReport = async(yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'weighted'): Promise<ClassificationReport> => { | |
export const classificationReport = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'weighted'): Promise<ClassificationReport> => { |
src/metrics/classifier.ts
Outdated
const f1s = divNoNan(divNoNan(mul(precisions, recalls), add(precisions, recalls)), 2); | ||
const accuracy = accuracyScore(yTrue, yPred); | ||
const weights = divNoNan(sum(confusionMatrix, 0), sum(confusionMatrix)); | ||
const averagePrecision = average == 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const averagePrecision = average == 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; | |
const averagePrecision = await (average == 'weighted' ? mul(precisions, weights).data() : divNoNan(sum(precisions), numClasses).data())[0]; |
src/metrics/classifier.ts
Outdated
const weights = divNoNan(sum(confusionMatrix, 0), sum(confusionMatrix)); | ||
const averagePrecision = average == 'weighted' ? mul(precisions, weights).dataSync()[0] : divNoNan(sum(precisions), numClasses).dataSync()[0]; | ||
const averageRecall = average == 'weighted' ? mul(recalls, weights).dataSync()[0] : divNoNan(sum(recalls), numClasses).dataSync()[0]; | ||
const averageF1 = average == 'weighted' ? mul(f1s, weights).dataSync()[0] : divNoNan(sum(f1s), numClasses).dataSync()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
|
There are still some comments not getting resolved. |
No description provided.