Skip to content

Commit 9cc6cf9

Browse files
author
DB Tsai
committed
Removed the miniBatch in LBFGS.
1 parent 1ba6a33 commit 9cc6cf9

File tree

2 files changed

+16
-39
lines changed

2 files changed

+16
-39
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
4242
private var convergenceTol = 1E-4
4343
private var maxNumIterations = 100
4444
private var regParam = 0.0
45-
private var miniBatchFraction = 1.0
4645

4746
/**
4847
* Set the number of corrections used in the LBFGS update. Default 10.
@@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
5756
this
5857
}
5958

60-
/**
61-
* Set fraction of data to be used for each L-BFGS iteration. Default 1.0.
62-
*/
63-
def setMiniBatchFraction(fraction: Double): this.type = {
64-
this.miniBatchFraction = fraction
65-
this
66-
}
67-
6859
/**
6960
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
7061
* Smaller value will lead to higher accuracy with the cost of more iterations.
@@ -110,15 +101,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
110101
}
111102

112103
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
113-
val (weights, _) = LBFGS.runMiniBatchLBFGS(
104+
val (weights, _) = LBFGS.runLBFGS(
114105
data,
115106
gradient,
116107
updater,
117108
numCorrections,
118109
convergenceTol,
119110
maxNumIterations,
120111
regParam,
121-
miniBatchFraction,
122112
initialWeights)
123113
weights
124114
}
@@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
132122
@DeveloperApi
133123
object LBFGS extends Logging {
134124
/**
135-
* Run Limited-memory BFGS (L-BFGS) in parallel using mini batches.
136-
* In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
137-
* in order to compute a gradient estimate.
138-
* Sampling, and averaging the subgradients over this subset is performed using one standard
125+
* Run Limited-memory BFGS (L-BFGS) in parallel.
126+
* Averaging the subgradients over different partitions is performed using one standard
139127
* spark map-reduce in each iteration.
140128
*
141129
* @param data - Input data for L-BFGS. RDD of the set of data examples, each of
@@ -147,31 +135,27 @@ object LBFGS extends Logging {
147135
* @param convergenceTol - The convergence tolerance of iterations for L-BFGS
148136
* @param maxNumIterations - Maximal number of iterations that L-BFGS can be run.
149137
* @param regParam - Regularization parameter
150-
* @param miniBatchFraction - Fraction of the input data set that should be used for
151-
* one iteration of L-BFGS. Default value 1.0.
152138
*
153139
* @return A tuple containing two elements. The first element is a column matrix containing
154140
* weights for every feature, and the second element is an array containing the loss
155141
* computed for every iteration.
156142
*/
157-
def runMiniBatchLBFGS(
143+
def runLBFGS(
158144
data: RDD[(Double, Vector)],
159145
gradient: Gradient,
160146
updater: Updater,
161147
numCorrections: Int,
162148
convergenceTol: Double,
163149
maxNumIterations: Int,
164150
regParam: Double,
165-
miniBatchFraction: Double,
166151
initialWeights: Vector): (Vector, Array[Double]) = {
167152

168153
val lossHistory = new ArrayBuffer[Double](maxNumIterations)
169154

170155
val numExamples = data.count()
171-
val miniBatchSize = numExamples * miniBatchFraction
172156

173157
val costFun =
174-
new CostFun(data, gradient, updater, regParam, miniBatchFraction, miniBatchSize)
158+
new CostFun(data, gradient, updater, regParam, numExamples)
175159

176160
val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
177161

@@ -190,7 +174,7 @@ object LBFGS extends Logging {
190174
lossHistory.append(state.value)
191175
val weights = Vectors.fromBreeze(state.x)
192176

193-
logInfo("LBFGS.runMiniBatchLBFGS finished. Last 10 losses %s".format(
177+
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
194178
lossHistory.takeRight(10).mkString(", ")))
195179

196180
(weights, lossHistory.toArray)
@@ -205,8 +189,7 @@ object LBFGS extends Logging {
205189
gradient: Gradient,
206190
updater: Updater,
207191
regParam: Double,
208-
miniBatchFraction: Double,
209-
miniBatchSize: Double) extends DiffFunction[BDV[Double]] {
192+
numExamples: Long) extends DiffFunction[BDV[Double]] {
210193

211194
private var i = 0
212195

@@ -215,8 +198,7 @@ object LBFGS extends Logging {
215198
val localData = data
216199
val localGradient = gradient
217200

218-
val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i)
219-
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
201+
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
220202
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
221203
val l = localGradient.compute(
222204
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
@@ -234,7 +216,7 @@ object LBFGS extends Logging {
234216
Vectors.fromBreeze(weights),
235217
Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
236218

237-
val loss = lossSum / miniBatchSize + regVal
219+
val loss = lossSum / numExamples + regVal
238220
/**
239221
* It will return the gradient part of regularization using updater.
240222
*
@@ -256,8 +238,8 @@ object LBFGS extends Logging {
256238
Vectors.fromBreeze(weights),
257239
Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
258240

259-
// gradientTotal = gradientSum / miniBatchSize + gradientTotal
260-
axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
241+
// gradientTotal = gradientSum / numExamples + gradientTotal
242+
axpy(1.0 / numExamples, gradientSum, gradientTotal)
261243

262244
i += 1
263245

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
5959
val convergenceTol = 1e-12
6060
val maxNumIterations = 10
6161

62-
val (_, loss) = LBFGS.runMiniBatchLBFGS(
62+
val (_, loss) = LBFGS.runLBFGS(
6363
dataRDD,
6464
gradient,
6565
simpleUpdater,
6666
numCorrections,
6767
convergenceTol,
6868
maxNumIterations,
6969
regParam,
70-
miniBatchFrac,
7170
initialWeightsWithIntercept)
7271

7372
// Since the cost function is convex, the loss is guaranteed to be monotonically decreasing
@@ -104,15 +103,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
104103
val convergenceTol = 1e-12
105104
val maxNumIterations = 10
106105

107-
val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS(
106+
val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
108107
dataRDD,
109108
gradient,
110109
squaredL2Updater,
111110
numCorrections,
112111
convergenceTol,
113112
maxNumIterations,
114113
regParam,
115-
miniBatchFrac,
116114
initialWeightsWithIntercept)
117115

118116
val numGDIterations = 50
@@ -150,47 +148,44 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
150148
val maxNumIterations = 8
151149
var convergenceTol = 0.0
152150

153-
val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS(
151+
val (_, lossLBFGS1) = LBFGS.runLBFGS(
154152
dataRDD,
155153
gradient,
156154
squaredL2Updater,
157155
numCorrections,
158156
convergenceTol,
159157
maxNumIterations,
160158
regParam,
161-
miniBatchFrac,
162159
initialWeightsWithIntercept)
163160

164161
// Note that the first loss is computed with initial weights,
165162
// so the total numbers of loss will be numbers of iterations + 1
166163
assert(lossLBFGS1.length == 9)
167164

168165
convergenceTol = 0.1
169-
val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS(
166+
val (_, lossLBFGS2) = LBFGS.runLBFGS(
170167
dataRDD,
171168
gradient,
172169
squaredL2Updater,
173170
numCorrections,
174171
convergenceTol,
175172
maxNumIterations,
176173
regParam,
177-
miniBatchFrac,
178174
initialWeightsWithIntercept)
179175

180176
// Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed.
181177
assert(lossLBFGS2.length == 4)
182178
assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol)
183179

184180
convergenceTol = 0.01
185-
val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS(
181+
val (_, lossLBFGS3) = LBFGS.runLBFGS(
186182
dataRDD,
187183
gradient,
188184
squaredL2Updater,
189185
numCorrections,
190186
convergenceTol,
191187
maxNumIterations,
192188
regParam,
193-
miniBatchFrac,
194189
initialWeightsWithIntercept)
195190

196191
// With smaller convergenceTol, it takes more steps.

0 commit comments

Comments
 (0)