๐Ÿญ

(220210) Diary: Head Selection on Transformer & Variational Inference

Head Selection on Transformer using Variational Inference

(Meta, 2021) Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling

Transformer Layer์˜ Key Component๋ฅผ ๊ณ ๋ฅด์ž๋ฉด ์—ญ์‹œ Multi-Head Attention์ด ์•„๋‹๊นŒ ์‹ถ๋‹ค. ๋ฐ์ดํ„ฐ ๊ฐ„์˜ ์ƒ๊ด€์„ฑ์„ ๋‹ค์–‘ํ•œ ๊ด€์ ์—์„œ ๊ณ„์‚ฐํ•˜์—ฌ ๋ณต์žกํ•œ Data Patterns์„ ํฌ์ฐฉํ•œ๋‹ค.
(๋ณ„๊ฐœ์˜ ์ด์•ผ๊ธฐ๋กœ) NLP์—์„œ Multi-Domain (์ดํ•˜ Domain์ด๋ผ ์ž‘์„ฑํ•˜์ง€๋งŒ Lingual์˜ ์˜๋ฏธ๋„ ๋‚ดํฌํ•จ) Training์ด ๋นˆ๋ฒˆํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜๋Š”๋ฐ, Domain ๊ฐ„์˜ ์œ ์‚ฌ์„ฑ์ด ๋‚ฎ์€ ๊ฒฝ์šฐ, ๊ฐœ๋ณ„ Domain์—์„œ์˜ ์„ฑ๋Šฅ์ด ์ €ํ•˜๋˜๋Š” Negative Interference๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค.
Negative Interference์˜ ํ•œ ์›์ธ์œผ๋กœ Domain ๊ฐ„ Shared Parameters์—์„œ ๋ฐœ์ƒํ•˜๋Š” Gradients Conflict๋ฅผ ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Domain๋ณ„๋กœ ์„œ๋กœ ๋‹ค๋ฅธ Heads๋ฅผ ์„ ํƒํ•˜๋Š” Head Selection ๋ฐฉ๋ฒ•๋ก ์„ ์ œ์•ˆํ•˜๋ฉฐ ํ•ด๋‹น ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•œ๋‹ค.
(Domain Experts์™€ ๊ฐ™์€ ๋А๋‚Œ์œผ๋กœ LM์„ Modularํ•˜๊ฒŒ ํ™œ์šฉ)
์ด ๋•Œ, Domain ๋ณ„๋กœ ์–ด๋–ค Heads๋ฅผ ์„ ํƒํ•  ๊ฒƒ์ธ์ง€๋ฅผ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด Variational Inference๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
(VAE์—์„œ์˜ ๊ทธ ์นœ๊ตฌ๊ฐ€ ๋งž์Œ!)
Input Sequence๋กœ๋ถ€ํ„ฐ ํ•ด๋‹น (n)์ฐจ์›์ด n๋ฒˆ์งธ Head๋ฅผ ์„ ํƒํ•˜๋Š”์ง€๋ฅผ ํ‘œํ˜„ํ•˜๋Š”, Bernoulli ๋ถ„ํฌ์˜ Discrete Latent Variable, z๋ฅผ ๋ชจ๋ธ๋งํ•  ๋•Œ ๊ฐ Head๊ฐ€ ์„ ํƒ๋  ํ™•๋ฅ ์ด ๊ฐ™์Œ์„ ๊ฐ€์ •ํ•˜์—ฌ Inference Network๋ฅผ ํ•™์Šตํ•œ๋‹ค.
๋˜ํ•œ, ๋ชจ๋ธ์ด End-To-End ๋ฏธ๋ถ„ ๊ฐ€๋Šฅํ•˜๋„๋ก Gumbel-Softmax๋ฅผ ์ฐจ์šฉํ•œ๋‹ค.
(Tabular Data์— ๊ด€์‹ฌ ์žˆ์„ ์ ๋ถ€ํ„ฐ ๋ฆฌ๋ทฐํ•˜๋ ค ํ•˜์˜€์œผ๋‚˜, ์—ฌํƒœ๊นŒ์ง€ ํ•˜์ง€ ์•Š์€ Gumbel-Softmax.. ์˜ฌํ•ด์—๋Š” ๊ณต๋ถ€ํ•  ์˜ˆ์ •!)
Selection ๊ณผ์ •์—์„œ๋Š” ์„ ํƒ ํ™•๋ฅ ์ด ๋†’์€ Top-H๊ฐœ์˜ Heads๋ฅผ ์„ ํƒํ•˜๊ฒŒ ๋œ๋‹ค.
Heads์˜ ์ˆœ์„œ๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š๋Š” Subset Strategy, ๊ณ ๋ คํ•˜์—ฌ ์„ ํƒํ•˜๋Š” Group Strategy๊ฐ€ ์กด์žฌํ•œ๋‹ค. (์œ„ ๊ทธ๋ฆผ ์ฐธ์กฐ)
FAIR์˜ Multi-Lingual ํ˜น์€ Multi-Domain ์‹คํ—˜์€ ๊ฐœ์ธ์ ์œผ๋กœ ๋„ˆ๋ฌด ์ƒ์†Œํ•˜์—ฌ ๋ถ„์„์ด ์–ด๋ ค์›€..
์ „๋ฐ˜์ ์œผ๋กœ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ๊ฒƒ์œผ๋กœ ์ƒ๊ฐ๋˜์–ด ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ƒ๋žต!
์ด ๋…ผ๋ฌธ์„ ๋ฆฌ๋ทฐํ•œ ์ด์œ ๋Š” Variational Inference (VAE)๋ฅผ ์ฒ˜์Œ ์ ‘ํ–ˆ์„ ๋•Œ ๋งค์šฐ ํฅ๋ฏธ๋กญ๊ฒŒ ๊ณต๋ถ€ํ–ˆ๋˜ ๊ธฐ์–ต์ด ์žˆ๊ณ , ๋‹ค๋ฃจ๊ธฐ ์‰ฌ์šด ๋ถ„ํฌ์˜ Latent๋ฅผ ํ™œ์šฉํ•˜์—ฌ Output์„ ์กฐ์ ˆํ•˜๋Š” ์ ์ด ํ˜„์žฌ ์—…๋ฌด (Controlled Generation)์— ์˜๊ฐ์„ ์ฃผ์ง€ ์•Š์„๊นŒ ํ•ด์„œ์ด๋‹ค.
์˜ˆ์ƒ๋Œ€๋กœ ํฅ๋ฏธ๋กœ์› ๊ณ , ๋‹น๋ถ„๊ฐ„ ๋น„์Šทํ•œ ์—ฐ๊ตฌ๋“ค์„ ์ฐพ์•„ ๊ณต๋ถ€ํ•  ์˜ˆ์ •์ด๋‹ค.