Skip to content

Commit

Permalink
feat: support jstype custom options
Browse files Browse the repository at this point in the history
Allow overriding 64 bit number types with strings or regular js numbers.

Closes #112
  • Loading branch information
achingbrain committed Oct 13, 2023
1 parent 1d6e843 commit a6dfb0a
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 13 deletions.
101 changes: 88 additions & 13 deletions packages/protons/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ const types: Record<string, string> = {
uint64: 'bigint'
}

const jsTypeOverrides: Record<string, string> = {
JS_NUMBER: 'number',
JS_STRING: 'string'
}

const encoderGenerators: Record<string, (val: string) => string> = {
bool: (val) => `w.bool(${val})`,
bytes: (val) => `w.bytes(${val})`,
Expand All @@ -55,6 +60,11 @@ const encoderGenerators: Record<string, (val: string) => string> = {
uint64: (val) => `w.uint64(${val})`
}

const encoderGeneratorsJsTypeOverrides: Record<string, (val: string) => string> = {
number: (val: string) => `BigInt(${val})`,
string: (val: string) => `BigInt(${val})`
}

const decoderGenerators: Record<string, () => string> = {
bool: () => 'reader.bool()',
bytes: () => 'reader.bytes()',
Expand All @@ -73,6 +83,11 @@ const decoderGenerators: Record<string, () => string> = {
uint64: () => 'reader.uint64()'
}

const decoderGeneratorsJsTypeOverrides: Record<string, (original: string) => string> = {
number: (original: string) => `Number(${original})`,
string: (original: string) => `String(${original})`
}

const defaultValueGenerators: Record<string, () => string> = {
bool: () => 'false',
bytes: () => 'new Uint8Array(0)',
Expand All @@ -91,6 +106,11 @@ const defaultValueGenerators: Record<string, () => string> = {
uint64: () => '0n'
}

const defaultValueGeneratorsJsTypeOverrides: Record<string, () => string> = {
number: () => '0',
string: () => "''"
}

const defaultValueTestGenerators: Record<string, (field: string) => string> = {
bool: (field) => `(${field} != null && ${field} !== false)`,
bytes: (field) => `(${field} != null && ${field}.byteLength > 0)`,
Expand All @@ -109,7 +129,28 @@ const defaultValueTestGenerators: Record<string, (field: string) => string> = {
uint64: (field) => `(${field} != null && ${field} !== 0n)`
}

function findTypeName (typeName: string, classDef: MessageDef, moduleDef: ModuleDef): string {
const defaultValueTestGeneratorsJsTypeOverrides: Record<string, (field: string) => string> = {
number: (field) => `(${field} != null && ${field} !== 0)`,
string: (field) => `(${field} != null && ${field} !== '')`
}

function findJsTypeOverride (defaultType: string, fieldDef: FieldDef): string | undefined {
if (fieldDef.options?.jstype != null && jsTypeOverrides[fieldDef.options?.jstype] != null) {
if (!['int64', 'uint64', 'sint64', 'fixed64', 'sfixed64'].includes(defaultType)) {
throw new Error(`jstype is only allowed on int64, uint64, sint64, fixed64 or sfixed64 fields - got "${defaultType}"`)
}

return jsTypeOverrides[fieldDef.options?.jstype]
}
}

function findJsTypeName (typeName: string, classDef: MessageDef, moduleDef: ModuleDef, fieldDef: FieldDef): string {
const override = findJsTypeOverride(typeName, fieldDef)

if (override != null) {
return override
}

if (types[typeName] != null) {
return types[typeName]
}
Expand All @@ -123,7 +164,7 @@ function findTypeName (typeName: string, classDef: MessageDef, moduleDef: Module
}

if (classDef.parent != null) {
return findTypeName(typeName, classDef.parent, moduleDef)
return findJsTypeName(typeName, classDef.parent, moduleDef, fieldDef)
}

if (moduleDef.globals[typeName] != null) {
Expand Down Expand Up @@ -170,9 +211,16 @@ function createDefaultObject (fields: Record<string, FieldDef>, messageDef: Mess

const type: string = fieldDef.type
let defaultValue
let defaultValueGenerator = defaultValueGenerators[type]

if (defaultValueGenerator != null) {
const jsTypeOverride = findJsTypeOverride(type, fieldDef)

if (defaultValueGenerators[type] != null) {
defaultValue = defaultValueGenerators[type]()
if (jsTypeOverride != null && defaultValueGeneratorsJsTypeOverrides[jsTypeOverride] != null) {
defaultValueGenerator = defaultValueGeneratorsJsTypeOverrides[jsTypeOverride]
}

defaultValue = defaultValueGenerator()
} else {
const def = findDef(fieldDef.type, messageDef, moduleDef)

Expand Down Expand Up @@ -292,10 +340,10 @@ interface FieldDef {
function defineFields (fields: Record<string, FieldDef>, messageDef: MessageDef, moduleDef: ModuleDef): string[] {
return Object.entries(fields).map(([fieldName, fieldDef]) => {
if (fieldDef.map) {
return `${fieldName}: Map<${findTypeName(fieldDef.keyType ?? 'string', messageDef, moduleDef)}, ${findTypeName(fieldDef.valueType, messageDef, moduleDef)}>`
return `${fieldName}: Map<${findJsTypeName(fieldDef.keyType ?? 'string', messageDef, moduleDef, fieldDef)}, ${findJsTypeName(fieldDef.valueType, messageDef, moduleDef, fieldDef)}>`
}

return `${fieldName}${fieldDef.optional ? '?' : ''}: ${findTypeName(fieldDef.type, messageDef, moduleDef)}${fieldDef.repeated ? '[]' : ''}`
return `${fieldName}${fieldDef.optional ? '?' : ''}: ${findJsTypeName(fieldDef.type, messageDef, moduleDef, fieldDef)}${fieldDef.repeated ? '[]' : ''}`
})
}

Expand Down Expand Up @@ -383,7 +431,7 @@ export interface ${messageDef.name} {
type = 'message'
}

typeName = findTypeName(fieldDef.type, messageDef, moduleDef)
typeName = findJsTypeName(fieldDef.type, messageDef, moduleDef, fieldDef)
codec = `${typeName}.codec()`
}

Expand All @@ -392,9 +440,17 @@ export interface ${messageDef.name} {
if (fieldDef.map) {
valueTest = `obj.${name} != null && obj.${name}.size !== 0`
} else if (!fieldDef.optional && !fieldDef.repeated) {
let defaultValueTestGenerator = defaultValueTestGenerators[type]

// proto3 singular fields should only be written out if they are not the default value
if (defaultValueTestGenerators[type] != null) {
valueTest = `${defaultValueTestGenerators[type](`obj.${name}`)}`
if (defaultValueTestGenerator != null) {
const jsTypeOverride = findJsTypeOverride(type, fieldDef)

if (jsTypeOverride != null && defaultValueTestGeneratorsJsTypeOverrides[jsTypeOverride] != null) {
defaultValueTestGenerator = defaultValueTestGeneratorsJsTypeOverrides[jsTypeOverride]
}

valueTest = `${defaultValueTestGenerator(`obj.${name}`)}`
} else if (type === 'enum') {
// handle enums
valueTest = `obj.${name} != null && __${fieldDef.type}Values[obj.${name}] !== 0`
Expand All @@ -412,8 +468,20 @@ export interface ${messageDef.name} {
}
}

let writeField = (): string => `w.uint32(${id})
${encoderGenerators[type] == null ? `${codec}.encode(${valueVar}, w)` : encoderGenerators[type](valueVar)}`
let writeField = (): string => {
const encoderGenerator = encoderGenerators[type]

if (encoderGenerator != null) {
const jsTypeOverride = findJsTypeOverride(type, fieldDef)

if (jsTypeOverride != null && encoderGeneratorsJsTypeOverrides[jsTypeOverride] != null) {
valueVar = encoderGeneratorsJsTypeOverrides[jsTypeOverride](valueVar)
}
}

return `w.uint32(${id})
${encoderGenerator == null ? `${codec}.encode(${valueVar}, w)` : encoderGenerator(valueVar)}`
}

if (type === 'message') {
// message fields are only written if they have values. But if a message
Expand Down Expand Up @@ -483,11 +551,18 @@ export interface ${messageDef.name} {
type = 'message'
}

const typeName = findTypeName(fieldDef.type, messageDef, moduleDef)
const typeName = findJsTypeName(fieldDef.type, messageDef, moduleDef, fieldDef)
codec = `${typeName}.codec()`
}

const parseValue = `${decoderGenerators[type] == null ? `${codec}.decode(reader${type === 'message' ? ', reader.uint32()' : ''})` : decoderGenerators[type]()}`
let parseValue = `${decoderGenerators[type] == null ? `${codec}.decode(reader${type === 'message' ? ', reader.uint32()' : ''})` : decoderGenerators[type]()}`

// override setting type on js object
const jsTypeOverride = findJsTypeOverride(fieldDef.type, fieldDef)

if (jsTypeOverride != null && decoderGeneratorsJsTypeOverrides[jsTypeOverride] != null) {
parseValue = decoderGeneratorsJsTypeOverrides[jsTypeOverride](parseValue)
}

if (fieldDef.map) {
return `case ${fieldDef.id}: {
Expand Down
26 changes: 26 additions & 0 deletions packages/protons/test/custom-options.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* eslint-env mocha */

import { expect } from 'aegir/chai'
import { CustomOptionNumber, CustomOptionString } from './fixtures/custom-option-jstype.js'

describe('custom options', () => {
it('should allow overriding 64 bit numbers with numbers', () => {
const obj: CustomOptionNumber = {
num: 5,
bignum: 5
}

expect(CustomOptionNumber.decode(CustomOptionNumber.encode(obj)))
.to.deep.equal(obj)
})

it('should allow overriding 64 bit numbers with strings', () => {
const obj: CustomOptionString = {
num: 5,
bignum: '5'
}

expect(CustomOptionString.decode(CustomOptionString.encode(obj)))
.to.deep.equal(obj)
})
})
11 changes: 11 additions & 0 deletions packages/protons/test/fixtures/custom-option-jstype.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

message CustomOptionNumber {
int32 num = 1;
int64 bignum = 2 [jstype = JS_NUMBER];
}

message CustomOptionString {
int32 num = 1;
int64 bignum = 2 [jstype = JS_STRING];
}
145 changes: 145 additions & 0 deletions packages/protons/test/fixtures/custom-option-jstype.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/* eslint-disable import/export */
/* eslint-disable complexity */
/* eslint-disable @typescript-eslint/no-namespace */
/* eslint-disable @typescript-eslint/no-unnecessary-boolean-literal-compare */
/* eslint-disable @typescript-eslint/no-empty-interface */

import { encodeMessage, decodeMessage, message } from 'protons-runtime'
import type { Codec } from 'protons-runtime'
import type { Uint8ArrayList } from 'uint8arraylist'

export interface CustomOptionNumber {
num: number
bignum: number
}

export namespace CustomOptionNumber {
let _codec: Codec<CustomOptionNumber>

export const codec = (): Codec<CustomOptionNumber> => {
if (_codec == null) {
_codec = message<CustomOptionNumber>((obj, w, opts = {}) => {
if (opts.lengthDelimited !== false) {
w.fork()
}

if ((obj.num != null && obj.num !== 0)) {
w.uint32(8)
w.int32(obj.num)
}

if ((obj.bignum != null && obj.bignum !== 0)) {
w.uint32(16)
w.int64(BigInt(obj.bignum))
}

if (opts.lengthDelimited !== false) {
w.ldelim()
}
}, (reader, length) => {
const obj: any = {
num: 0,
bignum: 0
}

const end = length == null ? reader.len : reader.pos + length

while (reader.pos < end) {
const tag = reader.uint32()

switch (tag >>> 3) {
case 1:
obj.num = reader.int32()
break
case 2:
obj.bignum = Number(reader.int64())
break
default:
reader.skipType(tag & 7)
break
}
}

return obj
})
}

return _codec
}

export const encode = (obj: Partial<CustomOptionNumber>): Uint8Array => {
return encodeMessage(obj, CustomOptionNumber.codec())
}

export const decode = (buf: Uint8Array | Uint8ArrayList): CustomOptionNumber => {
return decodeMessage(buf, CustomOptionNumber.codec())
}
}

export interface CustomOptionString {
num: number
bignum: string
}

export namespace CustomOptionString {
let _codec: Codec<CustomOptionString>

export const codec = (): Codec<CustomOptionString> => {
if (_codec == null) {
_codec = message<CustomOptionString>((obj, w, opts = {}) => {
if (opts.lengthDelimited !== false) {
w.fork()
}

if ((obj.num != null && obj.num !== 0)) {
w.uint32(8)
w.int32(obj.num)
}

if ((obj.bignum != null && obj.bignum !== '')) {
w.uint32(16)
w.int64(BigInt(obj.bignum))
}

if (opts.lengthDelimited !== false) {
w.ldelim()
}
}, (reader, length) => {
const obj: any = {
num: 0,
bignum: ''
}

const end = length == null ? reader.len : reader.pos + length

while (reader.pos < end) {
const tag = reader.uint32()

switch (tag >>> 3) {
case 1:
obj.num = reader.int32()
break
case 2:
obj.bignum = String(reader.int64())
break
default:
reader.skipType(tag & 7)
break
}
}

return obj
})
}

return _codec
}

export const encode = (obj: Partial<CustomOptionString>): Uint8Array => {
return encodeMessage(obj, CustomOptionString.codec())
}

export const decode = (buf: Uint8Array | Uint8ArrayList): CustomOptionString => {
return decodeMessage(buf, CustomOptionString.codec())
}
}

0 comments on commit a6dfb0a

Please sign in to comment.