Skip to content

Conversation

@zitongzhan
Copy link
Contributor

@zitongzhan zitongzhan commented Dec 7, 2023

Add cg solver
Solve #311

@zitongzhan zitongzhan changed the title Bsr/cg Add CG solver Dec 7, 2023
@zitongzhan zitongzhan marked this pull request as draft December 7, 2023 05:39
@zitongzhan zitongzhan marked this pull request as ready for review December 7, 2023 06:29
@zitongzhan zitongzhan requested a review from wang-chen December 7, 2023 06:30
@zitongzhan
Copy link
Contributor Author

@zitongzhan zitongzhan requested a review from hxu296 December 7, 2023 19:32
Copy link
Member

@hxu296 hxu296 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zitongzhan zitongzhan requested review from wang-chen and removed request for wang-chen December 7, 2023 21:04
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't follow the existing style. Refer to this test or this test with predefined data. Use class style.

Comment on lines 304 to 332
if self.maxiter is None:
maxiter = n*10
else:
maxiter = self.maxiter
r = b - bmv(A, x) if x.any() else b.clone()
rho_prev, p = None, None

for iteration in range(maxiter):
if (torch.linalg.norm(r, dim=-1) < atol).all():
return x

z = bmv(M, r) if M is not None else r
rho_cur = vecdot(r, z)
if iteration > 0:
beta = rho_cur / rho_prev
p = p * beta.unsqueeze(-1) + z
else: # First spin
p = torch.empty_like(r)
p[:] = z[:]

q = bmv(A, p)
alpha = rho_cur / vecdot(p, q)
x += alpha.unsqueeze(-1)*p
r -= alpha.unsqueeze(-1)*q
rho_prev = rho_cur

else: # for loop exhausted
# Return incomplete progress
return x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about

        if self.maxiter is None:
            maxiter = n * 10
        else:
            maxiter = self.maxiter
        r = b - bmv(A, x) if x.any() else b.clone()
        p = torch.empty_like(r)
        rho_prev, p[:]= None, z[:]

        for iteration in range(maxiter):
            if (torch.linalg.norm(r, dim=-1) < atol).all():
                return x

            z = bmv(M, r) if M is not None else r
            rho_cur = vecdot(r, z)
            beta = rho_cur / rho_prev
            p = p * beta.unsqueeze(-1) + z
            q = bmv(A, p)
            alpha = rho_cur / vecdot(p, q)
            x += alpha.unsqueeze(-1)*p
            r -= alpha.unsqueeze(-1)*q
            rho_prev = rho_cur

        return x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

z has to be defined before it is used in rho_prev, p[:]= None, z[:]

@zitongzhan zitongzhan merged commit 375740e into main Dec 8, 2023
@zitongzhan zitongzhan deleted the bsr/cg branch December 8, 2023 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants