Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5278339
Fix to allow M
JohnStott Jul 10, 2018
1c5a6cd
Merge pull request #1 from JohnStott/mae_sample_wts
JohnStott Jul 10, 2018
0dacd2e
Updated MAE test to consider sample_weights in calculation
JohnStott Jul 10, 2018
d7e8161
Merge branch 'master' of https://github.com/JohnStott/scikit-learn
JohnStott Jul 10, 2018
16bd695
Removed comment
JohnStott Jul 10, 2018
2bddc6a
Fixed: E501 line too long (82 > 79 characters)
JohnStott Jul 10, 2018
37badb8
syntax correction
JohnStott Jul 10, 2018
a404983
Added fix details
JohnStott Jul 11, 2018
6ad17c0
Changed to use consistent datatypes during calculaions
JohnStott Jul 12, 2018
2d0a97e
Corrected formatting
JohnStott Jul 12, 2018
1a36123
local testing
JohnStott Jul 12, 2018
f49ef59
Requested Changes
JohnStott Jul 13, 2018
af98aeb
changes as per review
JohnStott Jul 13, 2018
db74c0e
check for empty stack
JohnStott Jul 14, 2018
c35624f
fixed issue
JohnStott Jul 14, 2018
a136cf5
removed explicit casts
JohnStott Jul 14, 2018
ad6201b
removed explicit casts
JohnStott Jul 14, 2018
88ade1e
removed debug info
JohnStott Jul 15, 2018
aa073d5
Removed unnecessary explicits
JohnStott Jul 15, 2018
5f90f71
Removed unnecessary explicit casts
JohnStott Jul 15, 2018
bd417e9
merge conflict resolution
JohnStott Jul 16, 2018
0912207
added additional test
JohnStott Jul 16, 2018
947d54e
Merge branch 'master' into median_fix
JohnStott Jul 16, 2018
6c8ff77
updated comments
JohnStott Jul 16, 2018
2760423
Merge branch 'master' into median_fix
JohnStott Jul 16, 2018
100157e
Requested changes incl additional unit test
JohnStott Jul 16, 2018
28663d9
Merge branch 'master' into median_fix
JohnStott Jul 16, 2018
de00b02
fix mistake
JohnStott Jul 16, 2018
fed3117
formatting
JohnStott Jul 16, 2018
3752d35
Merge branch 'master' into median_fix
JohnStott Jul 16, 2018
bca9282
removed whitespace
JohnStott Jul 16, 2018
8ad1414
removed whitespace
JohnStott Jul 16, 2018
74c9791
added test notes
JohnStott Jul 16, 2018
42a050b
formatting
JohnStott Jul 16, 2018
cba8bf2
Requested changes
JohnStott Jul 16, 2018
eeee051
Trailing space fix attempt
JohnStott Jul 17, 2018
fdb30ff
Trailing whitespace fix attempt #2
JohnStott Jul 17, 2018
82bcba0
remove whitespace #3
JohnStott Jul 17, 2018
ad8409f
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
JohnStott Jul 17, 2018
d7d8dee
merge
JohnStott Jul 17, 2018
d0c503d
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
JohnStott Jul 19, 2018
af73020
Merge branch 'master' into median_fix
JohnStott Jul 19, 2018
0d50a23
clean up / extras
JohnStott Jul 21, 2018
9b4b88b
clean up
JohnStott Jul 21, 2018
caf017e
missing closing bracket
JohnStott Jul 21, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sklearn/tree/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ cdef class WeightedMedianCalculator:
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
cdef int reset(self) nogil except -1
cdef int update_median_parameters_post_push(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
self, DOUBLE_t data, DOUBLE_t weight, DOUBLE_t original_median,
int push_index) nogil
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef int update_median_parameters_post_remove(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
self, DOUBLE_t data, DOUBLE_t weight, DOUBLE_t original_median,
int removal_index) nogil
cdef DOUBLE_t get_median(self) nogil
109 changes: 77 additions & 32 deletions sklearn/tree/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,8 @@ cdef class WeightedPQueue:

cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1:
"""Push record on the array.

Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
or returns index of the item added.
"""
cdef SIZE_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = NULL
Expand All @@ -370,11 +369,12 @@ cdef class WeightedPQueue:

# Increase element count
self.array_ptr = array_ptr + 1
return 0
return i

cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil:
"""Remove a specific value/weight record from the array.
Returns 0 if successful, -1 if record not found."""
Returns the index of the item removed, -1 if record not
found."""
cdef SIZE_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = self.array_
cdef SIZE_t idx_to_remove = -1
Expand All @@ -398,11 +398,12 @@ cdef class WeightedPQueue:
array[i] = array[i+1]

self.array_ptr = array_ptr - 1
return 0
return idx_to_remove

cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil:
"""Remove the top (minimum) element from array.
Returns 0 if successful, -1 if nothing to remove."""
Returns the index of the item popped if successful (will always be
zero), -1 if nothing to remove."""
cdef SIZE_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = self.array_
cdef SIZE_t i
Expand All @@ -419,6 +420,8 @@ cdef class WeightedPQueue:
array[i] = array[i+1]

self.array_ptr = array_ptr - 1

# the index of the popped item will always be zero:
return 0

cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil:
Expand Down Expand Up @@ -518,20 +521,22 @@ cdef class WeightedMedianCalculator:
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef int return_value
cdef int push_index
cdef DOUBLE_t original_median

if self.size() != 0:
original_median = self.get_median()
# samples.push (WeightedPQueue.push) uses safe_realloc, hence except -1
return_value = self.samples.push(data, weight)
self.update_median_parameters_post_push(data, weight,
original_median)
return return_value
push_index = self.samples.push(data, weight)
if push_index == -1:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if push_index == -1:
return -1

I added this to replicate what was previously being returned. Though we should only get a -1 when an exception occurs i.e., MemoryError. So in hindsight I think I can remove this check since an exception in self.samples.push should terminate immediately...? I am not 100% sure though with being new to Cyphon?

return -1
self.update_median_parameters_post_push(data, weight, original_median,
push_index)
return 0

cdef int update_median_parameters_post_push(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil:
DOUBLE_t original_median, int push_index) nogil:
"""Update the parameters used in the median calculation,
namely `k` and `sum_w_0_k` after an insertion"""

Expand All @@ -545,6 +550,25 @@ cdef class WeightedMedianCalculator:
# get the original weighted median
self.total_weight += weight

if data == original_median:
if push_index < self.k:
self.sum_w_0_k += weight
self.k += 1

while(self.k > 1 and ((self.sum_w_0_k -
self.samples.get_weight_from_index(self.k-1))
>= self.total_weight / 2.0)):
self.k -= 1
self.sum_w_0_k -= self.samples.get_weight_from_index(self.k)

while(self.k < self.samples.size() and
(self.sum_w_0_k < self.total_weight / 2.0)):

self.k += 1
self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1)

return 0

if data < original_median:
# inserting below the median, so increment k and
# then update self.sum_w_0_k accordingly by adding
Expand All @@ -562,7 +586,7 @@ cdef class WeightedMedianCalculator:
self.sum_w_0_k -= self.samples.get_weight_from_index(self.k)
return 0

if data >= original_median:
if data > original_median:
# inserting above or at the median
# minimize k such that sum(W[0:k]) >= total_weight / 2
while(self.k < self.samples.size() and
Expand All @@ -575,40 +599,40 @@ cdef class WeightedMedianCalculator:
"""Remove a value from the MedianHeap, removing it
from consideration in the median calculation
"""
cdef int return_value
cdef int removal_index
cdef DOUBLE_t original_median

if self.size() != 0:
original_median = self.get_median()
# no elements to remove
if self.size() == 0:
return -1

original_median = self.get_median()

return_value = self.samples.remove(data, weight)
self.update_median_parameters_post_remove(data, weight,
original_median)
return return_value
removal_index = self.samples.remove(data, weight)
self.update_median_parameters_post_remove(data, weight, original_median, removal_index)
return 0

cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil:
"""Pop a value from the MedianHeap, starting from the
left and moving to the right.
"""
cdef int return_value
cdef int removal_index
cdef double original_median

if self.size() != 0:
original_median = self.get_median()

# no elements to pop
if self.samples.size() == 0:
if self.size() == 0:
return -1

return_value = self.samples.pop(data, weight)
self.update_median_parameters_post_remove(data[0],
weight[0],
original_median)
return return_value
original_median = self.get_median()

removal_index = self.samples.pop(data, weight)
self.update_median_parameters_post_remove(data[0], weight[0],
original_median, removal_index)
return 0

cdef int update_median_parameters_post_remove(
self, DOUBLE_t data, DOUBLE_t weight,
double original_median) nogil:
double original_median, int removal_index) nogil:
"""Update the parameters used in the median calculation,
namely `k` and `sum_w_0_k` after a removal"""
# reset parameters because it there are no elements
Expand All @@ -628,6 +652,27 @@ cdef class WeightedMedianCalculator:
# get the current weighted median
self.total_weight -= weight

# if data removed was part of the original median then we need
# to look on both sides of k because there may be duplicates
# with differing weights:
if data == original_median:
if removal_index < self.k:
self.k -= 1
self.sum_w_0_k -= weight

while(self.k < self.samples.size() and
(self.sum_w_0_k < self.total_weight / 2.0)):
self.k += 1
self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1)

while(self.k > 1 and ((self.sum_w_0_k -
self.samples.get_weight_from_index(self.k-1))
>= self.total_weight / 2.0)):
self.k -= 1
self.sum_w_0_k -= self.samples.get_weight_from_index(self.k)

return 0

if data < original_median:
# removing below the median, so decrement k and
# then update self.sum_w_0_k accordingly by subtracting
Expand All @@ -646,7 +691,7 @@ cdef class WeightedMedianCalculator:
self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1)
return 0

if data >= original_median:
if data > original_median:
# removing above the median
# minimize k such that sum(W[0:k]) >= total_weight / 2
while(self.k > 1 and ((self.sum_w_0_k -
Expand Down
8 changes: 8 additions & 0 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,14 @@ def test_mae():
assert_array_equal(dt_mae.tree_.impurity, [1.4, 1.5, 4.0 / 3.0])
assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0])

# This test ensures the correct median is being calculated when
# we have duplicate y values and non-uniform sample weights.
# Bug fixed in version 0.20 (see issue #10725):
dt_mae.fit(X=[[1.42055744], [0.958369], [0.38367319], [0.83952129]],
y=[1, 2, 1, 1], sample_weight=[1, 2, 2, 1])
assert_array_equal(dt_mae.tree_.impurity, [1.0 / 3.0, 0.0, 1.0 / 3.0])
assert_array_equal(dt_mae.tree_.value.flat, [1.0, 1.0, 2.0])


def test_criterion_copy():
# Let's check whether copy of our criterion has the same type
Expand Down