๐Ÿง

(210829) Review: DEMIX Layers: Disentangling Domains for Modular Language Modeling

210829 Review Domain Adaptation
๋ณธ ๋…ผ๋ฌธ์€ ์ž‘๋…„ ์ฆˆ์Œ์— ํฅ๋ฏธ๋กญ๊ฒŒ ์ฝ์—ˆ๋˜ DAPT/TAPT ๋…ผ๋ฌธ ์ €์ž์˜ ์‹ ์ž‘์ด๋‹ค. ์—ญ์‹œ๋‚˜ Language Model์˜ Domain Adaptation์— ๊ด€ํ•œ ์—ฐ๊ตฌ์ด๋ฉฐ, Domain๋ณ„๋กœ Expert๋ฅผ ๊ตฌ์ถ•ํ•˜๊ณ  Mixture Model๋กœ ํ™œ์šฉํ•˜๋Š” ๋‚ด์šฉ์ด๋‹ค.

Domain Adaptation

๋งŽ์€ ์–‘์˜, General Data๋กœ ํ•™์Šต๋œ Language Model(LM)๋“ค์€ ํŠน์ • Domain์˜ Task์—์„œ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด์ง€ ๋ชปํ•œ๋‹ค. ํ•˜์—ฌ, ํ•ด๋‹น Domain Data๋กœ LM์˜ ์ถ”๊ฐ€์ ์ธ ํ•™์Šต(Domain Adaptation)์„ ์ˆ˜ํ–‰ํ•˜๋Š”๋ฐ, (๋…ผ๋ฌธ์—์„œ Dense Training์ด๋ผ๊ณ  ์นญํ•˜๋Š”) ์ผ๋ฐ˜์ ์ธ ๋ฐฉ๋ฒ•์€ LM์˜ ๋ชจ๋“  Parameter๋“ค์„ ๋ชจ๋“  Domain Data์—์„œ Loss๋ฅผ ์ค„์ด๋Š” ๋ฐฉํ–ฅ์œผ๋กœ Updateํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ํ•˜์ง€๋งŒ ์ด๋Ÿฌํ•œ ๋ฐฉ๋ฒ•์€ ๋ช‡ ๊ฐ€์ง€ ๋ฌธ์ œ๋ฅผ ๊ฐ–๋Š”๋‹ค.
โ€ข
LM์˜ ํฌ๊ธฐ๊ฐ€ ์ปค์งˆ์ˆ˜๋ก Computational Cost๊ฐ€ ๋น„์‹ธ์ง„๋‹ค.
โ€ข
์ƒˆ๋กœ์šด Domain์˜ Data๋กœ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ์ด์ „ Domain์˜ ์ •๋ณด๋ฅผ ์žŠ๋Š” ๋“ฑ, ๋‹ค๋ฅธ Domain์—์„œ์˜ Model Performance๊ฐ€ ์ €ํ•˜๋  ์ˆ˜ ์žˆ๋‹ค.
๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด Domain๋ณ„๋กœ Specialized Components(Layers)๋ฅผ ๊ฐ–๋Š” Modularํ•œ LM์„ ์ œ์•ˆํ•œ๋‹ค. Modular LM์˜ ํŠน์ง•์œผ๋กœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ ๋“ค์ด ์žˆ๋‹ค.
โ€ข
์ƒˆ๋กœ์šด Domain์„ ํ•™์Šตํ•  ๋•Œ Shared Parameter ์ œ์™ธ, Specialized Parameter๋“ค๋งŒ Update๋˜๋ฏ€๋กœ Computationally ํšจ์œจ์ ์ด๋ฉฐ, ๋‹ค๋ฅธ Domain์˜ Performance์— ๋ณ€ํ™”๊ฐ€ ์ƒ๊ธฐ์ง€ ์•Š๋Š”๋‹ค.
โ€ข
์ด๋ ‡๊ฒŒ ํ•™์Šตํ•œ Domain๋ณ„ Layer๋“ค์„ Expert๋ผ๊ณ  ์นญํ•œ๋‹ค.
โ€ข
Expert๋Š” GPT-3์™€ ๊ฐ™์€ Transformer LM์˜ Feed-Forward Network(FFN) Component๋ฅผ ๋ณ€ํ™˜ํ•œ ๊ฒƒ์ด๋‹ค.
โ€ข
Inference์‹œ์—๋Š” ์—ฌ๋Ÿฌ Expert๋“ค์„ Mixingํ•˜๋Š” ๋“ฑ Customizing์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

Multi-Domain Corpus

Training Data์™€ Inference์‹œ์— ์‚ฌ์šฉํ•˜๋Š” ์ƒˆ๋กœ์šด Domain์˜ Data: Novel Data๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” Corpus์˜ ์ •๋ณด๋ฅผ ๋‚˜ํƒ€๋‚ธ ํ‘œ์ด๋‹ค.
ํ•™์Šต ๊ณผ์ •์—์„œ Training Data๋กœ ์ด 8๊ฐœ์˜ Expert๋ฅผ ๊ตฌ์ถ•ํ•˜๊ณ , ์‹คํ—˜์—์„œ๋Š” 8์ข…๋ฅ˜์˜ In-Domain Data์™€ 8์ข…๋ฅ˜์˜ Out-Of-Domain(Novel) Data๋กœ LM์˜ (Generalization)์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•œ๋‹ค.

Proposed Model(Expert): DEMIX Layer

Demix Routing
๊ธฐ์กด์˜ Mixture-Of-Experts Transformer์—์„œ FFN Components๋Š” (ํŠน์ • Domain์œผ๋กœ์˜) Routing Function(g_n)๊ณผ Domain๋ณ„ FFN Layer๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
DEMIX Layer๋Š” ์œ„ Component์™€ ๋™์ผํ•œ ํ˜•ํƒœ๋ฅผ ๊ฐ–์ง€๋งŒ, Token Level์—์„œ Routing์ด ์ด๋ฃจ์–ด์ง€๋ฉฐ+Token๋‹น ์ตœ๋Œ€ 2๊ฐœ์˜ Domain์ด ๋ถ€์—ฌ๋˜๋Š” ๊ธฐ์กด์˜ ๋ฐฉ์‹๊ณผ ๋‹ค๋ฅด๊ฒŒ, ๊ฐ™์€ Sequence์— ์กด์žฌํ•˜๋Š” Token๋“ค์— ๋ชจ๋‘ ๋™์ผํ•œ ํ•˜๋‚˜์˜ Domain์„ ๋ถ€์—ฌํ•œ๋‹ค. (Domain๋ณ„ Metadata ํ™œ์šฉ).
(์ค‘์š”: Sequence ์ „์ฒด์— ํ•˜๋‚˜์˜ Domain์„ ๋ถ€์—ฌํ•˜๋Š” ๊ฒƒ์€ ํ•™์Šต ๊ณผ์ •์—๋งŒ ์ ์šฉ๋˜๋ฉฐ, Inference ๊ณผ์ •์—์„œ๋Š” ์—ฌ๋Ÿฌ Domain ์ •๋ณด๋ฅผ Mixingํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค.)
Demix Architecture
Transformer์˜ ๋ชจ๋“  FFN Layer๋“ค์„ Demix Layer๋กœ ๊ต์ฒดํ•˜๋ฉฐ, ์„ฑ๋Šฅ์— ์•…์˜ํ–ฅ์„ ์ฃผ๋Š” Shared Layer๋Š” ์•„์˜ˆ ์‚ญ์ œํ•œ๋‹ค.
Demix Layer์˜ FFN Layer๋Š” ์ผ๋ฐ˜์ ์ธ Transformer์—์„œ์™€ ๋™์ผํ•œ ์ฐจ์›์˜ 2-Layer MLP์ด๋‹ค.
Demix Training
Demix Layer๋Š” LM์˜ ์ด Parameter ์ˆ˜๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค์ง€๋งŒ, ํ•™์Šต Runtime์€ ์ฆ๊ฐ€ํ•˜์ง€ ์•Š์œผ๋ฉฐ, ์˜คํžˆ๋ ค Higher Throughput์„ ๊ฐ–๊ฒŒ ํ•œ๋‹ค.
e.g.) ์ผ๋ฐ˜์ ์œผ๋กœ PyTorch DDP(Multi-GPU)๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ๋™์ผํ•œ ๋ชจ๋ธ์„ ๋ณต์‚ฌํ•˜์—ฌ ๊ฐ GPU์— ํ• ๋‹นํ•˜๊ณ , Synchronization์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. Demix Training์—์„œ๋Š” (์ „์ฒด GPU ์ˆ˜: 32, Expert ์ˆ˜: 8 ์ผ ๋•Œ) Expert๋ฅผ n(GPU)/n(Expert)=4 ๋งŒํผ ๋ณต์‚ฌํ•˜์—ฌ ์ด๋“ค์„ Synchornizeํ•œ๋‹ค(โ‡’Throughput ์ฆ๊ฐ€). ์ด๋ฅผ ์œ„ํ•ด Batch๋„ ํ•˜๋‚˜์˜ Domain์œผ๋กœ ๊ตฌ์„ฑํ•˜๋ฉฐ, Larger Batch Size๋ฅผ ์ ์šฉํ•œ๋‹ค๊ณ  ํ•œ๋‹ค.

In-Domain Performance

Training Data์˜ Test Set์—์„œ Perplexity๋ฅผ ์ธก์ •ํ•˜๋ฉฐ, ์‹คํ—˜์— ์‚ฌ์šฉํ•œ Baseline LM์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
โ€ข
DENSE: ์ผ๋ฐ˜์ ์ธ Dense Training์œผ๋กœ ํ•™์Šต
โ€ข
DENSE (Balanced): Dense Training+Domain๋ณ„ Data์˜ ์ˆ˜๊ฐ€ ๋™์ผํ•˜๊ฒŒ ํ•™์Šต
โ€ข
+DOMAIN-TOKEN: DENSE (Balanced)+Sequence ์•ž์— Domain์„ ๋‚˜ํƒ€๋‚ด๋Š” Token์„ ๋ถ™์—ฌ ํ•™์Šต
โ€ข
DEMIX (naive): Test Data์˜ Domain์„ ์•„๋Š” ๊ฒฝ์šฐ, ํ•ด๋‹น Expert๋ฅผ ์‚ฌ์šฉ
โ€ข
DEMIX (cached): ์ถ”ํ›„์— ์„ค๋ช…ํ•  Experts-Mixing์„ In-Domain Data์— ์ ์šฉ
์‹คํ—˜ ๊ฒฐ๊ณผ๋ฅผ ์‚ดํŽด๋ณด๋ฉด ํฌ๊ธฐ๊ฐ€ ~760M์ธ LM๋“ค์—์„œ๋Š” DEMIX (naive)๊ฐ€ DENSE ๊ณ„์—ด LM๋“ค์— ๋น„ํ•ด ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ 1.3B LM์—์„œ๋Š” ์„ฑ๋Šฅ์ด ๊ฐœ์„ ๋˜์ง€ ์•Š์•˜๋Š”๋ฐ, ์ด๋Š” LM์˜ ํฌ๊ธฐ๊ฐ€ ์ปค์ง€๋ฉด Dense Training์„ ํ†ตํ•ด Domain-Specificํ•œ ์ •๋ณด๋ฅผ ์ถฉ๋ถ„ํžˆ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Œ์„ ์ฆ๋ช…ํ•œ๋‹ค.
1.3B LM์˜ 8๊ฐœ Expert๋ฅผ ๋‹ค๋ฅธ Domain์— ์ ์šฉํ–ˆ์„ ๋•Œ ๊ฒฐ๊ณผ์ด๋‹ค. Web Text์™€ Real News์˜ ๊ฒฝ์šฐ ๋‹ค๋ฅธ Domain Data์—์„œ๋„ PPL์ด ํฌ๊ฒŒ ์ฆ๊ฐ€ํ•˜์ง€ ์•Š์Œ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ฆ‰, ๋‘ Domain์€ ๋‹ค๋ฅธ Domain๊ณผ ๊ฒน์น˜๋Š” ๋ถ€๋ถ„์ด ๋งŽ๋‹ค๋Š” ๋œป์œผ๋กœ, Web Text์™€ Real News์—์„œ (๊ฒฐ๊ณผ์ ์œผ๋กœ ๋น„์Šทํ•œ Data๋ฅผ ๋” ๋งŽ์ด ํ•™์Šตํ•˜๊ฒŒ ๋œ ํ˜•ํƒœ์˜) DENSE๊ฐ€ ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ์ด์œ ๊ฐ€ ์„ค๋ช…๋œ๋‹ค.

Mixing Experts (Novel-Domain Performance)

์‹ค์ œ Inference ๊ณผ์ •์—์„œ๋Š” Input Data๊ฐ€ ํ•™์Šตํ•œ Domain๋“ค ์ค‘ ํ•˜๋‚˜์— ๋ถ„๋ช…ํ•˜๊ฒŒ ํฌํ•จ๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๊ฐ€ ์กด์žฌํ•œ๋‹ค. ์ด ๋•Œ์—๋Š” Expert๋“ค์˜ ์ •๋ณด๋ฅผ Mixingํ•˜์—ฌ Inference๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š”๋ฐ, ์ด๋ฅผ ์œ„ํ•ด Input Data๊ฐ€ ๊ธฐ์กด์˜ Domain๋“ค๊ณผ Overlap๋˜๋Š” ์ •๋„๋ฅผ ๊ณ„์‚ฐํ•  ํ•„์š”๊ฐ€ ์žˆ๋‹ค.
๋…ผ๋ฌธ์—์„œ๋Š” Inference ๊ณผ์ •์„ ์‹(4)์™€ ๊ฐ™์ด ํ‘œํ˜„ํ•˜๋Š”๋ฐ, P(D_t=j)๋Š” Timestep t์—์„œ์˜ Token์ด Domain j์— ์†ํ•  ํ™•๋ฅ ์„ ์˜๋ฏธํ•œ๋‹ค. ์‹(4)์—์„œ Likelihood๋Š” Expert๋“ค๋กœ๋ถ€ํ„ฐ ์–ป์„ ์ˆ˜ ์žˆ๊ณ , Prior๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋‚จ๋Š”๋‹ค. ์ด๋Š” ์•ž์„œ ์–ธ๊ธ‰ํ–ˆ๋˜ Routing Function(g_n)์„ ์ •์˜ํ•˜๋Š” ๊ฒƒ๊ณผ ์—ฐ๊ด€์ด ์žˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” Prior๋ฅผ ์ •์˜ (Posterior๋กœ ๋ณ€ํ™˜) ํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ ์ด 3๊ฐ€์ง€๋ฅผ ์†Œ๊ฐœํ•œ๋‹ค.
โ€ข
Uniform: Domain๋ณ„๋กœ ๋ชจ๋‘ ์ผ์ •ํ•œ ๊ฐ’์„ ๊ฐ€์ง
โ€ข
Updating: ์•ž์„  Timestep์—์„œ์˜ Posterior๊ฐ’์„ ๋ฐ˜์˜ (Moving Average)
โ€ข
Cached: ๋ฏธ๋ฆฌ ์ผ๋ถ€ Data๋ฅผ ์ถ”์ถœํ•˜์—ฌ Posterior๋ฅผ ๊ตฌํ•˜๊ณ , ์ด๋ฅผ ๊ณ ์ •
Novel Data์—์„œ PPL์„ ์ธก์ •ํ•œ ๊ฒฐ๊ณผ, Cached ๋ฐฉ์‹์œผ๋กœ Prior๋ฅผ ์ •์˜ํ•˜์—ฌ Mixingํ•  ๋•Œ ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. In-Domain ์‹คํ—˜ ๋•Œ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ Performance Gain์ด LM์˜ ํฌ๊ธฐ๊ฐ€ ์ž‘์„ ๋•Œ ๋‘๋“œ๋Ÿฌ์ง€๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์•„, DEMIX Layer ์ž์ฒด๊ฐ€ Small LM์—์„œ ์˜๋ฏธ๊ฐ€ ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

Adding+Removing Experts

์ƒˆ๋กœ์šด Domain์„ ์ถ”๊ฐ€๋กœ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒฝ์šฐ์ด๋‹ค. Dense-DAPT๋ž€ Dense Training์œผ๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์ถ”๊ฐ€ Pre-Training(DAPT)์‹œํ‚ค๋Š” ๊ฒƒ์ด๊ณ , DEMIX-DAPT๋Š” ์ƒˆ๋กœ์šด Expert๋ฅผ ๊ตฌ์ถ•ํ•˜๋Š” ๊ฒƒ์„ ๋งํ•œ๋‹ค. DEMIX-DAPT์‹œ์—๋Š” ์ƒˆ๋กœ์šด Expert๋ฅผ ์ƒˆ๋กœ์šด Domain๊ณผ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๊ธฐ์กด Domain์˜ Expert Parameter๋กœ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค. ์œ„ ์‹คํ—˜ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, ๋‘ ๊ฒฝ์šฐ ๋ชจ๋‘ Target(์ƒˆ๋กœ์šด) Domain์˜ PPL์ด ๊ฐ์†Œํ•˜์ง€๋งŒ, Dense-DAPT๋Š” ๋‹ค๋ฅธ Domain์˜ PPL์„ ์ฆ๊ฐ€์‹œํ‚ด์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๋ฐ˜๋ฉด, DEMIX-DAPT๋Š” ๋‹ค๋ฅธ Domain์˜ PPL ๋ณ€ํ™” ์—†์ด Dense-DAPT์™€ ๋งž๋จน๋Š” ํ˜น์€ ๋” ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜์˜ ํ‘œ๋Š” Experts-Mixing๋ณด๋‹ค Expert๋ฅผ ์ƒˆ๋กœ ๊ตฌ์ถ•ํ•˜๋Š” ํŽธ์ด ๋” ๋งŽ์€ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ์ด๋Œ์–ด๋ƒ„์„ ์ฆ๋ช…ํ•œ๋‹ค.
๋ฐ˜๋Œ€๋กœ ๊ธฐ์กด Domain์„ ์‚ญ์ œํ•˜๋Š” ๊ฒฝ์šฐ์ด๋‹ค. +EXPERT๋Š” ๋ชจ๋“  Expert๊ฐ€ ํ™œ์„ฑํ™” ๋œ ๊ฒฝ์šฐ, -EXPERT๋Š” ํ•ด๋‹น Expert๋ฅผ ๋น„ํ™œ์„ฑํ™” ์‹œํ‚จ ๊ฒฝ์šฐ, -DOMAIN์€ LM์„ ํ•ด๋‹น Domain Data๋ฅผ ์ œ์™ธํ•˜๊ณ  From Scratch๋กœ ํ•™์Šต์‹œํ‚จ ๊ฒฝ์šฐ์ด๋‹ค.