Skip to content

Refactor ANN code to avoid rvalue references.#2259

Merged
rcurtin merged 3 commits intomlpack:masterfrom
rcurtin:ann-rvalue-reference
Mar 11, 2020
Merged

Refactor ANN code to avoid rvalue references.#2259
rcurtin merged 3 commits intomlpack:masterfrom
rcurtin:ann-rvalue-reference

Conversation

@rcurtin
Copy link
Copy Markdown
Member

@rcurtin rcurtin commented Mar 6, 2020

Previously, all of the ANN code had signatures like

Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);

but the use of rvalue references led to some incorrect memory usages, and it was a little confusing as people would call Forward(std::move(input), std::move(output))... but input would remain unchanged (this is confusing given the expected std::move() semantics).

In this PR I've refactored out all unnecessary use of rvalue references. I think, in some places, this makes copies avoidable and there may be some minor acceleration. I'm running some quick benchmarks in the models/ repo and will post the results here when they're done (I don't expect any serious speedup).

@rcurtin
Copy link
Copy Markdown
Member Author

rcurtin commented Mar 6, 2020

(I posted another timing comment earlier, but I was not able to reproduce the numbers so I deleted it. It must have been a system issue or something.)

I ran the LSTMTimeSeriesMultivariate program from the models repository, with the number of epochs modified to be 50, on the current mlpack master branch code:

# old
$ time ./LSTMTimeSeriesMultivariate
Reading data ...
Training ...
1 - Mean Squared Error := 0.61127
...
50 - Mean Squared Error := 0.0738777
Finished training.
Saving Model
Model saved in lstm_multi.bin
Loading model ...
Mean Squared Error on Prediction data points:= 0.0738777
The predicted Google stock (high, low) for the last day is: 
  (1132.37, 1112.9)

real	1m32.897s
user	6m57.729s
sys	0m1.669s

and on this branch:

# new
$ time ./LSTMTimeSeriesMultivariate
Reading data ... 
Training ...
1 - Mean Squared Error := 0.61127
...
50 - Mean Squared Error := 0.0738777
Finished training.
Saving Model
Model saved in lstm_multi.bin
Loading model ...
Mean Squared Error on Prediction data points:= 0.0738777
The predicted Google stock (high, low) for the last day is:
  (1132.37, 1112.9)

real    1m29.893s
user    6m39.122s
sys     0m1.382s

Over several trials both versions came out with about the same runtime (plus or minus 3 to 5 seconds).

@rcurtin
Copy link
Copy Markdown
Member Author

rcurtin commented Mar 6, 2020

However this doesn't seem to solve #2146, so the solution in #2234 seems to still be necessary (for now).

@kartikdutt18
Copy link
Copy Markdown
Member

Hi @rcurtin, This also doesn't solve #2221, so I'll try valgrind and gdb to gain more insight.

@rcurtin rcurtin added this to the mlpack 3.3.0 milestone Mar 9, 2020
Copy link
Copy Markdown
Member

@zoq zoq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, quite a few changes; looks good to me.

Copy link
Copy Markdown

@mlpack-bot mlpack-bot bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second approval provided automatically after 24 hours. 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants