DiFacto — Distributed Factorization Machines


1 著者

Mu Li, Ziqi Liu, Alexander J. Smola, Yu-Xiang Wang
Computer Science, Carnegie Mellon University

WSDM 2016 のBest Paper Honorable Mentionに選ばれる

DMLC (xgboost, mxnet, minerva, parameter server...)


2 紹介者

バクフー株式会社 柏野雄太 @yutakashino




3 ざっくり言うと


4 背景と動機


5 Factorization Machine

Screen Shot 2016-03-15 at 21.47.45.png

UIS={Alice(A),Bob(B),Charlie(C),...}={Titanic(TI),NottingHill(NH),StarWars(SW),StarTrek(ST),...}={(A,TI,20101,5),(A,NH,20102,3),(A,SW,20104,1),(B,SW,20095,4),(B,ST,20098,5),(C,TI,20099,1),(C,SW,200912,5)}

5.1 線型モデルとの関係

Screen Shot 2016-03-15 at 22.03.59.png
Screen Shot 2016-03-19 at 8.45.04 AM.png

https://github.com/coreylynch/pyFM/blob/master/pyfm/pylibfm.py

5.2 実装

./libFM -task r -train ml1m-train.libfm -test ml1m-test.libfm -dim ’1,1,8

6 DiFactoモデル

6.1 キャッチ

従来のFactrization Machineを改善

6.2 Memory Adaptive Constraints

Frequency threshold

Screen Shot 2016-03-13 at 18.29.44.png

6.3 Sparse Regularization

l1 shrinkage: 線形モデルのl1正則化のようなものをFMにも導入する

Screen Shot 2016-03-13 at 18.29.59.png

6.4 Frequency Adaptive Regularization

Screen Shot 2016-03-19 at 12.55.35 AM.png

6.5 結局最適化は…

Screen Shot 2016-03-19 at 12.53.54 AM.png


7 分散学習

Parameter Serverの仕組みで,重みV, wのアップデートをサーバで,勾配計算をワーカーに分散化

Screen Shot 2016-03-13 at 18.33.08.png

Screen Shot 2016-03-13 at 18.33.34.png

7.1 gradient

Screen Shot 2016-03-19 at 12.50.57 AM.png

7.2 update V

Screen Shot 2016-03-19 at 12.51.22 AM.png

7.3 update w

Screen Shot 2016-03-19 at 12.51.31 AM.png

7.4 収束解析

Screen Shot 2016-03-13 at 18.33.44.png

Screen Shot 2016-03-13 at 18.33.52.png

7.5 分散学習の実装

https://github.com/dmlc/difacto

#Start: Create one scheduler node, m worker nodes and nserver nodes over multiple machines.
#Scheduler Node:
Assume the data is partitioned into s parts p1, . . . , ps,
for t = 1 to T do
  Work packages P = {p1, . . . , ps}
  Accomplished packages A = ∅
  while P 6= ∅ do
    switch detected event from worker i do
      case idle
        Pick p ∈ P \ A and assign p to worker i
        A = A ∪ {p},
      case finished p
        P = P \ {p}
      case dead or timeout
        A = A \ {p},
  end while
end for
#Worker i:
Receive command “processing p” from the scheduler
while read a minibatch from p do
  Pull wi and Vi from server nodes for all features ithat appear in this minibatch
  Compute the gradient based on (16) and (17)
  Push gradient back to servers
end while
#Server i:
if received gradient from a worker then
  update w and V by using (20) and (19)
end if

8 DiFactoの実際の使い方

build/difacto data_in=data/gisette_scale val_data=data/gisette_scale.t lr=.02 V_dim=2 V_lr=.001

Screen Shot 2016-03-19 at 7.34.15 AM.png
Screen Shot 2016-03-19 at 7.40.10 AM.png
Screen Shot 2016-03-19 at 7.40.24 AM.png

tracker/dmlc_local.py -n 2 -s 1 bin/difacto.dmlc learn/difacto/guide/train.conf
train.conf
train_data = "data/train-part_[0-1].*"
val_data = "data/train-part_2.*"
data_format = "libsvm"
model_out = "model/criteo"
embedding {
  dim = 16
  threshold = 16
  lambda_l2 = 0.0001
}
lambda_l1 = 4
lr_eta = .01
max_data_pass = 1
minibatch = 1000
early_stop = 1
EOF

9 実験結果

9.1 Adaptive Memory

Screen Shot 2016-03-13 at 17.25.55.png

9.2 Fixed-point Compression

デフォルトで浮動小数点32 bitで表現されるgradientやモデルパラメータを精度の低い整数に「圧縮」して,ネットワーク負荷を下げたときに,どうなるか.

(a) 圧縮度が高いとネットワーク負荷は下がる(当たり前)
(b) 正確度について,CTR2は変わらない.Criteo2は一番圧縮度の高いときに6%ほど正確度が下がるのは当然だけれど,圧縮度が低いからといって,正確性が高くなるわけではないことがわかった.

Screen Shot 2016-03-13 at 17.39.57.png

9.3 LibFMとの比較

Screen Shot 2016-03-13 at 17.39.49.png


10 結論

Factorization MachineにAdaptive MemoryとFrequency adaptive正則化を入れ,Parameter Serverの仕組みで分散化させたDiFactoは,大きな問題を高速に取り扱うことができる.