๐Ÿง

(210905) Review: Self-Guided Contrastive Learning for BERT Sentence Representations

210905 Review Sentence Embedding Contrastive Learning
์ด ๋…ผ๋ฌธ์€ ์„œ์šธ๋Œ€, Naver๊ฐ€ ACL 2021์—์„œ ๋ฐœํ‘œํ•œ ๋…ผ๋ฌธ์ด๋‹ค. Contrastive Learning์„ ํ™œ์šฉํ•œ BERT์˜ Sentence Embedding ํ•™์Šต์ด ์ฃผ์š” ๋‚ด์šฉ์ด๋ฉฐ, ์ผ๋ฐ˜์ ์ธ Data Augmentation์ด ์•„๋‹Œ BERT์˜ ์ดˆ๊ธฐ Layers Representation์„ ์‚ฌ์šฉํ•œ ์ ์ด ํฅ๋ฏธ๋กญ๋‹ค.

Problems: BERT as Sentence Encoder

BERT๋ฅผ Sentence Encoder๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” Downstream Task๋กœ (Supervision์„ ํ™œ์šฉํ•œ) Fine-Tuning์„ ์ˆ˜ํ–‰ํ•˜์—ฌ์•ผ ํ•œ๋‹ค. ๋งŒ์•ฝ Labeled Data๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ, BERT์˜ ๋งˆ์ง€๋ง‰ Layer(s)๋ฅผ Mean Poolingํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ธ๋ฐ, ์ด๋Š” ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด์ง€ ๋ชปํ•œ๋‹ค. (์•„๋ž˜ ํ‘œ ์ฐธ์กฐ). ์•„๋ž˜์˜ ํ‘œ๋ฅผ ํ†ตํ•ด Pooling๊ธฐ๋ฒ•๊ณผ Pooling์— ์‚ฌ์šฉํ•˜๋Š” Layer์— ๋”ฐ๋ผ Sentence Embedding ์„ฑ๋Šฅ์˜ ํŽธ์ฐจ๊ฐ€ ๋งค์šฐ ํฐ ์ ์„ ์•Œ ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด๋Š” ํ˜„์žฌ BERT๋ฅผ ํ™œ์šฉํ•œ Sentence Embedding์ด ์ถฉ๋ถ„ํžˆ Solidํ•˜์ง€ ์•Š์œผ๋ฉฐ, BERT์˜ Expressive Power๋ฅผ ๋” ํ™œ์šฉํ•  ์—ฌ์ง€๊ฐ€ ๋‚จ์•„ ์žˆ์Œ์„ ์˜๋ฏธํ•œ๋‹ค.
์ €์ž๋Š” Unsupervisedํ•˜๊ฒŒ BERT๋ฅผ Sentence Encoder๋กœ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ์•ˆ์œผ๋กœ Contrastive Learning์— ์ฃผ๋ชฉํ•˜๋Š”๋ฐ, ์ผ๋ฐ˜์ ์ธ Data Augmentation์ด ์•„๋‹Œ (์œ„์˜ Intuition์— ์ฐฉ์•ˆํ•˜์—ฌ) BERT์˜ ์ดˆ๊ธฐ Layer Representation์„ ํ™œ์šฉํ•˜๋Š” ๋ฐฉ์‹์„ ์ œ์•ˆํ•œ๋‹ค.

Proposed Method: Self-Guided Contrastive Learning

1.
BERT๋ฅผ BERT_fixed, BERT_tuned 2๊ฐ€์ง€ ๋ฒ„์ „์œผ๋กœ ๋ณต์‚ฌ.
โ€ข
BERT_fixed๋Š” ํ•™์Šต ์‹œ์— Training Signal(์ดˆ๊ธฐ Layer์˜ Representation)์„ ์ œ๊ณตํ•˜๋Š” ๊ณ ์ •๋œ Parameter์˜ ๋ชจ๋ธ
โ€ข
BERT_tuned๋Š” Sentence Embedding์„ ์œ„ํ•ด Fine-Tuningํ•˜๋Š” ๋ชจ๋ธ
โ€ข
๋‘ ๋ชจ๋ธ์„ ๋ถ„๋ฆฌํ•˜๋Š” ์ด์œ ๋Š” ํ•™์Šต ๊ณผ์ •์—์„œ BERT_fixed์˜ Training Signal์˜ ์„ฑ๋Šฅ์ด ์ €ํ•˜๋˜๋Š” ํ˜„์ƒ(BERT_fixed=BERT_tuned)์„ ๋ฐฉ์ง€ํ•˜๊ณ , BERT์˜ ์—ฌ๋Ÿฌ Layer๋“ค์˜ ์ •๋ณด๋ฅผ ์ทจํ•ฉํ•˜์ž๋Š” ์ €์ž์˜ ์ฒ ํ•™์— ๋ถ€ํ•ฉํ•˜๊ธฐ ๋•Œ๋ฌธ
2.
Mini-Batch ๋ฌธ์žฅ๋“ค์˜ Hidden Representation Sampling.
โ€ข
Mini-Batch์— ํฌํ•จ๋œ b๊ฐœ ์ค‘ i๋ฒˆ์งธ ๋ฌธ์žฅ์˜ k๋ฒˆ์งธ Layer Representation: H_i,k
โ€ข
๊ฐ Layer Representation์˜ Pooling(๋…ผ๋ฌธ์—์„œ๋Š” Max Pooling) ์ˆ˜ํ–‰: h_i,k=pooling(H_i,k)
โ€ข
Pooling๋œ Representation๋“ค ์ค‘ Sampling(๋…ผ๋ฌธ์—์„œ๋Š” Uniform Sampler) ์ˆ˜ํ–‰: h_i=sampler({h_i,k | 0โ‰คkโ‰คnum(BERT_fixed's Layers)})
3.
BERT_tuned๋กœ๋ถ€ํ„ฐ Sentence Embedding ์ถ”์ถœ.
โ€ข
๋งˆ์ง€๋ง‰ Layer์˜ CLS Token๋งŒ์„ Sentence Embedding์œผ๋กœ ์‚ฌ์šฉ: c_i
โ€ข
X={x | {c_i} U {h_i}}
4.
Loss(NT-Xent Loss) ๊ณ„์‚ฐ!
โ€ข
Sampling๋œ ์ดˆ๊ธฐ Layer Representation, h_i์™€ Sentence Embedding, c_i๋Š” ๊ฐ๊ฐ Projection Head(f)๋ฅผ ๊ฑฐ์นœ ํ›„, ๋‹ค๋ฅธ ๋ฒกํ„ฐ๋“ค๊ณผ์˜ Cosine Similarity(g) ๊ณ„์‚ฐ์— ํ™œ์šฉ๋œ๋‹ค.
โ€ข
ํ•™์Šต์€ ๊ฐ™์€ ๋ฌธ์žฅ์˜ h_i์™€ c_i์˜ ์œ ์‚ฌ๋„๊ฐ€ ๋‹ค๋ฅธ ๋ฌธ์žฅ๋“ค๊ณผ์˜ ์œ ์‚ฌ๋„๋ณด๋‹ค ํฐ ๊ฐ’์„ ๊ฐ–๋„๋ก ์ˆ˜ํ–‰๋œ๋‹ค: L^base.
โ€ข
์ตœ์ข… Loss์˜ ํ˜•ํƒœ๋Š” ์œ„์™€ ๊ฐ™์œผ๋ฉฐ, Regularizationํ•ญ์€ BERT_fixed์™€ BERT_tuned๊ฐ€ ๋„ˆ๋ฌด ๋‹ค๋ฅธ ๊ฐ’์„ ๊ฐ–์ง€ ์•Š๋„๋ก ์กฐ์ •ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.
5.
Learning Objective Optimization!
โ€ข
L^base๋Š” ๋‘ ๋ฌธ์žฅ s_i, s_j์˜ ์œ ์‚ฌ์„ฑ์„ 4๊ฐ€์ง€ ์š”์†Œ๋“ค์„ ๊ณ ๋ คํ•˜์—ฌ ์ •์˜ํ•œ๋‹ค.
(a) c_i โ†’โ† h_i: (๋™์ผ ๋ฌธ์žฅ) Sentence Embedding๊ณผ ์ดˆ๊ธฐ Layer Representation์˜ ์œ ์‚ฌ์„ฑ.
(b) c_i โ†โ†’ c_j
(c) c_i โ†โ†’ h_j
(d) h_i โ†โ†’ h_j
โ€ข
์ €์ž๋Š” (a)์š”์†Œ๋งŒ์ด ํ•„์ˆ˜์ ์ด๊ณ , ํ•˜์—ฌ (a)์— ํŠนํžˆ ์ง‘์ค‘ํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค๋ฅธ ์š”์†Œ๋“ค์„ ์ œ๊ฑฐํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ L^base๋ฅผ ์ˆ˜์ •ํ•œ๋‹ค.
โ€ข
Option1: (d)์ œ๊ฑฐ.
โ€ข
Option2: (b)์ œ๊ฑฐ.
โ€ข
Option3: c_i์™€ h_i(or h_j)๋ฅผ ๋” ๋งŽ์€ ๊ด€์ (๋‹ค์–‘ํ•œor๋ณต์ˆ˜์˜ ์ดˆ๊ธฐ Layer๋“ค์˜ Representation)์—์„œ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•จ.
โ€ข
์ตœ์ข… Loss์˜ ํ˜•ํƒœ๋Š” ์œ„์™€ ๊ฐ™๋‹ค: L^opt.

Experiments & Results

์‹คํ—˜์€ 7๊ฐ€์ง€ STS Datasets์—์„œ ์ˆ˜ํ–‰๋˜๋ฉฐ, BERT+SBERT(-base, -large)์— ๊ธฐ์กด Sentence Embedding Method๋“ค๊ณผ ์ œ์•ˆ ๊ธฐ๋ฒ•์„ ์ ์šฉํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ตํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•œ๋‹ค. Baseline์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” Method๋“ค์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒƒ๋“ค์ด ์žˆ์œผ๋ฉฐ,
โ€ข
CLS Pooling
โ€ข
Mean Pooling
โ€ข
WK Pooling
โ€ข
Flow
โ€ข
Contrastive (BT: Back Translation)
์ œ์•ˆ ๊ธฐ๋ฒ•์—๋Š” L^base๋กœ ํ•™์Šตํ•œ SG, L^opt๋กœ ํ•™์Šตํ•œ SG-OPT๊ฐ€ ์žˆ๋‹ค.
์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” ์œ„ ํ‘œ์™€ ๊ฐ™์€๋ฐ, SBERT์—์„œ์˜ ์ผ๋ถ€ Case๋ฅผ ์ œ์™ธํ•˜๋ฉด, ์ œ์•ˆ ๊ธฐ๋ฒ•์ด Baseline Methods์— ๋น„ํ•ด ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ํŠนํžˆ, SG-OPT๊ฐ€ SG์˜ ์„ฑ๋Šฅ์„ ๋Šฅ๊ฐ€ํ•จ์œผ๋กœ์จ Learning Objective Optimization์˜ ํšจ๊ณผ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
Learning Objective Optimization์˜ ํšจ๊ณผ๋Š” Ablation Study๋ฅผ ํ†ตํ•ด์„œ๋„ ์ฆ๋ช…๋œ๋‹ค. ์œ„์˜ ํ‘œ ์ƒ๋‹จ์— Loss๋ฅผ ๋ณ€๊ฒฝํ•˜๋ฉฐ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•œ ๊ฒฐ๊ณผ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ๋‹ค. ํ‘œ ํ•˜๋‹จ์„ ํ†ตํ•ด Hyperparameter๊ฐ’๋“ค๋„ ์„ฑ๋Šฅ์— ํฐ ์˜ํ–ฅ์„ ๋ฏธ์นจ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. Projection Head์˜ ์œ ๋ฌด๊ฐ€ ์„ฑ๋Šฅ์— ํฐ ์ฐจ์ด๋ฅผ ๋งŒ๋“œ๋Š” ์ ์ด ์ฃผ๋ชฉํ•  ๋งŒํ•˜๋‹ค.
์œ„ ๊ทธ๋ž˜ํ”„๋Š” BERT-base ๋ชจ๋ธ์„ NLI(or STS-B) Dataset์œผ๋กœ ํ•™์Šต์‹œํ‚ค๊ณ  STS Datasets์—์„œ ํ…Œ์ŠคํŒ… ํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋‹ค. SG-OPT๊ฐ€ Flow์— ๋น„ํ•ด ๊ธฐ๋ณธ์ ์ธ ์„ฑ๋Šฅ๋„ ์ข‹๊ณ , In-Domain & Out-of-Domain ํ•™์Šต์˜ ์„ฑ๋Šฅ ์ฐจ์ด๊ฐ€ ์ž‘์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค: Robust to Domain Shifts.