Skip to content

Jit fuse clamp#11574

Closed
t-vi wants to merge 14 commits intopytorch:masterfrom
t-vi:jit_fuse_clamp
Closed

Jit fuse clamp#11574
t-vi wants to merge 14 commits intopytorch:masterfrom
t-vi:jit_fuse_clamp

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Sep 12, 2018

This patch adds fused forward and backward for clamp to the jit.
This is one item of #11118 . If it's OK, I'd be happy to also add some more of #11118 .

The patch depends on #11150 , which I merged into master as a base. I'll rebase it when that or #10981 is merged.

This is first serious jit patch, thank you, @ngimel and the others for their guidance. All errors are my own.

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 12, 2018
{aten::sub, "(${0} - ${2}*${1})"},
{aten::rand_like, "uniform(rnd())"},
//min, max
{aten::clamp, "fmaxf(fminf(${0}, ${2}), ${1})"},

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Sep 12, 2018

Unfortunately this will conflict with #10981 😕

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 12, 2018

I thought the plan was to have #10981 go first, then #11150 and then this...
Gives me some time to look at the other ones, at least threshold and fix the above.

// boundary and the factor is 1 when the boundary is NaN
// the ! is expressed as "1-" for lack of a "not" function and
// the the fuser insisting on float
return {(inputs.at(0).isnan() ? inputs.at(0) : grads.at(0))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I'd like to simplify the derivative formula

return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};

} else if (node->matches("aten::clamp(Tensor self, Scalar min, Scalar max) -> Tensor")) {
// we do two type_as as it's free (hopefully) and the "*" only works with float

This comment was marked as off-topic.

This comment was marked as off-topic.

// but that is hard to reliably code here, so we have 0 as gradient
// when the input is NaN (unless grads is NaN or infinite)
return {grads.at(0)
* (1-(inputs.at(0).isnan()).type_as(inputs.at(0)))

This comment was marked as off-topic.

This comment was marked as off-topic.

ezyang and others added 2 commits September 18, 2018 00:53
Thank you, @apaszke, for your feedback!

Also put back support for DifferentiableGraph in assertAllFused.
@t-vi
Copy link
Collaborator Author

t-vi commented Sep 18, 2018

@ezyang I think it is good, but I don't know what to make out of the CI failure - it doesn't look related at first sight but from clicking on "previous build" a couple of times, it's apparently not shared by other PRs...

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants