From 3d705bf05d3be5e5232089c3524213ebbae2911f Mon Sep 17 00:00:00 2001 From: andrei-faber <89944741+andrei-faber@users.noreply.github.com> Date: Mon, 8 May 2023 23:07:05 -0600 Subject: [PATCH] Added ADO.NET importing/exporting functionality to DataFrame (#5975) * refactoring - removed copy/paste in DataFrame.CreateColumn() * added a universal loading method and export to DataTable * added tests for new loading/saving methods in DataFrame * improved error handling - DataFrame.LoadFrom() * DataFrame - importing and exporting data using ADO.NET providers * DataFrame.LoadFrom() - use async * DataFrame.LoadFrom() - minor refactorings * Update Microsoft.Data.Analysis.Tests.csproj Changed version of System.Data.SQLite * Update Microsoft.Data.Analysis.Tests.csproj * fixed chown command * sql db test path change * sql db test path change * sql db test fix * sql db test fix --------- Co-authored-by: Michael Sharp <51342856+michaelgsharp@users.noreply.github.com> Co-authored-by: Michael Sharp --- eng/Versions.props | 2 +- eng/helix.proj | 2 +- src/Microsoft.Data.Analysis/DataFrame.IO.cs | 186 ++++++++++++++++-- src/Microsoft.Data.Analysis/Extensions.cs | 37 ++++ .../DataFrame.IOTests.cs | 118 +++++++++++ .../Microsoft.Data.Analysis.Tests.csproj | 7 + 6 files changed, 334 insertions(+), 18 deletions(-) create mode 100644 src/Microsoft.Data.Analysis/Extensions.cs diff --git a/eng/Versions.props b/eng/Versions.props index 052539cb303..0c6252b6f54 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -87,7 +87,7 @@ 0.0.6-test 0.0.7-test 4.6.1 - 1.0.112.2 + 1.0.113 1.2.7 2.4.2 diff --git a/eng/helix.proj b/eng/helix.proj index b7d3b600862..ef55768a735 100644 --- a/eng/helix.proj +++ b/eng/helix.proj @@ -96,7 +96,7 @@ - $(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $(whoami) $HELIX_WORKITEM_ROOT + $(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $USER $HELIX_WORKITEM_ROOT $(HelixPreCommands);set ML_TEST_DATADIR=%HELIX_CORRELATION_PAYLOAD%;set MICROSOFTML_RESOURCE_PATH=%HELIX_WORKITEM_ROOT% $(HelixPreCommands);install_name_tool -change "/usr/local/opt/libomp/lib/libomp.dylib" "@loader_path/libomp.dylib" libSymSgdNative.dylib diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index 04cacc99a8a..e5256249988 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -4,9 +4,12 @@ using System; using System.Collections.Generic; +using System.Data; +using System.Data.Common; using System.Globalization; using System.IO; using System.Text; +using System.Threading.Tasks; namespace Microsoft.Data.Analysis { @@ -109,12 +112,158 @@ public static DataFrame LoadCsv(string filename, } } + public static DataFrame LoadFrom(IEnumerable> vals, IList<(string, Type)> columnInfos) + { + var columnsCount = columnInfos.Count; + var columns = new List(columnsCount); + + foreach (var (name, type) in columnInfos) + { + var column = CreateColumn(type, name); + columns.Add(column); + } + + var res = new DataFrame(columns); + + foreach (var items in vals) + { + for (var c = 0; c < items.Count; c++) + { + items[c] = items[c]; + } + res.Append(items, inPlace: true); + } + + return res; + } + + public void SaveTo(DataTable table) + { + var columnsCount = Columns.Count; + + if (table.Columns.Count == 0) + { + foreach (var column in Columns) + { + table.Columns.Add(column.Name, column.DataType); + } + } + else + { + if (table.Columns.Count != columnsCount) + throw new ArgumentException(); + for (var c = 0; c < columnsCount; c++) + { + if (table.Columns[c].DataType != Columns[c].DataType) + throw new ArgumentException(); + } + } + + var items = new object[columnsCount]; + foreach (var row in Rows) + { + for (var c = 0; c < columnsCount; c++) + { + items[c] = row[c] ?? DBNull.Value; + } + table.Rows.Add(items); + } + } + + public DataTable ToTable() + { + var res = new DataTable(); + SaveTo(res); + return res; + } + + public static DataFrame FromSchema(DbDataReader reader) + { + var columnsCount = reader.FieldCount; + var columns = new DataFrameColumn[columnsCount]; + + for (var c = 0; c < columnsCount; c++) + { + var type = reader.GetFieldType(c); + var name = reader.GetName(c); + var column = CreateColumn(type, name); + columns[c] = column; + } + + var res = new DataFrame(columns); + return res; + } + + public static async Task LoadFrom(DbDataReader reader) + { + var res = FromSchema(reader); + var columnsCount = reader.FieldCount; + + var items = new object[columnsCount]; + while (await reader.ReadAsync()) + { + for (var c = 0; c < columnsCount; c++) + { + items[c] = reader.IsDBNull(c) + ? null + : reader[c]; + } + res.Append(items, inPlace: true); + } + + reader.Close(); + + return res; + } + + public static async Task LoadFrom(DbDataAdapter adapter) + { + using var reader = await adapter.SelectCommand.ExecuteReaderAsync(); + return await LoadFrom(reader); + } + + public void SaveTo(DbDataAdapter dataAdapter, DbProviderFactory factory) + { + using var commandBuilder = factory.CreateCommandBuilder(); + commandBuilder.DataAdapter = dataAdapter; + dataAdapter.InsertCommand = commandBuilder.GetInsertCommand(); + dataAdapter.UpdateCommand = commandBuilder.GetUpdateCommand(); + dataAdapter.DeleteCommand = commandBuilder.GetDeleteCommand(); + + using var table = ToTable(); + + var connection = dataAdapter.SelectCommand.Connection; + var needClose = connection.TryOpen(); + + try + { + using var transaction = connection.BeginTransaction(); + try + { + dataAdapter.Update(table); + } + catch + { + transaction.Rollback(); + transaction.Dispose(); + throw; + } + transaction.Commit(); + } + finally + { + if (needClose) + connection.Close(); + } + } + /// /// return of if not null or empty, otherwise return "Column{i}" where i is . /// /// column names. /// column index. /// + private static string GetColumnName(string[] columnNames, int columnIndex) { var defaultColumnName = "Column" + columnIndex.ToString(); @@ -126,68 +275,68 @@ private static string GetColumnName(string[] columnNames, int columnIndex) return defaultColumnName; } - private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex) + private static DataFrameColumn CreateColumn(Type kind, string columnName) { DataFrameColumn ret; if (kind == typeof(bool)) { - ret = new BooleanDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new BooleanDataFrameColumn(columnName); } else if (kind == typeof(int)) { - ret = new Int32DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new Int32DataFrameColumn(columnName); } else if (kind == typeof(float)) { - ret = new SingleDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new SingleDataFrameColumn(columnName); } else if (kind == typeof(string)) { - ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0); + ret = new StringDataFrameColumn(columnName, 0); } else if (kind == typeof(long)) { - ret = new Int64DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new Int64DataFrameColumn(columnName); } else if (kind == typeof(decimal)) { - ret = new DecimalDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new DecimalDataFrameColumn(columnName); } else if (kind == typeof(byte)) { - ret = new ByteDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new ByteDataFrameColumn(columnName); } else if (kind == typeof(char)) { - ret = new CharDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new CharDataFrameColumn(columnName); } else if (kind == typeof(double)) { - ret = new DoubleDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new DoubleDataFrameColumn(columnName); } else if (kind == typeof(sbyte)) { - ret = new SByteDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new SByteDataFrameColumn(columnName); } else if (kind == typeof(short)) { - ret = new Int16DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new Int16DataFrameColumn(columnName); } else if (kind == typeof(uint)) { - ret = new UInt32DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new UInt32DataFrameColumn(columnName); } else if (kind == typeof(ulong)) { - ret = new UInt64DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new UInt64DataFrameColumn(columnName); } else if (kind == typeof(ushort)) { - ret = new UInt16DataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new UInt16DataFrameColumn(columnName); } else if (kind == typeof(DateTime)) { - ret = new PrimitiveDataFrameColumn(GetColumnName(columnNames, columnIndex)); + ret = new PrimitiveDataFrameColumn(columnName); } else { @@ -196,6 +345,11 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int return ret; } + private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex) + { + return CreateColumn(kind, GetColumnName(columnNames, columnIndex)); + } + private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader, char separator = ',', bool header = true, string[] columnNames = null, Type[] dataTypes = null, diff --git a/src/Microsoft.Data.Analysis/Extensions.cs b/src/Microsoft.Data.Analysis/Extensions.cs new file mode 100644 index 00000000000..3e3d20b4a43 --- /dev/null +++ b/src/Microsoft.Data.Analysis/Extensions.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Text; + +namespace Microsoft.Data.Analysis +{ + public static class Extensions + { + public static DbDataAdapter CreateDataAdapter(this DbProviderFactory factory, DbConnection connection, string tableName) + { + var query = connection.CreateCommand(); + query.CommandText = $"SELECT * FROM {tableName}"; + var res = factory.CreateDataAdapter(); + res.SelectCommand = query; + return res; + } + + public static bool TryOpen(this DbConnection connection) + { + if (connection.State == ConnectionState.Closed) + { + connection.Open(); + return true; + } + else + { + return false; + } + } + } +} diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 697e38f9e46..398e849907e 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -4,11 +4,16 @@ using System; using System.Collections.Generic; +using System.Data; +using System.Data.Common; using System.Globalization; using System.IO; using System.Linq; using System.Text; +using System.Data.SQLite; +using System.Data.SQLite.EF6; using Xunit; +using Microsoft.ML.TestFramework.Attributes; namespace Microsoft.Data.Analysis.Tests { @@ -1021,6 +1026,119 @@ public void TestMixedDataTypesInCsv() } } + [Fact] + public void TestLoadFromEnumerable() + { + var (columns, vals) = GetTestData(); + var dataFrame = DataFrame.LoadFrom(vals, columns); + AssertEqual(dataFrame, columns, vals); + } + + [Fact] + public void TestSaveToDataTable() + { + var (columns, vals) = GetTestData(); + var dataFrame = DataFrame.LoadFrom(vals, columns); + + using var table = dataFrame.ToTable(); + + var resColumns = table.Columns.Cast().Select(column => (column.ColumnName, column.DataType)).ToArray(); + Assert.Equal(columns, resColumns); + + var resVals = table.Rows.Cast().Select(row => row.ItemArray).ToArray(); + Assert.Equal(vals, resVals); + } + + [X86X64FactAttribute("The SQLite un-managed code, SQLite.interop, only supports x86/x64 architectures.")] + public async void TestSQLite() + { + var (columns, vals) = GetTestData(); + var dataFrame = DataFrame.LoadFrom(vals, columns); + + try + { + var (factory, connection) = InitSQLiteDb(); + using (factory) + { + using (connection) + { + using var dataAdapter = factory.CreateDataAdapter(connection, TableName); + dataFrame.SaveTo(dataAdapter, factory); + + var resDataFrame = await DataFrame.LoadFrom(dataAdapter); + + AssertEqual(resDataFrame, columns, vals); + } + } + } + finally + { + CleanupSQLiteDb(); + } + } + + static void AssertEqual(DataFrame dataFrame, (string name, Type type)[] columns, object[][] vals) + { + var resColumns = dataFrame.Columns.Select(column => (column.Name, column.DataType)).ToArray(); + Assert.Equal(columns, resColumns); + var resVals = dataFrame.Rows.Select(row => row.ToArray()).ToArray(); + Assert.Equal(vals, resVals); + } + + static ((string name, Type type)[] columns, object[][] vals) GetTestData() + { + const int RowsCount = 10_000; + + var columns = new[] + { + ("ID", typeof(long)), + ("Text", typeof(string)) + }; + + var vals = new object[RowsCount][]; + for (var i = 0L; i < RowsCount; i++) + { + var row = new object[columns.Length]; + row[0] = i; + row[1] = $"test {i}"; + vals[i] = row; + } + + return (columns, vals); + } + + static (SQLiteProviderFactory factory, DbConnection connection) InitSQLiteDb() + { + var connectionString = $"DataSource={SQLitePath};Version=3;New=True;Compress=True;"; + + SQLiteConnection.CreateFile(SQLitePath); + var factory = new SQLiteProviderFactory(); + + var connection = factory.CreateConnection(); + connection.ConnectionString = connectionString; + connection.Open(); + + using var command = connection.CreateCommand(); + command.CommandText = $"CREATE TABLE {TableName} (ID INTEGER NOT NULL PRIMARY KEY ASC, Text VARCHAR(25))"; + command.ExecuteNonQuery(); + + return (factory, connection); + } + + static void CleanupSQLiteDb() + { + if (File.Exists(SQLitePath)) + File.Delete(SQLitePath); + } + + static readonly string BasePath = + Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) + "/"; + + const string DbName = "TestDb"; + const string TableName = "TestTable"; + + static readonly string SQLitePath = $@"{BasePath}/{DbName}.sqlite"; + public readonly struct LoadCsvVerifyingHelper { private readonly int _columnCount; diff --git a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj index 1a570416e48..dafae0d9420 100644 --- a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj +++ b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj @@ -8,6 +8,8 @@ + + @@ -44,4 +46,9 @@ + + + + +