Skip to content

Commit

Permalink
PR #11202 from Eran: add rsutils::number::running_average<>
Browse files Browse the repository at this point in the history
  • Loading branch information
maloel authored Dec 15, 2022
2 parents f2b2434 + f7aad2c commit b761504
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 43 deletions.
108 changes: 108 additions & 0 deletions third-party/rsutils/include/rsutils/number/running-average.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// License: Apache 2.0. See LICENSE file in root directory.
// Copyright(c) 2022 Intel Corporation. All Rights Reserved.

#pragma once

#include <type_traits>
#include <limits>
#include <stdlib.h> // size_t


namespace rsutils {
namespace number {


// Compute the average of a set of numbers, one at a time, without overflow!
// We adapt this to a signed integral T type, where we must start counting leftovers.
// We also add some rounding.
//
template< class T, typename Enable = void >
class running_average;


// Compute the average of a set of numbers, one at a time.
// This is the basic implementation, using doubles. See:
// https://www.heikohoffmann.de/htmlthesis/node134.html
// The basic code:
// double average( double[] ary )
// {
// double avg = 0;
// int n = 0;
// for( double x : ary )
// avg += (x - avg) / ++n;
// return avg;
// }
//
template<>
class running_average< double >
{
double _avg = 0.;
size_t _n = 0;

public:
running_average() = default;

size_t size() const { return _n; }
double get() const { return _avg; }

void add( double x ) { _avg += (x - _avg) / ++_n; }
};


// Compute the average of a set of numbers, one at a time.
//
// Adapted to a signed integral T type, where we must start counting leftovers.
// We also add some rounding.
// And we must do all that WITHOUT OVERFLOW!
//
template< class T >
class running_average< T, typename std::enable_if< std::is_integral< T >::value >::type >
{
T _avg = 0;
size_t _n = 0;
T _leftover = 0;

public:
running_average() = default;

size_t size() const { return _n; }
T get() const { return _avg; }
T leftover() const { return _leftover; }

double fraction() const { return _n ? double( _leftover ) / double( _n ) : 0.; }
double get_double() const { return _avg + fraction(); }

void add( T x ) { _avg += int_div_mod( x - _avg, ++_n, _leftover ); }

private:
static T add_no_overflow( T a, T b )
{
if( a > 0 )
{
if( b > std::numeric_limits< T >::max() - a )
return a; // discard b
}
else if( a < 0 )
{
if( b < std::numeric_limits< T >::min() - a )
return a; // discard b
}
return a + b;
}
static T int_div_mod( int64_t dividend_, size_t n, T & remainder )
{
// We need the modulo sign to be the same as the dividend!
// And, more importantly, modulo can be implemented differently based on the compiler, so we cannot use it!
T dividend = add_no_overflow( dividend_, remainder );
// We want 6.5 to be rounded to 7, but have to be careful with the sign:
T rounding = n / 2;
T rounded = add_no_overflow( dividend, dividend < 0 ? -rounding : rounding );
T result = rounded / (T)n;
remainder = dividend - n * result;
return result;
}
};


} // namespace number
} // namespace rsutils
42 changes: 41 additions & 1 deletion third-party/rsutils/py/pyrsutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#include <rsutils/py/pybind11.h>
#include <rsutils/easylogging/easyloggingpp.h>
#include <rsutils/string/split.h>
#include <rsutils/string/from.h>
#include <rsutils/version.h>
#include <rsutils/number/running-average.h>
#include <rsutils/number/stabilized-value.h>
#include <rsutils/string/from.h>


#define NAME pyrsutils
#define SNAME "pyrsutils"
Expand Down Expand Up @@ -65,6 +67,44 @@ PYBIND11_MODULE(NAME, m) {
.def( py::self > py::self )
.def( "is_between", &version::is_between );

using int_avg = rsutils::number::running_average< int64_t >;
py::class_< int_avg >( m, "running_average_i" )
.def( py::init<>() )
.def( "__nonzero__", &int_avg::size ) // Called to implement truth value testing in Python 2
.def( "__bool__", &int_avg::size ) // Called to implement truth value testing in Python 3
.def( "size", &int_avg::size )
.def( "get", &int_avg::get )
.def( "leftover", &int_avg::leftover )
.def( "fraction", &int_avg::fraction )
.def( "get_double", &int_avg::get_double )
.def( "__int__", &int_avg::get )
.def( "__float__", &int_avg::get_double )
.def( "__str__", []( int_avg const & self ) -> std::string { return rsutils::string::from( self.get_double() ); } )
.def( "__repr__",
[]( int_avg const & self ) -> std::string {
return rsutils::string::from() << "<" SNAME ".running_average<int64_t>"
<< " " << self.get() << " "
<< ( self.leftover() < 0 ? "" : "+" ) << self.leftover()
<< "/" << self.size() << ">";
} )
.def( "add", &int_avg::add );

using double_avg = rsutils::number::running_average< double >;
py::class_< double_avg >( m, "running_average" )
.def( py::init<>() )
.def( "__nonzero__", &double_avg::size ) // Called to implement truth value testing in Python 2
.def( "__bool__", &double_avg::size ) // Called to implement truth value testing in Python 3
.def( "size", &double_avg::size )
.def( "get", &double_avg::get )
.def( "__float__", &double_avg::get )
.def( "__str__", []( double_avg const & self ) -> std::string { return rsutils::string::from( self.get() ); } )
.def( "__repr__",
[]( double_avg const & self ) -> std::string {
return rsutils::string::from() << "<" SNAME ".running_average<double>"
<< " " << self.get() << " /" << self.size() << ">";
} )
.def( "add", &double_avg::add );

using stabilized_value = rsutils::number::stabilized_value< double >;
auto not_empty = []( stabilized_value const & self ) {
return ! self.empty();
Expand Down
104 changes: 62 additions & 42 deletions unit-tests/py/rspy/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,34 @@ def print_stack():
"""


def check_failed():
def _count_check():
global n_assertions
n_assertions += 1


def check_passed():
"""
Function for when a check fails
:return: always False (so you can 'return check_failed()'
"""
_count_check()
reset_info()
return True


def check_failed( abort = False ):
"""
Function for when a check fails
:return: always False (so you can 'return check_failed()'
"""
_count_check()
global n_failed_assertions, test_failed
n_failed_assertions += 1
test_failed = True
print_info()
if abort:
abort()
return False


def abort():
Expand All @@ -159,20 +179,15 @@ def check( exp, description = None, abort_if_failed = False):
:param abort_if_failed: If True and assertion failed the test will be aborted
:return: True if assertion passed, False otherwise
"""
global n_assertions
n_assertions += 1
if not exp:
print_stack()
if description:
log.out( f" {description}" )
else:
log.out( f" check failed; received {exp}" )
check_failed()
if abort_if_failed:
abort()
return False
return check_failed( abort_if_failed )
reset_info()
return True
return check_passed()


def check_false( exp, description = None, abort_if_failed = False):
Expand All @@ -192,43 +207,58 @@ def check_equal(result, expected, abort_if_failed = False):
"""
if type(expected) == list:
log.out("check_equal should not be used for lists. Use check_equal_lists instead")
if abort_if_failed:
abort()
return False
global n_assertions
n_assertions += 1
return check_failed( abort_if_failed )
if result != expected:
print_stack()
log.out( " left :", result )
log.out( " right :", expected )
check_failed()
if abort_if_failed:
abort()
return False
reset_info()
return True
return check_failed( abort_if_failed )
return check_passed()


def check_between( result, min, max, abort_if_failed = False ):
"""
Used for asserting a variable is between two values
:param result: The actual value of a variable
:param min: The minimum expected value of the result
:param max: The maximum expected value of the result
:param abort_if_failed: If True and assertion failed the test will be aborted
:return: True if assertion passed, False otherwise
"""
if result < min or result > max:
print_stack()
log.out( " result :", result )
log.out( " between :", min, '-', max )
return check_failed( abort_if_failed )
return check_passed()


def check_approx_abs( result, expected, abs_err, abort_if_failed = False ):
"""
Used for asserting a variable has the expected value, plus/minus 'abs_err'
:param result: The actual value of a variable
:param expected: The expected value of the result
:param abs_err: How far away from expected we're allowed to get
:param abort_if_failed: If True and assertion failed the test will be aborted
:return: True if assertion passed, False otherwise
"""
return check_between( result, expected - abs_err, expected + abs_err, abort_if_failed )


def unreachable( abort_if_failed = False ):
"""
Used to assert that a certain section of code (exp: an if block) is not reached
:param abort_if_failed: If True and this function is reached the test will be aborted
"""
global n_assertions
n_assertions += 1
print_stack()
check_failed()
if abort_if_failed:
abort()
check_failed( abort_if_failed )


def unexpected_exception():
"""
Used to assert that an except block is not reached. It's different from unreachable because it expects
to be in an except block and prints the stack of the error and not the call-stack for this function
"""
global n_assertions
n_assertions += 1
traceback.print_exc( file = sys.stdout )
check_failed()

Expand All @@ -242,8 +272,6 @@ def check_equal_lists(result, expected, abort_if_failed = False):
:param abort_if_failed: If True and assertion failed the test will be aborted
:return: True if assertion passed, False otherwise
"""
global n_assertions
n_assertions += 1
failed = False
if len(result) != len(expected):
failed = True
Expand All @@ -260,12 +288,8 @@ def check_equal_lists(result, expected, abort_if_failed = False):
print_stack()
log.out( " result list :", result )
log.out( " expected list:", expected )
check_failed()
if abort_if_failed:
abort()
return False
reset_info()
return True
return check_failed( abort_if_failed )
return check_passed()


def check_exception(exception, expected_type, expected_msg = None, abort_if_failed = False):
Expand All @@ -288,12 +312,8 @@ def check_exception(exception, expected_type, expected_msg = None, abort_if_fail
if failed:
print_stack()
log.out( *failed )
check_failed()
if abort_if_failed:
abort()
return False
reset_info()
return True
return check_failed( abort_if_failed )
return check_passed()


def check_throws( _lambda, expected_type, expected_msg = None, abort_if_failed = False ):
Expand All @@ -306,8 +326,8 @@ def check_throws( _lambda, expected_type, expected_msg = None, abort_if_failed =
_lambda()
except Exception as e:
check_exception( e, expected_type, expected_msg, abort_if_failed )
else:
unexpected_exception()
return check_passed()
return check_failed( abort_if_failed )


def check_frame_drops(frame, previous_frame_number, allowed_drops = 1, allow_frame_counter_reset = False):
Expand Down
Loading

0 comments on commit b761504

Please sign in to comment.