210905 Review Sentence Embedding Contrastive Learning
์ด ๋
ผ๋ฌธ์ ์์ธ๋, Naver๊ฐ ACL 2021์์ ๋ฐํํ ๋
ผ๋ฌธ์ด๋ค. Contrastive Learning์ ํ์ฉํ BERT์ Sentence Embedding ํ์ต์ด ์ฃผ์ ๋ด์ฉ์ด๋ฉฐ, ์ผ๋ฐ์ ์ธ Data Augmentation์ด ์๋ BERT์ ์ด๊ธฐ Layers Representation์ ์ฌ์ฉํ ์ ์ด ํฅ๋ฏธ๋กญ๋ค.
Problems: BERT as Sentence Encoder
BERT๋ฅผ Sentence Encoder๋ก ์ฌ์ฉํ๊ธฐ ์ํด์๋ Downstream Task๋ก (Supervision์ ํ์ฉํ) Fine-Tuning์ ์ํํ์ฌ์ผ ํ๋ค. ๋ง์ฝ Labeled Data๊ฐ ์กด์ฌํ์ง ์๋ ๊ฒฝ์ฐ, BERT์ ๋ง์ง๋ง Layer(s)๋ฅผ Mean Poolingํ์ฌ ์ฌ์ฉํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ธ๋ฐ, ์ด๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด์ง ๋ชปํ๋ค. (์๋ ํ ์ฐธ์กฐ). ์๋์ ํ๋ฅผ ํตํด Pooling๊ธฐ๋ฒ๊ณผ Pooling์ ์ฌ์ฉํ๋ Layer์ ๋ฐ๋ผ Sentence Embedding ์ฑ๋ฅ์ ํธ์ฐจ๊ฐ ๋งค์ฐ ํฐ ์ ์ ์ ์ ์๋๋ฐ, ์ด๋ ํ์ฌ BERT๋ฅผ ํ์ฉํ Sentence Embedding์ด ์ถฉ๋ถํ Solidํ์ง ์์ผ๋ฉฐ, BERT์ Expressive Power๋ฅผ ๋ ํ์ฉํ ์ฌ์ง๊ฐ ๋จ์ ์์์ ์๋ฏธํ๋ค.
์ ์๋ Unsupervisedํ๊ฒ BERT๋ฅผ Sentence Encoder๋ก ํ์ต์ํค๋ ๋ฐฉ์์ผ๋ก Contrastive Learning์ ์ฃผ๋ชฉํ๋๋ฐ, ์ผ๋ฐ์ ์ธ Data Augmentation์ด ์๋ (์์ Intuition์ ์ฐฉ์ํ์ฌ) BERT์ ์ด๊ธฐ Layer Representation์ ํ์ฉํ๋ ๋ฐฉ์์ ์ ์ํ๋ค.
Proposed Method: Self-Guided Contrastive Learning
1.
BERT๋ฅผ BERT_fixed, BERT_tuned 2๊ฐ์ง ๋ฒ์ ์ผ๋ก ๋ณต์ฌ.
โข
BERT_fixed๋ ํ์ต ์์ Training Signal(์ด๊ธฐ Layer์ Representation)์ ์ ๊ณตํ๋ ๊ณ ์ ๋ Parameter์ ๋ชจ๋ธ
โข
BERT_tuned๋ Sentence Embedding์ ์ํด Fine-Tuningํ๋ ๋ชจ๋ธ
โข
๋ ๋ชจ๋ธ์ ๋ถ๋ฆฌํ๋ ์ด์ ๋ ํ์ต ๊ณผ์ ์์ BERT_fixed์ Training Signal์ ์ฑ๋ฅ์ด ์ ํ๋๋ ํ์(BERT_fixed=BERT_tuned)์ ๋ฐฉ์งํ๊ณ , BERT์ ์ฌ๋ฌ Layer๋ค์ ์ ๋ณด๋ฅผ ์ทจํฉํ์๋ ์ ์์ ์ฒ ํ์ ๋ถํฉํ๊ธฐ ๋๋ฌธ
2.
Mini-Batch ๋ฌธ์ฅ๋ค์ Hidden Representation Sampling.
โข
Mini-Batch์ ํฌํจ๋ b๊ฐ ์ค i๋ฒ์งธ ๋ฌธ์ฅ์ k๋ฒ์งธ Layer Representation: H_i,k
โข
๊ฐ Layer Representation์ Pooling(๋
ผ๋ฌธ์์๋ Max Pooling) ์ํ: h_i,k=pooling(H_i,k)
โข
Pooling๋ Representation๋ค ์ค Sampling(๋
ผ๋ฌธ์์๋ Uniform Sampler) ์ํ: h_i=sampler({h_i,k | 0โคkโคnum(BERT_fixed's Layers)})
3.
BERT_tuned๋ก๋ถํฐ Sentence Embedding ์ถ์ถ.
โข
๋ง์ง๋ง Layer์ CLS Token๋ง์ Sentence Embedding์ผ๋ก ์ฌ์ฉ: c_i
โข
X={x | {c_i} U {h_i}}
4.
Loss(NT-Xent Loss) ๊ณ์ฐ!
โข
Sampling๋ ์ด๊ธฐ Layer Representation, h_i์ Sentence Embedding, c_i๋ ๊ฐ๊ฐ Projection Head(f)๋ฅผ ๊ฑฐ์น ํ, ๋ค๋ฅธ ๋ฒกํฐ๋ค๊ณผ์ Cosine Similarity(g) ๊ณ์ฐ์ ํ์ฉ๋๋ค.
โข
ํ์ต์ ๊ฐ์ ๋ฌธ์ฅ์ h_i์ c_i์ ์ ์ฌ๋๊ฐ ๋ค๋ฅธ ๋ฌธ์ฅ๋ค๊ณผ์ ์ ์ฌ๋๋ณด๋ค ํฐ ๊ฐ์ ๊ฐ๋๋ก ์ํ๋๋ค: L^base.
โข
์ต์ข
Loss์ ํํ๋ ์์ ๊ฐ์ผ๋ฉฐ, Regularizationํญ์ BERT_fixed์ BERT_tuned๊ฐ ๋๋ฌด ๋ค๋ฅธ ๊ฐ์ ๊ฐ์ง ์๋๋ก ์กฐ์ ํ๋ ์ญํ ์ ํ๋ค.
5.
Learning Objective Optimization!
โข
L^base๋ ๋ ๋ฌธ์ฅ s_i, s_j์ ์ ์ฌ์ฑ์ 4๊ฐ์ง ์์๋ค์ ๊ณ ๋ คํ์ฌ ์ ์ํ๋ค.
(a) c_i โโ h_i: (๋์ผ ๋ฌธ์ฅ) Sentence Embedding๊ณผ ์ด๊ธฐ Layer Representation์ ์ ์ฌ์ฑ.
(b) c_i โโ c_j
(c) c_i โโ h_j
(d) h_i โโ h_j
โข
์ ์๋ (a)์์๋ง์ด ํ์์ ์ด๊ณ , ํ์ฌ (a)์ ํนํ ์ง์คํ๊ธฐ ์ํด ๋ค๋ฅธ ์์๋ค์ ์ ๊ฑฐํ๋ ๋ฐฉํฅ์ผ๋ก L^base๋ฅผ ์์ ํ๋ค.
โข
Option1: (d)์ ๊ฑฐ.
โข
Option2: (b)์ ๊ฑฐ.
โข
Option3: c_i์ h_i(or h_j)๋ฅผ ๋ ๋ง์ ๊ด์ (๋ค์ํor๋ณต์์ ์ด๊ธฐ Layer๋ค์ Representation)์์ ๋น๊ตํ๊ธฐ ์ํจ.
โข
์ต์ข
Loss์ ํํ๋ ์์ ๊ฐ๋ค: L^opt.
Experiments & Results
์คํ์ 7๊ฐ์ง STS Datasets์์ ์ํ๋๋ฉฐ, BERT+SBERT(-base, -large)์ ๊ธฐ์กด Sentence Embedding Method๋ค๊ณผ ์ ์ ๊ธฐ๋ฒ์ ์ ์ฉํ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ์ฌ ์ฑ๋ฅ์ ํ๊ฐํ๋ค. Baseline์ผ๋ก ์ฌ์ฉํ๋ Method๋ค์๋ ๋ค์๊ณผ ๊ฐ์ ๊ฒ๋ค์ด ์์ผ๋ฉฐ,
โข
CLS Pooling
โข
Mean Pooling
โข
WK Pooling
โข
Flow
โข
Contrastive (BT: Back Translation)
์ ์ ๊ธฐ๋ฒ์๋ L^base๋ก ํ์ตํ SG, L^opt๋ก ํ์ตํ SG-OPT๊ฐ ์๋ค.
์คํ ๊ฒฐ๊ณผ๋ ์ ํ์ ๊ฐ์๋ฐ, SBERT์์์ ์ผ๋ถ Case๋ฅผ ์ ์ธํ๋ฉด, ์ ์ ๊ธฐ๋ฒ์ด Baseline Methods์ ๋นํด ์ข์ ์ฑ๋ฅ์ ๋ณด์์ ์ ์ ์๋ค. ํนํ, SG-OPT๊ฐ SG์ ์ฑ๋ฅ์ ๋ฅ๊ฐํจ์ผ๋ก์จ Learning Objective Optimization์ ํจ๊ณผ๋ฅผ ํ์ธํ ์ ์๋ค.
Learning Objective Optimization์ ํจ๊ณผ๋ Ablation Study๋ฅผ ํตํด์๋ ์ฆ๋ช
๋๋ค. ์์ ํ ์๋จ์ Loss๋ฅผ ๋ณ๊ฒฝํ๋ฉฐ ํ์ต์ ์ํํ ๊ฒฐ๊ณผ๊ฐ ํฌํจ๋์ด ์๋ค. ํ ํ๋จ์ ํตํด Hyperparameter๊ฐ๋ค๋ ์ฑ๋ฅ์ ํฐ ์ํฅ์ ๋ฏธ์นจ์ ์ ์ ์๋ค. Projection Head์ ์ ๋ฌด๊ฐ ์ฑ๋ฅ์ ํฐ ์ฐจ์ด๋ฅผ ๋ง๋๋ ์ ์ด ์ฃผ๋ชฉํ ๋งํ๋ค.
์ ๊ทธ๋ํ๋ BERT-base ๋ชจ๋ธ์ NLI(or STS-B) Dataset์ผ๋ก ํ์ต์ํค๊ณ STS Datasets์์ ํ
์คํ
ํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค. SG-OPT๊ฐ Flow์ ๋นํด ๊ธฐ๋ณธ์ ์ธ ์ฑ๋ฅ๋ ์ข๊ณ , In-Domain & Out-of-Domain ํ์ต์ ์ฑ๋ฅ ์ฐจ์ด๊ฐ ์์์ ํ์ธํ ์ ์๋ค: Robust to Domain Shifts.