Data Science/Paper Review

CONTROL PREFIXES for Parameter-Efficient Text Generation (2021) 논문리뷰

준나이 2023. 5. 13. 21:40

Abstract

prefix-tuning
  • large pre-trained model(PLM)을 downstream task로 adaptation 시키기위한 가볍지만 강력한 기술
  • 하지만 같은 level의 dataset으로 학습된 prompt를 모든 examples에 사용함
control-prefixes
  • prefix-tuning을 확장해서, input-dependent한 information을 추가로 포함시켜 dynamic prompt화 시킨 기술
  • prompt-tuning과 controlled generation의 이점들을 모두 이용할 수 있게됨
  • attribute-level representation을 PML의 layers와 통합시키고 text를 원하는 방향으로 생성 수 있도록 guide할 수 있음
  • GEM benchmarks 중에서 5 dataset을 실험에 활용
  • 추가 params가 0.1 ~ 3%에 불과한 parameter-efficient한 방법

1. Introduction

  • PLM을 downstream tasks로 adaptation 시키기는게 text generation분야의 일반적 접근 방식
fine-tuning의 한계
  • PLM의 모든 parameters를 대상으로 tuning시켜 task마다 LM을 별도로 만들어야 함
  • PLM의 크기가 수백만에서 수십억에 달하기 때문에 현실적으로 이용하기 힘듦
  • 다른 task를 푸는데도 도움이 될 수 있는 정보들을 그대로 overwriting 시킴
기존 prompt-tuning의 한계
  • 많은 researchers는 PLM의 parameters를 고정시켜 문제를 해결하려고함
  • prmpt tuning은 input에 tuned prompt를 덧붙여 downstream tasks로 adaptation 시키는 방법
  • 하지만 input-dependent한 dynamic prompt 영역에 대한 연구는 부족
  • 기존 controlled generation 기법들은 task에 대한 성능과는 관련없이 특정한 text를 생성시키는게 목표이거나, attribute-level parmameters 뿐만 아니라 PLM까지 모두 update됨
dynamic prompting method "control prefix"
  • prefix-tuning을 확장시켜 PLM의 모든 lyaers에 static task-specific prompt를 통합시킴
  • PLM의 parameters는 fixed된 상태에서 datapoint-specific attributes가 input-level에서 가이드하는 역할을 함
  • 이를 위해 control prefixes가 주어진 guidance에 따라 input을 변경시킴
  • 이러한 dynamic prompts는 static prompt parameters와 같이 작동하여 frozen PLM을 finer-grained control로 확장시킴
  • 선택된 attributes에는 input에 대한 추가정보(e.g. data-to-text triple set의 domain)를 주거나 생성 됐으면 하는 output의 특성(e.g. text길이)을 명시할 수 있음
  • 각각의 dataset에 specific한 추가적인 input-level information을 이용해서 다양한 text generation tasks를 평가
experiment
  • 기존 방법론들을 상회하는 성능을 보여줌 (WebNLG, DART, E2E Clean)
  • higher human-assessed performance for summerisation on SUM
  • attribute-level information을 이용할 수 있는 dataset에 focus를 둠 (실험 시)
  • 하지만 attribute-level information을 이용할 수 없는 경우가 일반적이기 때문에 control prefixes를 이용한 zero-shot learning이 효율적일 수도 있다는 것을 함께 보여줌

2. Related Work

prompt tuning
  • prompt tuning: discrete한 prompting과는 다르게 soft prompt는 gradient descent를 이용한 labelled data의 정보를 최대로 이끌어냄
  • prompt embedding tuning: input embedding에 덧붙여지는 prompt embedding을 학습시킴
  • prefix-tuning: NLG에 특화된 prompt embedding tuning으로 모든 examples이 함께 사용하는 학습가능한 K-V pairs가 존재하고 left context를 attention 시키는데 사용됨
Controlled generation
  • controlled generation은 다양한 종류의 guiance를 이용해서 prompt learning이 부족한 부분을 채워줌
  • 각각의 language를 인코딩한 control tokens으로 이용한 multilingual translation model
  • 하지만 이러한 모델은 contorl prefixes 뿐만 아니라 PLM의 paramters도 학습시켜야함
  • 대안으로 plug-and-play 방식들도 존재하는데, topic이나 sentiment같은 생성되는 글의 성격을 control 할 수 있음
  • 하지만 추가적은 computation으로 inference시에 느린 단점이 존재
Dynamic prompts
  • input에 denpendent하다는 특징을 갖고있는 dynamic prmpts에 대한 연구는 거의 없음
  • 가장 비슷한 연구로는 dynamic prompts를 구성하기 위해 attribute alignment function을 사용한 방법이 존재
  • 하지만 이 방법은 static한 component는 사용하지 않고, task에 대한 성능은 고려하지 않고 특정한 target 속성을 생성하는 것을 목적으로 함
  • control prefixes는 task를 명시할 수 있는 static prompt component도 사용해서 task 성능도 향상시킴

3. Contorl Prefixes

3.1. Background

  • 이 연구는 조건부 확률 $P(Y|X)$를 objective로 하는 seq2seq tasks를 대상으로 함
  • $X$:tokenised input, $Y$: output sequences (e.g. summerisation: $X$=article, $Y$=highlight)
PLM ($\phi$)
  • enc-dec models decoding auto-regressively - T5-large, BART$_LARGE
  • $\phi$ 는 학습 시 frozen
  • $d$: hidden state dimension, $L$: the number of layers
attention class $(E, D_c, D_m)$
  • 각 layer에 존재하는 3가지 attention ($E$ = encoder self-attention, $D_c$: decoder cross-attention, $D_m$: decoder masked-attention)
  • l-th layer의 attention computation을 위해 $Q, K, V$ 존재 ($Q_l \in \mathbb{R}^{N \times d}$, $K_l, V_l \in \mathbb{R}^{M \times d}$)
  • $N$: the number of tokens relating queries, $M$: the number of tokens relating keys and values

3.2. Intuition

  • frozen PLM은 자연어에 대한 폭넓은 이해를 갖고있고 이는 tasks에 공유될 수 있는 parameter-efficient한 방법에 시작점을 제공함
  • 여기에 trainable task-specific parameters를 더해서 model이 특정한 task와 관련된 정보를 학습할 수 있도록 함
  • 여기에 attribute-level parameters를 추가함으로써 모델이 원하는 방향대로 output을 생성할 수 있도록 datapoint-level information을 제공해줌
  • general task-specific parameters는 각각의 input $X$의 모듈화된 가이드 신호(guide signal)에 따라 따라 변화하는 control prefixes에 맞게 adaptation 됨
  • 여기서 guide signal로 attributes를 discrete labels로 이용

3.3. Description

  • $P_{\theta}$: a general task prefix (=task-specific parameters)
  • $C_{\theta}$: a set of control prefixes (=attribute-level parameters)
  • input $X$를 처리할 때 어떤 control prefixes를 사용할지 나타내는 guidance $G$를 필요로 함
  • corpus $Z = \{\langle X^j,Y^j,G^j \rangle \}_{j=1,\ldots,N}$
  • $G^j$: 모든 conditional attribute-level information for the sample $j$
  • gradient descent를 통해 최종적으로 inference parameters $\phi$를 최적화시키는 것이 목적 ($\phi$는 fixed)
    $$ \theta^* = \arg\max_{\theta} \sum^{N}_{j=1} \log p(X^j, G^j; P_{\theta}, C_{\theta}, \phi)$$
General Prefix
  • attention $(E, D_c, D_m)$마다 K-V pairs의 prefix가 각각 학습됨
  • $P=\{P_1, \ldots , P_2 \}, P_l \in \textbb{R}^{\rho \times 2d}$
  • $\rho$ = prompt length
  • prefix-tuning에서는 l-th layer에서 일어나는 attention을 위해 $K^l$, $V^l$은 다음과 같이 agumented 됨
    $$ K_{l}' = [P_{l,K} ; K_l], V_{l}' = [P_{l,V} ; K_l] $$
  • $ K_l', V_l' \in \mathbb{R}^{(\rho + M)\times dL}$
Control Prefixes
  • 하나의 attribute가 R개의 labels를 가질 수 있다고 가정 (e.g. news domain: sport, technology, ...)
  • $C_{\theta} = \{C_{\theta,1},\ldots,C_{\theta,R}\}$
  • $C_{\theta,r} \in \mathcal{R}^{\rho_c \times 6dL} $: r-th attribute label에 대해서 학습한 control prefix
  • $\rho_c$: control prompt length (attribute에 따라 달라짐)
  • $\mathcal{A}$: $G$가 가리키는 attribute label에 상응하는 control prefix를 반환하는 함수
    $$ K_{l}'' = [\mathcal{A}(G)_{l,K} ; P_{l,K} ; K_l], V_{l}'' = [\mathcal{A}(G)_{l,V} ; P_{l,V} ; K_l]$$
  • $ K_l'', V_l'' \in \mathbb{R}^{(\rho_c + \rho + M)\times d} $
Shared Re-parameterisation
  • 기존 연구에 의해 학습가능한 parameters 수를 늘려서 prefix optimisation을 안정화 시킬수 있다는 것이 밝혀짐
  • 위의 연구에서는 prefix를 re-parameterisation하기 위해 1층 짜리 MLP를 사용했는데, 이 논문에서는 각각의 attention에 대응되도록 2층짜리 MLP를 3개 이용함
  • 각각의 attnetion class $[E, D_c, D_m]$ 마다, $P=MLP(\tilde{P})$ 이고 중간 hidden size $k$는 800으로 설정 ($\tilde{P} \in \mathbb{R}^{\rho \times d}$)
  • MLP와 $\tilde{P}$는 $\tilde{\theta}$에 의해 parameterised 됨
  • 여기서 $\theta$는 $\tilde{\theta}$의 function이고 $|\theta| < |\tilde{\theta}|$
  • 학습완료 후 $\theta$는 저장하고 $\tilde{\theta}$는 폐기해도 됨
  • general prefix $P_{\theta}$처럼 contorl prefix $C_{r,\theta} = \{C_{r}^{E}, C_{r}^{D_c}, C_{r}^{D_m}\}$는 각각의 attention를 위해 구성됨
  • MLP 또한 위처럼 $MLP^E, MLP^{D_c}, MLP^{D_m}$ 존재
  • over-parameterisation은 optimisation landscape을 더 부드럽게 하는 효과가 있음
  • 3개의 분리된 re-parameterisation은 각각의 prefix element가 control과 시너지를 낼 수 있게함