210829 Review Domain Adaptation
๋ณธ ๋
ผ๋ฌธ์ ์๋
์ฆ์์ ํฅ๋ฏธ๋กญ๊ฒ ์ฝ์๋ DAPT/TAPT ๋
ผ๋ฌธ ์ ์์ ์ ์์ด๋ค. ์ญ์๋ Language Model์ Domain Adaptation์ ๊ดํ ์ฐ๊ตฌ์ด๋ฉฐ, Domain๋ณ๋ก Expert๋ฅผ ๊ตฌ์ถํ๊ณ Mixture Model๋ก ํ์ฉํ๋ ๋ด์ฉ์ด๋ค.
Domain Adaptation
๋ง์ ์์, General Data๋ก ํ์ต๋ Language Model(LM)๋ค์ ํน์ Domain์ Task์์ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด์ง ๋ชปํ๋ค. ํ์ฌ, ํด๋น Domain Data๋ก LM์ ์ถ๊ฐ์ ์ธ ํ์ต(Domain Adaptation)์ ์ํํ๋๋ฐ, (๋
ผ๋ฌธ์์ Dense Training์ด๋ผ๊ณ ์นญํ๋) ์ผ๋ฐ์ ์ธ ๋ฐฉ๋ฒ์ LM์ ๋ชจ๋ Parameter๋ค์ ๋ชจ๋ Domain Data์์ Loss๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก Updateํ๋ ๊ฒ์ด๋ค. ํ์ง๋ง ์ด๋ฌํ ๋ฐฉ๋ฒ์ ๋ช ๊ฐ์ง ๋ฌธ์ ๋ฅผ ๊ฐ๋๋ค.
โข
LM์ ํฌ๊ธฐ๊ฐ ์ปค์ง์๋ก Computational Cost๊ฐ ๋น์ธ์ง๋ค.
โข
์๋ก์ด Domain์ Data๋ก ํ์ต์ ์ํํ๋ฉด ์ด์ Domain์ ์ ๋ณด๋ฅผ ์๋ ๋ฑ, ๋ค๋ฅธ Domain์์์ Model Performance๊ฐ ์ ํ๋ ์ ์๋ค.
๋
ผ๋ฌธ์์๋ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Domain๋ณ๋ก Specialized Components(Layers)๋ฅผ ๊ฐ๋ Modularํ LM์ ์ ์ํ๋ค. Modular LM์ ํน์ง์ผ๋ก๋ ๋ค์๊ณผ ๊ฐ์ ์ ๋ค์ด ์๋ค.
โข
์๋ก์ด Domain์ ํ์ตํ ๋ Shared Parameter ์ ์ธ, Specialized Parameter๋ค๋ง Update๋๋ฏ๋ก Computationally ํจ์จ์ ์ด๋ฉฐ, ๋ค๋ฅธ Domain์ Performance์ ๋ณํ๊ฐ ์๊ธฐ์ง ์๋๋ค.
โข
์ด๋ ๊ฒ ํ์ตํ Domain๋ณ Layer๋ค์ Expert๋ผ๊ณ ์นญํ๋ค.
โข
Expert๋ GPT-3์ ๊ฐ์ Transformer LM์ Feed-Forward Network(FFN) Component๋ฅผ ๋ณํํ ๊ฒ์ด๋ค.
โข
Inference์์๋ ์ฌ๋ฌ Expert๋ค์ Mixingํ๋ ๋ฑ Customizing์ด ๊ฐ๋ฅํ๋ค.
Multi-Domain Corpus
Training Data์ Inference์์ ์ฌ์ฉํ๋ ์๋ก์ด Domain์ Data: Novel Data๋ฅผ ๊ตฌ์ฑํ๋ Corpus์ ์ ๋ณด๋ฅผ ๋ํ๋ธ ํ์ด๋ค.
ํ์ต ๊ณผ์ ์์ Training Data๋ก ์ด 8๊ฐ์ Expert๋ฅผ ๊ตฌ์ถํ๊ณ , ์คํ์์๋ 8์ข
๋ฅ์ In-Domain Data์ 8์ข
๋ฅ์ Out-Of-Domain(Novel) Data๋ก LM์ (Generalization)์ฑ๋ฅ์ ํ๊ฐํ๋ค.
Proposed Model(Expert): DEMIX Layer
Demix Routing
๊ธฐ์กด์ Mixture-Of-Experts Transformer์์ FFN Components๋ (ํน์ Domain์ผ๋ก์) Routing Function(g_n)๊ณผ Domain๋ณ FFN Layer๋ก ๊ตฌ์ฑ๋๋ค.
DEMIX Layer๋ ์ Component์ ๋์ผํ ํํ๋ฅผ ๊ฐ์ง๋ง, Token Level์์ Routing์ด ์ด๋ฃจ์ด์ง๋ฉฐ+Token๋น ์ต๋ 2๊ฐ์ Domain์ด ๋ถ์ฌ๋๋ ๊ธฐ์กด์ ๋ฐฉ์๊ณผ ๋ค๋ฅด๊ฒ, ๊ฐ์ Sequence์ ์กด์ฌํ๋ Token๋ค์ ๋ชจ๋ ๋์ผํ ํ๋์ Domain์ ๋ถ์ฌํ๋ค. (Domain๋ณ Metadata ํ์ฉ).
(
์ค์: Sequence ์ ์ฒด์ ํ๋์ Domain์ ๋ถ์ฌํ๋ ๊ฒ์ ํ์ต ๊ณผ์ ์๋ง ์ ์ฉ๋๋ฉฐ, Inference ๊ณผ์ ์์๋ ์ฌ๋ฌ Domain ์ ๋ณด๋ฅผ Mixingํ์ฌ ์ฌ์ฉํ๋ค.)
Demix Architecture
Transformer์ ๋ชจ๋ FFN Layer๋ค์ Demix Layer๋ก ๊ต์ฒดํ๋ฉฐ, ์ฑ๋ฅ์ ์
์ํฅ์ ์ฃผ๋ Shared Layer๋ ์์ ์ญ์ ํ๋ค.
Demix Layer์ FFN Layer๋ ์ผ๋ฐ์ ์ธ Transformer์์์ ๋์ผํ ์ฐจ์์ 2-Layer MLP์ด๋ค.
Demix Training
Demix Layer๋ LM์ ์ด Parameter ์๋ฅผ ์ฆ๊ฐ์ํค์ง๋ง, ํ์ต Runtime์ ์ฆ๊ฐํ์ง ์์ผ๋ฉฐ, ์คํ๋ ค Higher Throughput์ ๊ฐ๊ฒ ํ๋ค.
e.g.) ์ผ๋ฐ์ ์ผ๋ก PyTorch DDP(Multi-GPU)๋ฅผ ์ฌ์ฉํ ๋ ๋์ผํ ๋ชจ๋ธ์ ๋ณต์ฌํ์ฌ ๊ฐ GPU์ ํ ๋นํ๊ณ , Synchronization์ ์ํํ๋ค. Demix Training์์๋ (์ ์ฒด GPU ์: 32, Expert ์: 8 ์ผ ๋) Expert๋ฅผ n(GPU)/n(Expert)=4 ๋งํผ ๋ณต์ฌํ์ฌ ์ด๋ค์ Synchornizeํ๋ค(โThroughput ์ฆ๊ฐ). ์ด๋ฅผ ์ํด Batch๋ ํ๋์ Domain์ผ๋ก ๊ตฌ์ฑํ๋ฉฐ, Larger Batch Size๋ฅผ ์ ์ฉํ๋ค๊ณ ํ๋ค.
In-Domain Performance
Training Data์ Test Set์์ Perplexity๋ฅผ ์ธก์ ํ๋ฉฐ, ์คํ์ ์ฌ์ฉํ Baseline LM์ ๋ค์๊ณผ ๊ฐ๋ค.
โข
DENSE: ์ผ๋ฐ์ ์ธ Dense Training์ผ๋ก ํ์ต
โข
DENSE (Balanced): Dense Training+Domain๋ณ Data์ ์๊ฐ ๋์ผํ๊ฒ ํ์ต
โข
+DOMAIN-TOKEN: DENSE (Balanced)+Sequence ์์ Domain์ ๋ํ๋ด๋ Token์ ๋ถ์ฌ ํ์ต
โข
DEMIX (naive): Test Data์ Domain์ ์๋ ๊ฒฝ์ฐ, ํด๋น Expert๋ฅผ ์ฌ์ฉ
โข
DEMIX (cached): ์ถํ์ ์ค๋ช
ํ Experts-Mixing์ In-Domain Data์ ์ ์ฉ
์คํ ๊ฒฐ๊ณผ๋ฅผ ์ดํด๋ณด๋ฉด ํฌ๊ธฐ๊ฐ ~760M์ธ LM๋ค์์๋ DEMIX (naive)๊ฐ DENSE ๊ณ์ด LM๋ค์ ๋นํด ์ข์ ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. ํ์ง๋ง 1.3B LM์์๋ ์ฑ๋ฅ์ด ๊ฐ์ ๋์ง ์์๋๋ฐ, ์ด๋ LM์ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ฉด Dense Training์ ํตํด Domain-Specificํ ์ ๋ณด๋ฅผ ์ถฉ๋ถํ ํ์ตํ ์ ์์์ ์ฆ๋ช
ํ๋ค.
1.3B LM์ 8๊ฐ Expert๋ฅผ ๋ค๋ฅธ Domain์ ์ ์ฉํ์ ๋ ๊ฒฐ๊ณผ์ด๋ค. Web Text์ Real News์ ๊ฒฝ์ฐ ๋ค๋ฅธ Domain Data์์๋ PPL์ด ํฌ๊ฒ ์ฆ๊ฐํ์ง ์์์ ์ ์ ์๋ค. ์ฆ, ๋ Domain์ ๋ค๋ฅธ Domain๊ณผ ๊ฒน์น๋ ๋ถ๋ถ์ด ๋ง๋ค๋ ๋ป์ผ๋ก, Web Text์ Real News์์ (๊ฒฐ๊ณผ์ ์ผ๋ก ๋น์ทํ Data๋ฅผ ๋ ๋ง์ด ํ์ตํ๊ฒ ๋ ํํ์) DENSE๊ฐ ๊ฐ์ฅ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด๋ ์ด์ ๊ฐ ์ค๋ช
๋๋ค.
Mixing Experts (Novel-Domain Performance)
์ค์ Inference ๊ณผ์ ์์๋ Input Data๊ฐ ํ์ตํ Domain๋ค ์ค ํ๋์ ๋ถ๋ช
ํ๊ฒ ํฌํจ๋์ง ์๋ ๊ฒฝ์ฐ๊ฐ ์กด์ฌํ๋ค. ์ด ๋์๋ Expert๋ค์ ์ ๋ณด๋ฅผ Mixingํ์ฌ Inference๋ฅผ ์ํํ๋๋ฐ, ์ด๋ฅผ ์ํด Input Data๊ฐ ๊ธฐ์กด์ Domain๋ค๊ณผ Overlap๋๋ ์ ๋๋ฅผ ๊ณ์ฐํ ํ์๊ฐ ์๋ค.
๋
ผ๋ฌธ์์๋ Inference ๊ณผ์ ์ ์(4)์ ๊ฐ์ด ํํํ๋๋ฐ, P(D_t=j)๋ Timestep t์์์ Token์ด Domain j์ ์ํ ํ๋ฅ ์ ์๋ฏธํ๋ค. ์(4)์์ Likelihood๋ Expert๋ค๋ก๋ถํฐ ์ป์ ์ ์๊ณ , Prior๋ฅผ ๊ณ์ฐํด์ผ ํ๋ ๋ฌธ์ ๊ฐ ๋จ๋๋ค. ์ด๋ ์์ ์ธ๊ธํ๋ Routing Function(g_n)์ ์ ์ํ๋ ๊ฒ๊ณผ ์ฐ๊ด์ด ์๋ค. ๋
ผ๋ฌธ์์๋ Prior๋ฅผ ์ ์ (Posterior๋ก ๋ณํ) ํ๋ ๋ฐฉ๋ฒ์ผ๋ก ์ด 3๊ฐ์ง๋ฅผ ์๊ฐํ๋ค.
โข
Uniform: Domain๋ณ๋ก ๋ชจ๋ ์ผ์ ํ ๊ฐ์ ๊ฐ์ง
โข
Updating: ์์ Timestep์์์ Posterior๊ฐ์ ๋ฐ์ (Moving Average)
โข
Cached: ๋ฏธ๋ฆฌ ์ผ๋ถ Data๋ฅผ ์ถ์ถํ์ฌ Posterior๋ฅผ ๊ตฌํ๊ณ , ์ด๋ฅผ ๊ณ ์
Novel Data์์ PPL์ ์ธก์ ํ ๊ฒฐ๊ณผ, Cached ๋ฐฉ์์ผ๋ก Prior๋ฅผ ์ ์ํ์ฌ Mixingํ ๋ ๊ฐ์ฅ ์ข์ ์ฑ๋ฅ์ ๋ณด์์ ์ ์ ์๋ค. In-Domain ์คํ ๋์ ๋ง์ฐฌ๊ฐ์ง๋ก Performance Gain์ด LM์ ํฌ๊ธฐ๊ฐ ์์ ๋ ๋๋๋ฌ์ง๋ ๊ฒ์ผ๋ก ๋ณด์, DEMIX Layer ์์ฒด๊ฐ Small LM์์ ์๋ฏธ๊ฐ ์๋ค๊ณ ์๊ฐํ ์ ์๊ฒ ๋ค.
Adding+Removing Experts
์๋ก์ด Domain์ ์ถ๊ฐ๋ก ํ์ต์ํค๋ ๊ฒฝ์ฐ์ด๋ค. Dense-DAPT๋ Dense Training์ผ๋ก ํ์ต๋ ๋ชจ๋ธ์ ์ถ๊ฐ Pre-Training(DAPT)์ํค๋ ๊ฒ์ด๊ณ , DEMIX-DAPT๋ ์๋ก์ด Expert๋ฅผ ๊ตฌ์ถํ๋ ๊ฒ์ ๋งํ๋ค. DEMIX-DAPT์์๋ ์๋ก์ด Expert๋ฅผ ์๋ก์ด Domain๊ณผ ๊ฐ์ฅ ์ ์ฌํ ๊ธฐ์กด Domain์ Expert Parameter๋ก ์ด๊ธฐํํ๋ค. ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, ๋ ๊ฒฝ์ฐ ๋ชจ๋ Target(์๋ก์ด) Domain์ PPL์ด ๊ฐ์ํ์ง๋ง, Dense-DAPT๋ ๋ค๋ฅธ Domain์ PPL์ ์ฆ๊ฐ์ํด์ ์ ์ ์๋ค. ๋ฐ๋ฉด, DEMIX-DAPT๋ ๋ค๋ฅธ Domain์ PPL ๋ณํ ์์ด Dense-DAPT์ ๋ง๋จน๋ ํน์ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์์ ํ์ธํ ์ ์๋ค. ์๋์ ํ๋ Experts-Mixing๋ณด๋ค Expert๋ฅผ ์๋ก ๊ตฌ์ถํ๋ ํธ์ด ๋ ๋ง์ ์ฑ๋ฅ ํฅ์์ ์ด๋์ด๋์ ์ฆ๋ช
ํ๋ค.
๋ฐ๋๋ก ๊ธฐ์กด Domain์ ์ญ์ ํ๋ ๊ฒฝ์ฐ์ด๋ค. +EXPERT๋ ๋ชจ๋ Expert๊ฐ ํ์ฑํ ๋ ๊ฒฝ์ฐ, -EXPERT๋ ํด๋น Expert๋ฅผ ๋นํ์ฑํ ์ํจ ๊ฒฝ์ฐ, -DOMAIN์ LM์ ํด๋น Domain Data๋ฅผ ์ ์ธํ๊ณ From Scratch๋ก ํ์ต์ํจ ๊ฒฝ์ฐ์ด๋ค.















