diff --git a/scripts/plugin.py b/scripts/plugin.py index b7fd634..cdff712 100755 --- a/scripts/plugin.py +++ b/scripts/plugin.py @@ -89,9 +89,6 @@ def inline_init() -> str: function_definition = function_definition.splitlines()[1:] function_definition = "\n".join(function_definition) function_definition = dedent(function_definition) - # If __init__ is empty, return an empty string - if function_definition == "pass": - return "" return function_definition # Add a presence check to a function definition diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py index 52711d2..a1bc8cb 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py @@ -26,6 +26,7 @@ def __init__( resource_logs: List[ResourceLogs] = None, ): self._resource_logs: List[ResourceLogs] = resource_logs + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -58,6 +59,7 @@ def __init__( partial_success: ExportLogsPartialSuccess = None, ): self._partial_success: ExportLogsPartialSuccess = partial_success + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -87,6 +89,7 @@ def __init__( ): self.rejected_log_records: int = rejected_log_records self.error_message: str = error_message + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -94,6 +97,7 @@ def calculate_size(self) -> int: size += len(b"\x08") + Varint.size_varint_i64(self.rejected_log_records) if self.error_message: v = self.error_message.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -102,7 +106,7 @@ def write_to(self, out: bytearray) -> None: out += b"\x08" Varint.write_varint_i64(out, self.rejected_log_records) if self.error_message: - v = self.error_message.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py index 7701775..7de23c5 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py @@ -26,6 +26,7 @@ def __init__( resource_metrics: List[ResourceMetrics] = None, ): self._resource_metrics: List[ResourceMetrics] = resource_metrics + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -58,6 +59,7 @@ def __init__( partial_success: ExportMetricsPartialSuccess = None, ): self._partial_success: ExportMetricsPartialSuccess = partial_success + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -87,6 +89,7 @@ def __init__( ): self.rejected_data_points: int = rejected_data_points self.error_message: str = error_message + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -94,6 +97,7 @@ def calculate_size(self) -> int: size += len(b"\x08") + Varint.size_varint_i64(self.rejected_data_points) if self.error_message: v = self.error_message.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -102,7 +106,7 @@ def write_to(self, out: bytearray) -> None: out += b"\x08" Varint.write_varint_i64(out, self.rejected_data_points) if self.error_message: - v = self.error_message.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py index 2488f6c..bbef204 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py @@ -26,6 +26,7 @@ def __init__( resource_spans: List[ResourceSpans] = None, ): self._resource_spans: List[ResourceSpans] = resource_spans + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -58,6 +59,7 @@ def __init__( partial_success: ExportTracePartialSuccess = None, ): self._partial_success: ExportTracePartialSuccess = partial_success + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -87,6 +89,7 @@ def __init__( ): self.rejected_spans: int = rejected_spans self.error_message: str = error_message + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -94,6 +97,7 @@ def calculate_size(self) -> int: size += len(b"\x08") + Varint.size_varint_i64(self.rejected_spans) if self.error_message: v = self.error_message.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -102,7 +106,7 @@ def write_to(self, out: bytearray) -> None: out += b"\x08" Varint.write_varint_i64(out, self.rejected_spans) if self.error_message: - v = self.error_message.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py index 2d25bcc..78c4c70 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py @@ -50,11 +50,13 @@ def __init__( self._array_value: ArrayValue = array_value self._kvlist_value: KeyValueList = kvlist_value self.bytes_value: bytes = bytes_value + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 if self.string_value is not None: v = self.string_value.encode("utf-8") + self._marshaler_cache[b"\n"] = v size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) if self.bool_value is not None: size += len(b"\x10") + 1 @@ -84,7 +86,7 @@ def calculate_size(self) -> int: def write_to(self, out: bytearray) -> None: if self.string_value is not None: - v = self.string_value.encode("utf-8") + v = self._marshaler_cache[b"\n"] out += b"\n" Varint.write_varint_u32(out, len(v)) out += v @@ -123,6 +125,7 @@ def __init__( values: List[AnyValue] = None, ): self._values: List[AnyValue] = values + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -155,6 +158,7 @@ def __init__( values: List[KeyValue] = None, ): self._values: List[KeyValue] = values + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -191,11 +195,13 @@ def __init__( ): self.key: str = key self._value: AnyValue = value + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 if self.key: v = self.key.encode("utf-8") + self._marshaler_cache[b"\n"] = v size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) if self._value is not None: size += ( @@ -207,7 +213,7 @@ def calculate_size(self) -> int: def write_to(self, out: bytearray) -> None: if self.key: - v = self.key.encode("utf-8") + v = self._marshaler_cache[b"\n"] out += b"\n" Varint.write_varint_u32(out, len(v)) out += v @@ -240,14 +246,17 @@ def __init__( self.version: str = version self._attributes: List[KeyValue] = attributes self.dropped_attributes_count: int = dropped_attributes_count + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 if self.name: v = self.name.encode("utf-8") + self._marshaler_cache[b"\n"] = v size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) if self.version: v = self.version.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) if self._attributes: size += sum( @@ -262,12 +271,12 @@ def calculate_size(self) -> int: def write_to(self, out: bytearray) -> None: if self.name: - v = self.name.encode("utf-8") + v = self._marshaler_cache[b"\n"] out += b"\n" Varint.write_varint_u32(out, len(v)) out += v if self.version: - v = self.version.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py index 18996f0..a190d10 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py @@ -60,6 +60,7 @@ def __init__( resource_logs: List[ResourceLogs] = None, ): self._resource_logs: List[ResourceLogs] = resource_logs + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -104,6 +105,7 @@ def __init__( self._resource: Resource = resource self._scope_logs: List[ScopeLogs] = scope_logs self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -122,6 +124,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -136,7 +139,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -166,6 +169,7 @@ def __init__( self._scope: InstrumentationScope = scope self._log_records: List[LogRecord] = log_records self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -184,6 +188,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -198,7 +203,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -250,6 +255,7 @@ def __init__( self.trace_id: bytes = trace_id self.span_id: bytes = span_id self.observed_time_unix_nano: int = observed_time_unix_nano + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -262,6 +268,7 @@ def calculate_size(self) -> int: size += len(b"\x10") + Varint.size_varint_u32(v) if self.severity_text: v = self.severity_text.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) if self._body is not None: size += ( @@ -307,7 +314,7 @@ def write_to(self, out: bytearray) -> None: out += b"\x10" Varint.write_varint_u32(out, v) if self.severity_text: - v = self.severity_text.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/metrics/v1/metrics_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/metrics/v1/metrics_marshaler.py index bc7097b..b6e57c3 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/metrics/v1/metrics_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/metrics/v1/metrics_marshaler.py @@ -38,6 +38,7 @@ def __init__( resource_metrics: List[ResourceMetrics] = None, ): self._resource_metrics: List[ResourceMetrics] = resource_metrics + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -82,6 +83,7 @@ def __init__( self._resource: Resource = resource self._scope_metrics: List[ScopeMetrics] = scope_metrics self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -100,6 +102,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -114,7 +117,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -144,6 +147,7 @@ def __init__( self._scope: InstrumentationScope = scope self._metrics: List[Metric] = metrics self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -162,6 +166,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -176,7 +181,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -244,17 +249,21 @@ def __init__( self._exponential_histogram: ExponentialHistogram = exponential_histogram self._summary: Summary = summary self._metadata: List[KeyValue] = metadata + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 if self.name: v = self.name.encode("utf-8") + self._marshaler_cache[b"\n"] = v size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) if self.description: v = self.description.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) if self.unit: v = self.unit.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) if self._gauge is not None: size += ( @@ -297,17 +306,17 @@ def calculate_size(self) -> int: def write_to(self, out: bytearray) -> None: if self.name: - v = self.name.encode("utf-8") + v = self._marshaler_cache[b"\n"] out += b"\n" Varint.write_varint_u32(out, len(v)) out += v if self.description: - v = self.description.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v if self.unit: - v = self.unit.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -350,6 +359,7 @@ def __init__( data_points: List[NumberDataPoint] = None, ): self._data_points: List[NumberDataPoint] = data_points + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -389,6 +399,7 @@ def __init__( self._data_points: List[NumberDataPoint] = data_points self.aggregation_temporality: AggregationTemporality = aggregation_temporality self.is_monotonic: bool = is_monotonic + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -441,6 +452,7 @@ def __init__( ): self._data_points: List[HistogramDataPoint] = data_points self.aggregation_temporality: AggregationTemporality = aggregation_temporality + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -488,6 +500,7 @@ def __init__( ): self._data_points: List[ExponentialHistogramDataPoint] = data_points self.aggregation_temporality: AggregationTemporality = aggregation_temporality + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -531,6 +544,7 @@ def __init__( data_points: List[SummaryDataPoint] = None, ): self._data_points: List[SummaryDataPoint] = data_points + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -589,6 +603,7 @@ def __init__( self.as_int: int = as_int self._attributes: List[KeyValue] = attributes self.flags: int = flags + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -705,6 +720,7 @@ def __init__( self.flags: int = flags self.min: float = min self.max: float = max + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -863,6 +879,7 @@ def __init__( self.min: float = min self.max: float = max self.zero_threshold: float = zero_threshold + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -980,6 +997,7 @@ def __init__( ): self.offset: int = offset self._bucket_counts: List[int] = bucket_counts + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -989,6 +1007,7 @@ def calculate_size(self) -> int: s = sum( Varint.size_varint_u64(uint32) for uint32 in self._bucket_counts ) + self._marshaler_cache[b"\x12"] = s size += len(b"\x12") + s + Varint.size_varint_u32(s) return size @@ -998,12 +1017,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_s32(out, self.offset) if self._bucket_counts: out += b"\x12" - Varint.write_varint_u32( - out, - sum( - Varint.size_varint_u64(uint32) for uint32 in self._bucket_counts - ), - ) + Varint.write_varint_u32(out, self._marshaler_cache[b"\x12"]) for v in self._bucket_counts: Varint.write_varint_u64(out, v) @@ -1045,6 +1059,7 @@ def __init__( self._quantile_values: List[SummaryDataPoint.ValueAtQuantile] = quantile_values self._attributes: List[KeyValue] = attributes self.flags: int = flags + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -1112,6 +1127,7 @@ def __init__( ): self.quantile: float = quantile self.value: float = value + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -1158,6 +1174,7 @@ def __init__( self.trace_id: bytes = trace_id self.as_int: int = as_int self._filtered_attributes: List[KeyValue] = filtered_attributes + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/resource/v1/resource_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/resource/v1/resource_marshaler.py index 7b43f3c..33e4d97 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/resource/v1/resource_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/resource/v1/resource_marshaler.py @@ -30,6 +30,7 @@ def __init__( ): self._attributes: List[KeyValue] = attributes self.dropped_attributes_count: int = dropped_attributes_count + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py index ce957f6..e51b0c7 100644 --- a/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py @@ -34,6 +34,7 @@ def __init__( resource_spans: List[ResourceSpans] = None, ): self._resource_spans: List[ResourceSpans] = resource_spans + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -78,6 +79,7 @@ def __init__( self._resource: Resource = resource self._scope_spans: List[ScopeSpans] = scope_spans self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -96,6 +98,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -110,7 +113,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -140,6 +143,7 @@ def __init__( self._scope: InstrumentationScope = scope self._spans: List[Span] = spans self.schema_url: str = schema_url + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -158,6 +162,7 @@ def calculate_size(self) -> int: ) if self.schema_url: v = self.schema_url.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) return size @@ -172,7 +177,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, v._get_size()) v.write_to(out) if self.schema_url: - v = self.schema_url.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -255,6 +260,7 @@ def __init__( self.dropped_links_count: int = dropped_links_count self._status: Status = status self.flags: int = flags + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -272,6 +278,7 @@ def calculate_size(self) -> int: ) if self.trace_state: v = self.trace_state.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) if self.parent_span_id: size += ( @@ -281,6 +288,7 @@ def calculate_size(self) -> int: ) if self.name: v = self.name.encode("utf-8") + self._marshaler_cache[b"*"] = v size += len(b"*") + Varint.size_varint_u32(len(v)) + len(v) if self.kind: v = self.kind @@ -338,7 +346,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, len(self.span_id)) out += self.span_id if self.trace_state: - v = self.trace_state.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -347,7 +355,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, len(self.parent_span_id)) out += self.parent_span_id if self.name: - v = self.name.encode("utf-8") + v = self._marshaler_cache[b"*"] out += b"*" Varint.write_varint_u32(out, len(v)) out += v @@ -426,6 +434,7 @@ def __init__( self.name: str = name self._attributes: List[KeyValue] = attributes self.dropped_attributes_count: int = dropped_attributes_count + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 @@ -433,6 +442,7 @@ def calculate_size(self) -> int: size += len(b"\t") + 8 if self.name: v = self.name.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) if self._attributes: size += sum( @@ -452,7 +462,7 @@ def write_to(self, out: bytearray) -> None: out += b"\t" out += struct.pack(" int: size = 0 @@ -511,6 +522,7 @@ def calculate_size(self) -> int: ) if self.trace_state: v = self.trace_state.encode("utf-8") + self._marshaler_cache[b"\x1a"] = v size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) if self._attributes: size += sum( @@ -537,7 +549,7 @@ def write_to(self, out: bytearray) -> None: Varint.write_varint_u32(out, len(self.span_id)) out += self.span_id if self.trace_state: - v = self.trace_state.encode("utf-8") + v = self._marshaler_cache[b"\x1a"] out += b"\x1a" Varint.write_varint_u32(out, len(v)) out += v @@ -565,11 +577,13 @@ def __init__( ): self.message: str = message self.code: Status.StatusCode = code + self._marshaler_cache = {} def calculate_size(self) -> int: size = 0 if self.message: v = self.message.encode("utf-8") + self._marshaler_cache[b"\x12"] = v size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) if self.code: v = self.code @@ -580,7 +594,7 @@ def calculate_size(self) -> int: def write_to(self, out: bytearray) -> None: if self.message: - v = self.message.encode("utf-8") + v = self._marshaler_cache[b"\x12"] out += b"\x12" Varint.write_varint_u32(out, len(v)) out += v diff --git a/src/snowflake/telemetry/_internal/serialize/__init__.py b/src/snowflake/telemetry/_internal/serialize/__init__.py index 0784f7d..cac2834 100644 --- a/src/snowflake/telemetry/_internal/serialize/__init__.py +++ b/src/snowflake/telemetry/_internal/serialize/__init__.py @@ -87,9 +87,11 @@ def write_varint_s32(out: bytearray, value: int) -> None: # Base class for all custom messages class MessageMarshaler: + _marshaler_cache: Dict[bytes, Any] + # Init may be inlined by the code generator def __init__(self) -> None: - pass + self._marshaler_cache = {} def write_to(self, out: bytearray) -> None: ... @@ -103,6 +105,8 @@ def _get_size(self) -> int: return self._size def SerializeToString(self) -> bytes: + # size MUST be calculated before serializing since some preprocessing is done + self._get_size() stream = bytearray() self.write_to(stream) return bytes(stream) @@ -163,8 +167,10 @@ def size_sfixed64(self, TAG: bytes, FIELD_ATTR: int) -> int: def size_bytes(self, TAG: bytes, FIELD_ATTR: bytes) -> int: return len(TAG) + Varint.size_varint_u32(len(FIELD_ATTR)) + len(FIELD_ATTR) + # This function should not be used for repeated strings due to caching by tag def size_string(self, TAG: bytes, FIELD_ATTR: str) -> int: v = FIELD_ATTR.encode("utf-8") + self._marshaler_cache[TAG] = v return len(TAG) + Varint.size_varint_u32(len(v)) + len(v) def size_message(self, TAG: bytes, FIELD_ATTR: MessageMarshaler) -> int: @@ -181,6 +187,7 @@ def size_repeated_fixed64(self, TAG: bytes, FIELD_ATTR: List[int]): def size_repeated_uint64(self, TAG: bytes, FIELD_ATTR: List[int]): s = sum(Varint.size_varint_u64(uint32) for uint32 in FIELD_ATTR) + self._marshaler_cache[TAG] = s return len(TAG) + s + Varint.size_varint_u32(s) def serialize_bool(self, out: bytearray, TAG: bytes, FIELD_ATTR: bool) -> None: @@ -247,8 +254,9 @@ def serialize_bytes(self, out: bytearray, TAG: bytes, FIELD_ATTR: bytes) -> None Varint.write_varint_u32(out, len(FIELD_ATTR)) out += FIELD_ATTR + # This function should not be used for repeated strings due to caching by tag def serialize_string(self, out: bytearray, TAG: bytes, FIELD_ATTR: str) -> None: - v = FIELD_ATTR.encode("utf-8") + v = self._marshaler_cache[TAG] out += TAG Varint.write_varint_u32(out, len(v)) out += v @@ -278,7 +286,7 @@ def serialize_repeated_fixed64(self, out: bytearray, TAG: bytes, FIELD_ATTR: Lis def serialize_repeated_uint64(self, out: bytearray, TAG: bytes, FIELD_ATTR: List[int]) -> None: out += TAG - Varint.write_varint_u32(out, sum(Varint.size_varint_u64(uint32) for uint32 in FIELD_ATTR)) + Varint.write_varint_u32(out, self._marshaler_cache[TAG]) for v in FIELD_ATTR: Varint.write_varint_u64(out, v)