liuxiaolong
2021-07-20 58d904a328c0d849769b483e901a0be9426b8209
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright 2018 Hans Dembinski
//
// Distributed under the Boost Software License, version 1.0.
// (See accompanying file LICENSE_1_0.txt
// or copy at http://www.boost.org/LICENSE_1_0.txt)
 
#ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
#define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
 
#include <boost/core/nvp.hpp>
#include <boost/histogram/fwd.hpp> // for weighted_mean<>
#include <boost/histogram/weight.hpp>
#include <cassert>
#include <type_traits>
 
namespace boost {
namespace histogram {
namespace accumulators {
 
/**
  Calculates mean and variance of weighted sample.
 
  Uses West's incremental algorithm to improve numerical stability
  of mean and variance computation.
*/
template <class ValueType>
class weighted_mean {
public:
  using value_type = ValueType;
  using const_reference = const value_type&;
 
  weighted_mean() = default;
 
  /// Allow implicit conversion from other weighted_means
  template <class T>
  weighted_mean(const weighted_mean<T>& o)
      : sum_of_weights_{o.sum_of_weights_}
      , sum_of_weights_squared_{o.sum_of_weights_squared_}
      , weighted_mean_{o.weighted_mean_}
      , sum_of_weighted_deltas_squared_{o.sum_of_weighted_deltas_squared_} {}
 
  /// Initialize to external sum of weights, sum of weights squared, mean, and variance
  weighted_mean(const_reference wsum, const_reference wsum2, const_reference mean,
                const_reference variance)
      : sum_of_weights_(wsum)
      , sum_of_weights_squared_(wsum2)
      , weighted_mean_(mean)
      , sum_of_weighted_deltas_squared_(
            variance * (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_)) {}
 
  /// Insert sample x
  void operator()(const_reference x) { operator()(weight(1), x); }
 
  /// Insert sample x with weight w
  void operator()(const weight_type<value_type>& w, const_reference x) {
    sum_of_weights_ += w.value;
    sum_of_weights_squared_ += w.value * w.value;
    const auto delta = x - weighted_mean_;
    weighted_mean_ += w.value * delta / sum_of_weights_;
    sum_of_weighted_deltas_squared_ += w.value * delta * (x - weighted_mean_);
  }
 
  /// Add another weighted_mean
  weighted_mean& operator+=(const weighted_mean& rhs) {
    if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) {
      const auto tmp =
          weighted_mean_ * sum_of_weights_ + rhs.weighted_mean_ * rhs.sum_of_weights_;
      sum_of_weights_ += rhs.sum_of_weights_;
      sum_of_weights_squared_ += rhs.sum_of_weights_squared_;
      weighted_mean_ = tmp / sum_of_weights_;
    }
    sum_of_weighted_deltas_squared_ += rhs.sum_of_weighted_deltas_squared_;
    return *this;
  }
 
  /** Scale by value
 
   This acts as if all samples were scaled by the value.
  */
  weighted_mean& operator*=(const_reference s) {
    weighted_mean_ *= s;
    sum_of_weighted_deltas_squared_ *= s * s;
    return *this;
  }
 
  bool operator==(const weighted_mean& rhs) const noexcept {
    return sum_of_weights_ == rhs.sum_of_weights_ &&
           sum_of_weights_squared_ == rhs.sum_of_weights_squared_ &&
           weighted_mean_ == rhs.weighted_mean_ &&
           sum_of_weighted_deltas_squared_ == rhs.sum_of_weighted_deltas_squared_;
  }
 
  bool operator!=(const weighted_mean& rhs) const noexcept { return !operator==(rhs); }
 
  /// Return sum of weights
  const_reference sum_of_weights() const noexcept { return sum_of_weights_; }
 
  /// Return sum of weights squared (variance of weight distribution)
  const_reference sum_of_weights_squared() const noexcept {
    return sum_of_weights_squared_;
  }
 
  /** Return mean value of accumulated weighted samples.
 
    The result is undefined, if `sum_of_weights() == 0`.
  */
  const_reference value() const noexcept { return weighted_mean_; }
 
  /** Return variance of accumulated weighted samples
 
    The result is undefined, if `sum_of_weights() == 0` or
    `sum_of_weights() == sum_of_weights_squared()`.
  */
  value_type variance() const {
    return sum_of_weighted_deltas_squared_ /
           (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_);
  }
 
  template <class Archive>
  void serialize(Archive& ar, unsigned /* version */) {
    ar& make_nvp("sum_of_weights", sum_of_weights_);
    ar& make_nvp("sum_of_weights_squared", sum_of_weights_squared_);
    ar& make_nvp("weighted_mean", weighted_mean_);
    ar& make_nvp("sum_of_weighted_deltas_squared", sum_of_weighted_deltas_squared_);
  }
 
private:
  value_type sum_of_weights_{};
  value_type sum_of_weights_squared_{};
  value_type weighted_mean_{};
  value_type sum_of_weighted_deltas_squared_{};
};
 
} // namespace accumulators
} // namespace histogram
} // namespace boost
 
#ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED
namespace std {
template <class T, class U>
/// Specialization for boost::histogram::accumulators::weighted_mean.
struct common_type<boost::histogram::accumulators::weighted_mean<T>,
                   boost::histogram::accumulators::weighted_mean<U>> {
  using type = boost::histogram::accumulators::weighted_mean<common_type_t<T, U>>;
};
} // namespace std
#endif
 
#endif