Skip to content

Commit 2f9ff55

Browse files
committed
normalize
2 parents b70d651 + 7116f70 commit 2f9ff55

File tree

10 files changed

+31
-33
lines changed

10 files changed

+31
-33
lines changed

deepinv/physics/blur.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,7 @@ def __init__(
499499
**kwargs,
500500
):
501501
super().__init__(**kwargs)
502-
self.method = "product_convolution2d"
503-
if self.method == "product_convolution2d":
504-
self.update_parameters(filters, multipliers, padding, **kwargs)
502+
self.update_parameters(filters, multipliers, padding, **kwargs)
505503
self.to(device)
506504

507505
def A(
@@ -520,14 +518,8 @@ def A(
520518
otherwise the blurred output has the same size as the image.
521519
:param str device: cpu or cuda
522520
"""
523-
if self.method == "product_convolution2d":
524-
self.update_parameters(filters, multipliers, padding, **kwargs)
525-
526-
return product_convolution2d(
527-
x, self.multipliers, self.filters, self.padding
528-
)
529-
else:
530-
raise NotImplementedError("Method not implemented in product-convolution")
521+
self.update_parameters(filters, multipliers, padding, **kwargs)
522+
return product_convolution2d(x, self.multipliers, self.filters, self.padding)
531523

532524
def A_adjoint(
533525
self, y: Tensor, filters=None, multipliers=None, padding=None, **kwargs
@@ -545,16 +537,12 @@ def A_adjoint(
545537
otherwise the blurred output has the same size as the image.
546538
:param str device: cpu or cuda
547539
"""
548-
if self.method == "product_convolution2d":
549-
self.update_parameters(
550-
filters=filters, multipliers=multipliers, padding=padding, **kwargs
551-
)
552-
553-
return product_convolution2d_adjoint(
554-
y, self.multipliers, self.filters, self.padding
555-
)
556-
else:
557-
raise NotImplementedError("Method not implemented in product-convolution")
540+
self.update_parameters(
541+
filters=filters, multipliers=multipliers, padding=padding, **kwargs
542+
)
543+
return product_convolution2d_adjoint(
544+
y, self.multipliers, self.filters, self.padding
545+
)
558546

559547
def update_parameters(
560548
self,

deepinv/physics/forward.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from typing import Union
3+
import warnings
34
import copy
45
import inspect
56
import collections.abc
@@ -49,6 +50,7 @@ def __init__(
4950
solver="gradient_descent",
5051
max_iter=50,
5152
tol=1e-4,
53+
**kwargs,
5254
):
5355
super().__init__()
5456
self.noise_model = noise_model
@@ -59,6 +61,11 @@ def __init__(
5961
self.tol = tol
6062
self.solver = solver
6163

64+
if len(kwargs) > 0:
65+
warnings.warn(
66+
f"Arguments {kwargs} are passed to {self.__class__.__name__} but are ignored."
67+
)
68+
6269
def __mul__(self, other):
6370
r"""
6471
Concatenates two forward operators :math:`A = A_1\circ A_2` via the mul operation
@@ -411,6 +418,7 @@ def __init__(
411418
max_iter=max_iter,
412419
solver=solver,
413420
tol=tol,
421+
**kwargs,
414422
)
415423
self.A_adj = A_adjoint
416424

deepinv/physics/phase_retrieval.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ def __init__(
172172
img_size=img_size,
173173
fast=False,
174174
channelwise=channelwise,
175-
unitary=unitary,
176-
compute_inverse=compute_inverse,
177175
dtype=dtype,
178176
device=device,
179177
rng=self.rng,
@@ -291,7 +289,6 @@ def __init__(
291289
B = StructuredRandom(
292290
img_size=self.img_size,
293291
output_size=self.output_size,
294-
mode=self.mode,
295292
n_layers=self.n_layers,
296293
transform_func=transform_func,
297294
transform_func_inv=transform_func_inv,

deepinv/tests/test_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_deprecated_physics_image_size():
2727
# CS: img_shape is changed to img_size
2828
with pytest.warns(DeprecationWarning, match="img_shape.*deprecated"):
2929
p = dinv.physics.CompressedSensing(
30-
m=m, img_shape=img_size, device="cpu", compute_inverse=True, rng=rng
30+
m=m, img_shape=img_size, device="cpu", rng=rng
3131
)
3232
assert p.img_size == img_size
3333

deepinv/tests/test_physics.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def find_operator(name, device, imsize=None, get_physics_param=False):
123123
if name == "CS":
124124
m = 30
125125
p = dinv.physics.CompressedSensing(
126-
m=m, img_size=img_size, device=device, compute_inverse=True, rng=rng
126+
m=m, img_size=img_size, device=device, rng=rng
127127
)
128128
norm = (
129129
1 + np.sqrt(np.prod(img_size) / m)
@@ -1131,7 +1131,7 @@ def test_noise(device, noise_type):
11311131
r"""
11321132
Tests noise models.
11331133
"""
1134-
physics = dinv.physics.DecomposablePhysics(device=device)
1134+
physics = dinv.physics.DecomposablePhysics()
11351135
physics.noise_model = choose_noise(noise_type, device)
11361136
x = torch.ones((1, 3, 2), device=device).unsqueeze(0)
11371137

@@ -1177,7 +1177,6 @@ def test_blur(device):
11771177
h = torch.ones((1, 1, 5, 5)) / 25.0
11781178

11791179
physics_blur = dinv.physics.Blur(
1180-
img_size=(1, x.shape[-2], x.shape[-1]),
11811180
filter=h,
11821181
device=device,
11831182
padding="circular",
@@ -2065,3 +2064,10 @@ def test_clone(name, device):
20652064
# Restore original values
20662065
physics = saved_physics
20672066
physics_clone = saved_physics_clone
2067+
2068+
2069+
def test_physics_warn_extra_kwargs():
2070+
with pytest.warns(
2071+
UserWarning, match="Arguments {'sigma': 0.5} are passed to Denoising"
2072+
):
2073+
dinv.physics.Denoising(sigma=0.5)

deepinv/tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_algo(name_algo, device):
112112
test_sample = torch.ones((1, 3, 64, 64), device=device)
113113

114114
sigma = 1
115-
physics = dinv.physics.Denoising(device=device)
115+
physics = dinv.physics.Denoising()
116116
physics.noise_model = dinv.physics.GaussianNoise(sigma)
117117
y = physics(test_sample)
118118

examples/basics/demo_blur_tour.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@
292292
)
293293
params_pc = pc_generator.step(batch_size)
294294

295-
physics = SpaceVaryingBlur(method="product_convolution2d", **params_pc)
295+
physics = SpaceVaryingBlur(**params_pc)
296296

297297
dirac_comb = torch.zeros(img_size)[None, None]
298298
dirac_comb[0, 0, ::delta, ::delta] = 1

examples/basics/demo_physics_tour.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@
9393
fast=False,
9494
channelwise=True,
9595
img_size=img_size,
96-
compute_inverse=True,
9796
device=device,
9897
)
9998

examples/external-libraries/demo_ri_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def to_logimage(im, rescale=False, dr=5000):
206206
physics = RadioInterferometry(
207207
img_size=image_gdth.shape[-2:],
208208
samples_loc=uv.permute((1, 0)),
209-
real=True,
209+
real_projection=True,
210210
device=device,
211211
)
212212

examples/patch-priors/demo_epll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
sigma = 0.1
3838
noise_model = GaussianNoise(sigma)
39-
physics = Denoising(device=device, noise_model=noise_model)
39+
physics = Denoising(noise_model=noise_model)
4040
observation = physics(test_img)
4141

4242
# %%

0 commit comments

Comments
 (0)