diff --git a/CodeOnlyStoredProcedure/RowFactory/AccessorFactoryBase.cs b/CodeOnlyStoredProcedure/RowFactory/AccessorFactoryBase.cs index 6468eb0..804b96a 100644 --- a/CodeOnlyStoredProcedure/RowFactory/AccessorFactoryBase.cs +++ b/CodeOnlyStoredProcedure/RowFactory/AccessorFactoryBase.cs @@ -275,7 +275,14 @@ protected static Expression CreateUnboxedRetrieval( if (expectedDbType != null) { res = Expression.Call(dbReader, typeof(IDataRecord).GetMethod("Get" + Type.GetTypeCode(expectedDbType)), index); - res = Expression.Convert(res, dbType); + if (type == typeof(bool) || type == typeof(bool?)) + { + res = Expression.NotEqual(res, Zero(expectedDbType)); + if (type == typeof(bool?)) + res = Expression.Convert(res, type); + } + else + res = Expression.Convert(res, dbType); } else res = Expression.Call(dbReader, typeof(IDataRecord).GetMethod("Get" + Type.GetTypeCode(dbType)), index); @@ -324,5 +331,23 @@ protected static Expression CreateUnboxedRetrieval( return Expression.Block(type, parms, body); } + + protected static Expression Zero(Type dbType, bool isNullable = false) + { + if (dbType == typeof(short)) + return Expression.Constant((short)0, isNullable ? typeof(short?) : dbType); + if (dbType == typeof(byte)) + return Expression.Constant((byte)0, isNullable ? typeof(byte?) : dbType); + if (dbType == typeof(double)) + return Expression.Constant(0.0, isNullable ? typeof(double?) : dbType); + if (dbType == typeof(float)) + return Expression.Constant(0f, isNullable ? typeof(float?) : dbType); + if (dbType == typeof(decimal)) + return Expression.Constant(0M, isNullable ? typeof(decimal?) : dbType); + if (dbType == typeof(long)) + return Expression.Constant(0L, isNullable ? typeof(long?) : dbType); + + return Expression.Constant(0, isNullable ? typeof(int?) : dbType); + } } } diff --git a/CodeOnlyStoredProcedure/RowFactory/ValueAccessorFactory.cs b/CodeOnlyStoredProcedure/RowFactory/ValueAccessorFactory.cs index a7b301d..d49a00c 100644 --- a/CodeOnlyStoredProcedure/RowFactory/ValueAccessorFactory.cs +++ b/CodeOnlyStoredProcedure/RowFactory/ValueAccessorFactory.cs @@ -76,7 +76,23 @@ public override Expression CreateExpressionToGetValueFromReader(IDataReader read if (dbColumnType != expectedType) { if (convertNumeric) - res = Expression.Convert(res, type); + { + if (expectedType == typeof(bool)) + { + if (isNullable) + { + // res = res == null ? null : (bool?)(bool)res + res = Expression.Condition( + Expression.Equal(res, Expression.Constant(null)), + Expression.Constant(default(bool?), typeof(bool?)), + Expression.Convert(Expression.NotEqual(res, Zero(dbColumnType, true)), typeof(bool?))); + } + else + res = Expression.NotEqual(res, Zero(dbColumnType)); + } + else + res = Expression.Convert(res, type); + } else throw new StoredProcedureColumnException(type, dbColumnType, propertyName); } diff --git a/CodeOnlyTests/RowFactory/ComplexTypeRowFactoryTests.cs b/CodeOnlyTests/RowFactory/ComplexTypeRowFactoryTests.cs index 64ef161..3af5665 100644 --- a/CodeOnlyTests/RowFactory/ComplexTypeRowFactoryTests.cs +++ b/CodeOnlyTests/RowFactory/ComplexTypeRowFactoryTests.cs @@ -654,6 +654,188 @@ public void EnumPropertiesWillConvertNumericTypes() .And.Match(i => i.FooBar == FooBar.Foo, "the transformer doubles the result from the database"); } + [TestMethod] + public void NumericTypeWillConvertToTrueBoolIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 1 } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToFalseBoolIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 0 } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToTrueBoolIfMarkedWithConvert_WhenTransformerPassed() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 1 } + }, true); + + var toTest = RowFactory.Create().ParseRows( + reader, new[] { Mock.Of() }, CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToFalseBoolIfMarkedWithConvert_WhenTransformerPassed() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 0L } + }, true); + + var toTest = RowFactory.Create().ParseRows( + reader, new[] { Mock.Of() }, CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToTrueNullableBoolIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 1M } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled.Value, "the property should be converted to bool"); + } + + [TestMethod] + public void NullNumericTypeWillConvertToNullNullableBoolIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", null } + }, true); + Mock.Get(reader).Setup(rdr => rdr.GetFieldType(0)).Returns(typeof(int)); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled.HasValue, "the property should not have a value"); + } + + [TestMethod] + public void NumericTypeWillConvertToFalseNullableBoolIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 0 } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled.Value, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToTrueNullableBoolIfMarkedWithConvert_WhenTransformerPassed() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 1.0 } + }, true); + + var toTest = RowFactory.Create().ParseRows( + reader, new[] { Mock.Of() }, CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled.Value, "the property should be converted to bool"); + } + + [TestMethod] + public void NumericTypeWillConvertToFalseNullableBoolIfMarkedWithConvert_WhenTransformerPassed() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", 0 } + }, true); + + var toTest = RowFactory.Create().ParseRows( + reader, new[] { Mock.Of() }, CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled.Value, "the property should be converted to bool"); + } + + [TestMethod] + public void NullNumericTypeWillConvertToNullNullableBoolIfMarkedWithConvert_WhenTransformerPassed() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", null } + }, true); + Mock.Get(reader).Setup(rdr => rdr.GetFieldType(0)).Returns(typeof(short)); + + var toTest = RowFactory.Create().ParseRows( + reader, new[] { Mock.Of() }, CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => !i.IsEnabled.HasValue, "the property should not have a value"); + } + + [TestMethod] + public void TrueBoolWillConvertToNumericTypeIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", true } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled == 1, "the property should be converted from bool"); + } + + [TestMethod] + public void FalseBoolWillConvertToNumericTypeIfMarkedWithConvert() + { + var reader = CreateDataReader(new Dictionary + { + { "IsEnabled", false } + }, false); + + var toTest = RowFactory.Create().ParseRows( + reader, Enumerable.Empty(), CancellationToken.None); + + toTest.Should().ContainSingle("because one row is returned").Which + .Should().Match(i => i.IsEnabled == 0, "the property should be converted from bool"); + } + private static IDataReader CreateDataReader(Dictionary values, bool setupGetValue = true) { var keys = values.Keys.OrderBy(s => s).ToList(); @@ -821,6 +1003,24 @@ private class EnumValueTypesConverted public FooBar FooBar { get; set; } } + private class ConvertToBool + { + [ConvertNumeric] + public bool IsEnabled { get; set; } + } + + private class ConvertToNullableBool + { + [ConvertNumeric] + public bool? IsEnabled { get; set; } + } + + private class ConvertFromBool + { + [ConvertNumeric] + public int IsEnabled { get; set; } + } + private class StaticValueAttribute : DataTransformerAttributeBase { public object Result { get; set; }