Skip to content

Commit 53bd090

Browse files
committed
update based on review suggestions
1 parent 20a1c8e commit 53bd090

3 files changed

Lines changed: 15 additions & 21 deletions

File tree

cupy/core/core.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ cdef class ndarray:
1414
readonly object dtype
1515
readonly memory.MemoryPointer data
1616
readonly ndarray base
17-
readonly object _cuda_array_descr
1817

1918
cpdef tolist(self)
2019
cpdef tofile(self, fid, sep=*, format=*)

cupy/core/core.pyx

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,17 @@ cdef class ndarray:
132132
else:
133133
raise TypeError('order not understood. order={}'.format(order))
134134

135-
self._cuda_array_descr = None
136-
137135
@property
138136
def __cuda_array_interface__(self):
139-
if self._cuda_array_descr is None:
140-
desc = {
141-
'shape': self.shape,
142-
'typestr': self.dtype.str,
143-
'descr': self.dtype.descr,
144-
'data': (self.data.mem.ptr, False),
145-
'version': 0,
146-
}
147-
if not self._c_contiguous:
148-
desc['strides'] = self._strides
149-
150-
self._cuda_array_descr = desc
151-
else:
152-
desc = self._cuda_array_descr
137+
desc = {
138+
'shape': self.shape,
139+
'typestr': self.dtype.str,
140+
'descr': self.dtype.descr,
141+
'data': (self.data.mem.ptr, False),
142+
'version': 0,
143+
}
144+
if not self._c_contiguous:
145+
desc['strides'] = self._strides
153146

154147
return desc
155148

tests/cupy_tests/core_tests/test_ndarray.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,29 +135,31 @@ def test_shape_set_int(self, xp):
135135
class TestNdarrayCudaInterface(unittest.TestCase):
136136

137137
def test_cuda_array_interface(self):
138-
arr = cupy.zeros(shape=(2,3), dtype=cupy.float64)
138+
arr = cupy.zeros(shape=(2, 3), dtype=cupy.float64)
139139
iface = arr.__cuda_array_interface__
140140
self.assertEqual(set(iface.keys()),
141141
set(['shape', 'typestr', 'data', 'version', 'descr']))
142-
self.assertEqual(iface['shape'], (2,3))
142+
self.assertEqual(iface['shape'], (2, 3))
143143
self.assertEqual(iface['typestr'], '<f8')
144144
self.assertIsInstance(iface['data'], tuple)
145145
self.assertEqual(len(iface['data']), 2)
146+
self.assertEqual(iface['data'][0], arr.data.ptr)
146147
self.assertEqual(iface['data'][1], False)
147148
self.assertEqual(iface['version'], 0)
148149
self.assertEqual(iface['descr'], [('', '<f8')])
149150

150151
def test_cuda_array_interface_view(self):
151-
arr = cupy.zeros(shape=(10,20), dtype=cupy.float64)
152+
arr = cupy.zeros(shape=(10, 20), dtype=cupy.float64)
152153
view = arr[::2,::5]
153154
iface = view.__cuda_array_interface__
154155
self.assertEqual(set(iface.keys()),
155156
set(['shape', 'typestr', 'data', 'version',
156157
'strides', 'descr']))
157-
self.assertEqual(iface['shape'], (5,4))
158+
self.assertEqual(iface['shape'], (5, 4))
158159
self.assertEqual(iface['typestr'], '<f8')
159160
self.assertIsInstance(iface['data'], tuple)
160161
self.assertEqual(len(iface['data']), 2)
162+
self.assertEqual(iface['data'][0], arr.data.ptr)
161163
self.assertEqual(iface['data'][1], False)
162164
self.assertEqual(iface['version'], 0)
163165
self.assertEqual(iface['strides'], [320, 40])

0 commit comments

Comments
 (0)