@@ -499,9 +499,7 @@ def __init__(
499499 ** kwargs ,
500500 ):
501501 super ().__init__ (** kwargs )
502- self .method = "product_convolution2d"
503- if self .method == "product_convolution2d" :
504- self .update_parameters (filters , multipliers , padding , ** kwargs )
502+ self .update_parameters (filters , multipliers , padding , ** kwargs )
505503 self .to (device )
506504
507505 def A (
@@ -520,14 +518,8 @@ def A(
520518 otherwise the blurred output has the same size as the image.
521519 :param str device: cpu or cuda
522520 """
523- if self .method == "product_convolution2d" :
524- self .update_parameters (filters , multipliers , padding , ** kwargs )
525-
526- return product_convolution2d (
527- x , self .multipliers , self .filters , self .padding
528- )
529- else :
530- raise NotImplementedError ("Method not implemented in product-convolution" )
521+ self .update_parameters (filters , multipliers , padding , ** kwargs )
522+ return product_convolution2d (x , self .multipliers , self .filters , self .padding )
531523
532524 def A_adjoint (
533525 self , y : Tensor , filters = None , multipliers = None , padding = None , ** kwargs
@@ -545,16 +537,12 @@ def A_adjoint(
545537 otherwise the blurred output has the same size as the image.
546538 :param str device: cpu or cuda
547539 """
548- if self .method == "product_convolution2d" :
549- self .update_parameters (
550- filters = filters , multipliers = multipliers , padding = padding , ** kwargs
551- )
552-
553- return product_convolution2d_adjoint (
554- y , self .multipliers , self .filters , self .padding
555- )
556- else :
557- raise NotImplementedError ("Method not implemented in product-convolution" )
540+ self .update_parameters (
541+ filters = filters , multipliers = multipliers , padding = padding , ** kwargs
542+ )
543+ return product_convolution2d_adjoint (
544+ y , self .multipliers , self .filters , self .padding
545+ )
558546
559547 def update_parameters (
560548 self ,
0 commit comments