Skip to content

Commit

Permalink
Issue 28893/infer schema csv (airbytehq#29099)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Aug 14, 2023
1 parent 2d5939c commit 12f1304
Show file tree
Hide file tree
Showing 16 changed files with 768 additions and 317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFil
parser = stream.get_parser(stream.config.file_type)

try:
record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger)))
record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger, discovered_schema=None)))
except StopIteration:
# The file is empty. We've verified that we can open it, so will
# consider the connection check successful even though it means
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class QuotingBehavior(Enum):
QUOTE_NONE = "Quote None"


class InferenceType(Enum):
NONE = "None"
PRIMITIVE_TYPES_ONLY = "Primitive Types Only"


DEFAULT_TRUE_VALUES = ["y", "yes", "t", "true", "on", "1"]
DEFAULT_FALSE_VALUES = ["n", "no", "f", "false", "off", "0"]

Expand Down Expand Up @@ -81,6 +86,12 @@ class Config:
default=DEFAULT_FALSE_VALUES,
description="A set of case-sensitive strings that should be interpreted as false values.",
)
inference_type: InferenceType = Field(
title="Inference Type",
default=InferenceType.NONE,
description="How to infer the types of the columns. If none, inference default to strings.",
airbyte_hidden=True,
)

@validator("delimiter")
def validate_delimiter(cls, v: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import logging
import uuid
from typing import Any, Dict, Iterable, Mapping
from typing import Any, Dict, Iterable, Mapping, Optional

import fastavro
from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType

AVRO_TYPE_TO_JSON_TYPE = {
"null": "null",
Expand Down Expand Up @@ -47,7 +48,7 @@ async def infer_schema(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
) -> SchemaType:
avro_format = config.format or AvroFormat()
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")
Expand Down Expand Up @@ -132,6 +133,7 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
avro_format = config.format or AvroFormat()
if not isinstance(avro_format, AvroFormat):
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Mapping, Optional

from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType

Schema = Dict[str, str]
Record = Dict[str, Any]


Expand All @@ -27,7 +27,7 @@ async def infer_schema(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Schema:
) -> SchemaType:
"""
Infer the JSON Schema for this file.
"""
Expand All @@ -40,6 +40,7 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Record]:
"""
Parse and emit each record.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import json
import logging
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Mapping, Optional

from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import PYTHON_TYPE_MAPPING, merge_schemas
from airbyte_cdk.sources.file_based.schema_helpers import PYTHON_TYPE_MAPPING, SchemaType, merge_schemas


class JsonlParser(FileTypeParser):
Expand All @@ -25,12 +25,12 @@ async def infer_schema(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
) -> SchemaType:
"""
Infers the schema for the file by inferring the schema for each line, and merging
it with the previously-inferred schema.
"""
inferred_schema: Dict[str, Any] = {}
inferred_schema: Mapping[str, Any] = {}

for entry in self._parse_jsonl_entries(file, stream_reader, logger, read_limit=True):
line_schema = self._infer_schema_for_record(entry)
Expand All @@ -44,6 +44,7 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
"""
This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import os
from typing import Any, Dict, Iterable, List, Mapping
from typing import Any, Dict, Iterable, List, Mapping, Optional
from urllib.parse import unquote

import pyarrow as pa
Expand All @@ -15,6 +15,7 @@
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
from pyarrow import Scalar


Expand All @@ -28,7 +29,7 @@ async def infer_schema(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
) -> SchemaType:
parquet_format = config.format or ParquetFormat()
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")
Expand All @@ -51,6 +52,7 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
parquet_format = config.format or ParquetFormat()
if not isinstance(parquet_format, ParquetFormat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, SchemaInferenceError

JsonSchemaSupportedType = Union[List[str], Literal["string"], str]
SchemaType = Dict[str, Dict[str, JsonSchemaSupportedType]]
SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]]

schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}}

Expand Down Expand Up @@ -99,7 +99,7 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]):
raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t)

merged_schema: Dict[str, Any] = deepcopy(schema1)
merged_schema: Dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping
for k2, t2 in schema2.items():
t1 = merged_schema.get(k2)
if t1 is None:
Expand All @@ -116,7 +116,7 @@ def _is_valid_type(t: JsonSchemaSupportedType) -> bool:
return t == "array" or get_comparable_type(t) is not None


def _choose_wider_type(key: str, t1: Dict[str, Any], t2: Dict[str, Any]) -> Dict[str, Any]:
def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) -> Mapping[str, Any]:
if (t1["type"] == "array" or t2["type"] == "array") and t1 != t2:
raise SchemaInferenceError(
FileBasedSourceError.SCHEMA_INFERENCE_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
import traceback
from functools import cache
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Set, Union

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
from airbyte_cdk.models import Type as MessageType
Expand All @@ -20,7 +20,7 @@
StopSyncPerValidationPolicy,
)
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import merge_schemas, schemaless_schema
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType, merge_schemas, schemaless_schema
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamSlice
Expand Down Expand Up @@ -84,7 +84,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
n_skipped = line_no = 0

try:
for record in parser.parse_records(self.config, file, self._stream_reader, self.logger):
for record in parser.parse_records(self.config, file, self._stream_reader, self.logger, schema):
line_no += 1
if self.config.schemaless:
record = {"data": record}
Expand Down Expand Up @@ -231,8 +231,8 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
Each file type has a corresponding `infer_schema` handler.
Dispatch on file type.
"""
base_schema: Dict[str, Any] = {}
pending_tasks: Set[asyncio.tasks.Task[Dict[str, Any]]] = set()
base_schema: SchemaType = {}
pending_tasks: Set[asyncio.tasks.Task[SchemaType]] = set()

n_started, n_files = 0, len(files)
files_iterator = iter(files)
Expand All @@ -251,7 +251,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:

return base_schema

async def _infer_file_schema(self, file: RemoteFile) -> Dict[str, Any]:
async def _infer_file_schema(self, file: RemoteFile) -> SchemaType:
try:
return await self.get_parser(self.config.file_type).infer_schema(self.config, file, self._stream_reader, self.logger)
except Exception as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
pytest.param(0, False, None, id="test_no_skip_rows_before_header_and_no_autogenerate_column_names"),
]
)
def test_csv_format(skip_rows_before_header, autogenerate_column_names, expected_error):
def test_csv_format_skip_rows_and_autogenerate_column_names(skip_rows_before_header, autogenerate_column_names, expected_error) -> None:
if expected_error:
with pytest.raises(expected_error):
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)
Expand Down
Loading

0 comments on commit 12f1304

Please sign in to comment.