๐Ÿญ

(211002) Diary: Controlled Text Generation Feat. CTRL & PPLM

211002 Diary GPT
์–ผ๋งˆ ์ „์— ๋ฆฌ๋ทฐํ–ˆ๋˜ Codex๋ฅผ ๋น„๋กฏํ•œ Transformer Decoder ๊ธฐ๋ฐ˜์˜ Large LM(s)์€ ์‚ฌ๋žŒ์ด ์ž‘์„ฑํ•˜๋Š” ๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ์ˆ˜์ค€์˜ Text Generation์ด ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ํ•œ๋‹ค. ํ•˜์ง€๋งŒ, LM์ด ํŠน์ • Domain, Style ํ˜น์€ Sentiment์˜ Text๋ฅผ ์ƒ์„ฑํ•˜๋„๋ก ์กฐ์ ˆํ•˜๋Š” ๊ฒƒ์€ ๊ฐ„๋‹จํ•œ ์ผ์ด ์•„๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ์›ํ•˜๋Š” ์กฐ๊ฑด์„ LM์— ๋ช…์‹œ์ ์œผ๋กœ (Special Token ๋“ฑ์„ Prompt์™€ ํ•จ๊ป˜) ์ „๋‹ฌํ•˜๋Š” ๋ฐฉ์‹์„ ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, LM์˜ ์ถ”๊ฐ€์ ์ธ ํ•™์Šต์ด ํ•„์š”ํ•˜๋‹ค๋Š” ์ ์—์„œ ํšจ์œจ์ ์ด์ง€ ์•Š๋‹ค. ์ตœ๊ทผ์—๋Š” ๊ณผ๊ฑฐ ๋ฆฌ๋ทฐํ–ˆ๋˜ ๋…ผ๋ฌธ๊ณผ ๋น„์Šทํ•˜๊ฒŒ Pre-Trained LM์˜ ์ถ”๊ฐ€ Parameter Update ์—†์ด ์›ํ•˜๋Š” ํŠน์„ฑ์˜ Text๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์—ฐ๊ตฌ๋“ค์ด ์ค‘์ ์ ์œผ๋กœ ์ˆ˜ํ–‰๋˜๋Š” ๊ฒƒ ๊ฐ™๋‹ค. ๋ณธ์ธ์€ ์ง€๊ธˆ๊ป ์ฃผ๋กœ Transformer Encoder ๊ณ„์—ด์˜ Denoising Auto-Encoding LM(s)๋งŒ์„ ์‚ฌ์šฉํ–ˆ๊ธฐ์—, ์ด๋ฒˆ ๊ธฐํšŒ์— Text Generation์„ ์ œ๋Œ€๋กœ(?) ๊ณต๋ถ€ํ•˜๋ ค ํ•œ๋‹ค. ๊ทธ๋Ÿฐ ์˜๋ฏธ์—์„œ ๊ฐ€์žฅ ๋Œ€ํ‘œ์ ์ธ 2ํŽธ์˜ ๋…ผ๋ฌธ์„ ์ฝ๊ณ  ๊ฐ„๋žตํžˆ ์ •๋ฆฌํ•ด๋ณธ๋‹ค..

(Salesforce, 2019) CTRL: A Conditional Transformer Language Model for Controllable Generation

CTRL์€ ์ฃผ์–ด์ง„ Domain, Style ๊ทธ๋ฆฌ๊ณ  ์งˆ์˜์‘๋‹ต, ๋ฒˆ์—ญ ๋“ฑ์˜ Task์— ํ•ด๋‹นํ•˜๋Š” Text๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•™์Šต๋œ LM์ด๋‹ค. CTRL์€ Prompt์™€ ํ•จ๊ป˜ ์ „๋‹ฌ๋ฐ›๋Š” Control Code(s)๋กœ๋ถ€ํ„ฐ ์ƒ์„ฑํ•  Text์˜ ํŠน์„ฑ์„ ๊ฒฐ์ •ํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, "์ด ์นผ์€"์ด๋ผ๋Š” Prompt๊ฐ€ ์žˆ์„ ๋•Œ, Horror๋ผ๋Š” Code๋ฅผ ํ•จ๊ป˜ ์ „๋‹ฌํ•˜๋ฉด ๊ณตํฌ์Šค๋Ÿฌ์šด ์žฅ๋ฉด์„ ๋ฌ˜์‚ฌํ•œ Text๋ฅผ ์ƒ์„ฑํ•  ๊ฒƒ์ด๊ณ , Code๊ฐ€ Reviews์ธ ๊ฒฝ์šฐ์—๋Š” ๊ตฌ๋งคํ•œ ์นผ์— ๋Œ€ํ•œ ํ›„๊ธฐ์™€ ๊ฐ™์€ Text๋ฅผ ์ƒ์„ฑํ•˜๊ฒŒ ๋˜๋Š” ์‹์ด๋‹ค. Reviews์™€ ํ•จ๊ป˜ Rating: 3.0์ด๋ผ๋Š” ์ถ”๊ฐ€ Code๋ฅผ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด ๋•Œ์—๋Š” ์ ์ˆ˜์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ๋‰˜์•™์Šค์˜ Text๋ฅผ ์ƒ์„ฑํ•œ๋‹ค๊ณ  ํ•œ๋‹ค. ์‚ดํŽด๋ณธ ๋ฐ”์™€ ๊ฐ™์ด Control Code(s)๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ 2์ข…๋ฅ˜๋กœ ๊ตฌ์„ฑ๋˜๋ฉฐ, ์ด๋“ค์„ ์กฐํ•ฉํ•˜์—ฌ ๋‹ค์–‘ํ•œ ํŠน์„ฑ์˜ Text๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.
โ€ข
Domain Control Code(s): Prompt์˜ ๋งจ ์•ž์— Prepend๋จ. ์ดํ›„ ๋ชจ๋“  Sequence๋Š” ์ด๋กœ๋ถ€ํ„ฐ Propagate๋˜๊ธฐ ๋•Œ๋ฌธ์— ํŠน๋ณ„ํ•œ ์˜๋ฏธ์˜ Token์œผ๋กœ ์ทจ๊ธ‰๋œ๋‹ค๊ณ  ํ•จ.
โ€ข
Non-Domain Control Code(s): Rating: 3.0๊ณผ ๊ฐ™์€ Code.
CTRL์˜ Training+Inference์—์„œ ์ฃผ๋ชฉํ•  ๋งŒํ•œ ์ ๋“ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
โ€ข
140GB์˜ Data๋กœ ํ•™์Šต: ์ด ์ค‘์—๋Š” General Corpus(Wikipedia ๋“ฑ), Reddit, QA+Translation Tasks ๋“ฑ์ด ํฌํ•จ๋จ.
โ€ข
Vocab Size๋Š” 250K๋กœ ๋งค์šฐ ํฐ ํŽธ, ํ•˜์ง€๋งŒ Sequence Length๋Š” 256, 512๋กœ ์ž‘์€ ํŽธ. Vocab Size๊ฐ€ ํฌ๊ธฐ ๋•Œ๋ฌธ์— Subwords๋กœ์˜ ๋ถ„ํ•ด๊ฐ€ ๋œํ•  ๊ฒƒ์ด๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ์ ์€ ์ˆ˜์˜ Token์œผ๋กœ๋„ Long Sequence๋ฅผ ์ถฉ๋ถ„ํžˆ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Œ.
โ€ข
Inference์—์„œ Temperature-Controlled ํ˜น์€ Nucleus Sampling๋ณด๋‹ค Greedyํ•˜์ง€๋งŒ, ์ค‘๋ณต ์ƒ์„ฑํ•œ Token์— Penalty๋ฅผ ์ฃผ๋Š” Penalized Sampling ํ™œ์šฉ. (Sampling ๊ธฐ๋ฒ•๋“ค์„ ์ž˜ ์ •๋ฆฌํ•ด ์ค€ Blog!)

(Uber, 2019) Plug and Play Language Models: A Simple Approach to Controlled Text Generation

CTRL๊ณผ ๊ฐ™์€ ๋ฐฉ์‹์€ LM์ด Special Token์„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก (์ถ”๊ฐ€์ ์ธ)ํ•™์Šต์„ ์ง„ํ–‰ํ•ด์•ผ ํ•˜๋Š” ์ œ์•ฝ์„ ๊ฐ–๋Š”๋‹ค. GPT ๊ณ„์—ด์˜ Model Size๋ฅผ ์ƒ๊ฐํ•˜๋ฉด, ์ด๋Š” ์ƒ๋‹นํ•œ Resource๋ฅผ ํ•„์š”๋กœ ํ•˜๋Š” ์ž‘์—…์ด๋‹ค. PPLM์€ Pre-Trained LM์˜ Parameter๋Š” Updateํ•˜์ง€ ์•Š๊ณ , Small Size์˜ Attribute Model์„ ํ™œ์šฉํ•˜์—ฌ LM์ด ์ƒ์„ฑํ•˜๋Š” Text์˜ ํŠน์„ฑ์„ ์กฐ์ ˆํ•œ๋‹ค. PPLM์€ Text์˜ Topic๊ณผ Sentiment๋ฅผ ์„ค์ •ํ•  ์ˆ˜ ์žˆ๊ณ , ๊ฐ๊ฐ Bag-of-Words(BoW), Single-Layer Classifier์˜ Attribute Model์„ ํ†ตํ•ด ์กฐ์ ˆํ•œ๋‹ค. ์กฐ๊ธˆ ๋” ๊ตฌ์ฒด์ ์ธ Process๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
1.
Pre-Trained LM์˜ Forward Pass๋กœ ๋‹ค์Œ Token(Sequence)์„ Inferenceํ•˜๊ณ , ์ด๋ฅผ Attribute Model์„ ํ†ตํ•ด ์›ํ•˜๋Š” ํŠน์„ฑ์˜ Likelihood, P(a|x)๋ฅผ Predictํ•œ๋‹ค.
2.
Attribute Model์˜ Gradient๋ฅผ ํ™œ์šฉํ•ด P(a|x)๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋Š” ๋ฐฉํ–ฅ์œผ๋กœ LM์˜ Latent Representations๋ฅผ Updateํ•œ๋‹ค.
3.
Update๋œ Latent๋กœ Inference๋ฅผ ๋‹ค์‹œ ์ˆ˜ํ–‰ํ•œ๋‹ค.
์œ„ ๋‚ด์šฉ์„ ๋…ผ๋ฌธ์œผ๋กœ ์ฝ์—ˆ์„ ๋•Œ, ์ดํ•ดํ•˜๋Š” ๋ฐ ๋ณ„๋‹ค๋ฅธ ์–ด๋ ค์›€์ด ์—†์—ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋ฅผ Code๋กœ ๊ตฌํ˜„ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋‹ˆ ๋ง‰์ƒ ๊ฐ์ด ์ž˜ ์˜ค์ง€ ์•Š๋Š” ๊ฒƒ์ด๋‹ค.. (์ด๋Š” Code๊ฐ€ ์–ด๋ ค์›Œ์„œ๊ฐ€ ์•„๋‹ˆ๋ผ ๋‚ด๊ฐ€ ํŠนํžˆ ์•ฝํ•˜๊ฑฐ๋‚˜ ๋ถ€์กฑํ•œ ์š”์†Œ๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด ์•„๋‹๊นŒ ์ƒ๊ฐํ•œ๋‹ค!) ๊ณต์‹ GitHub์„ ๋ณด๋ฉด์„œ ์ถ”๊ฐ€ ๊ณต๋ถ€๋ฅผ ํ•˜์˜€๊ณ , ๋‚ด๊ฐ€ ์ดํ•ดํ•œ ๊ฒƒ๋“ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
โ€ข
Inference์—์„œ Hugging Face GPT๋Š” ํ˜„์žฌ Time Step์˜ Token, x_T์™€ Cached(์ด์ „ Time Step๊นŒ์ง€ ๊ณ„์‚ฐ)๋œ Key-Values๊ฐ’, H_T๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„, x_(T+1)๊ณผ H_(T+1)์„ ์ถœ๋ ฅํ•จ(์—„๋ฐ€ํžˆ ๋งํ•˜๋ฉด Token์ด ์•„๋‹Œ Logits์„ ์ถœ๋ ฅ).
โ€ข
๋…ผ๋ฌธ์—์„œ ์–ธ๊ธ‰ํ•˜๋Š” LM์˜ Latent๋Š” H_T๋ฅผ ๋งํ•จ.
โ€ข
H_T์™€ ๋™์ผํ•œ Shape์˜, ๋ชจ๋“  Param์ด 0๊ฐ’์ธ Tensor, Delta(H_T)๋ฅผ ์ƒ์„ฑํ•จ.
โ€ข
Delta(H_T)์— Gradient๋ฅผ ๋ถ™์ด๊ณ (requires_grad=True), H_T์— ๋”ํ•ด์ค€ ํ›„ Forward Pass.
โ€ข
Backward Pass๋ฅผ ํ†ตํ•ด Delta(H_T)์˜ Gradient๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , Delta(H_T)๊ฐ’์„ Update.
โ€ข
H_T+Delta(H_T)๋กœ Inference๋ฅผ ๋‹ค์‹œ ์ˆ˜ํ–‰ํ•จ.