AI/Models

Wav2Vec 2.0 Model

rolling_rolling 2024. 10. 13. 23:23

이해한 대로 적는 Wav2Vec 2.0 모델 매커니즘

참고

https://zerojsh00.github.io/posts/Wav2Vec2/

https://nongnongai.tistory.com/34

0. intro


이 모델은 2020년 facebook에서 발표한 음성 인식 모델이다. 기존의 Wav2Vec 모델과 VQ-Wav2Vec모델처럼 음성 신호가 전사되어 있는 labeled data가 부족하다는 문제를 *self-supervised learning을 통한 pre-training으로 보완했다. 또한 pre-training 이후 CTC loss를 활용하여 적은 labeled data로부터 Fine-Tuning이 가능하다는 장점을 가지고 있다.

약 10분 분량의 labeled data로 fine-tuning을 진행했을 때 LibriSpeech를 기준으로 WER(Word Error Rate)가 깨끗한 음성은 4.8%, 나머지 음성에 대해선 8.2%를 기록할 정도로 매우 훌륭한 결과를 보여주었다.

시간에 따른 Wav2Vec 2.0 학습 결과

 

또한 여러 언어에 공통적으로 사용되는 음성을 학습할 수 있는 교차 학습 방식인 XLSR도 개발되어 있다. 즉 적은 음성 데이터만 가지고 있을 때, 이 기법을 사용하여 학습에 도움을 줄 수 있다.

 

*self-supervised learning: 비지도 학습의 한 분야로 라벨이 없는 데이터를 이용하여 자기 자신의 특성(representation)을 배우는 학습 방법이다. 즉 적은 데이터로 성능을 향상시킬 수 있는 방법이다.


 

1. 모델설명


pre-training과정에서의 Architecture

 

Feature Encoder - f : X  

기존의 Wav2Vec 모델과 같이 Feature Encoder Network가 존재한다. 이 단계에서는 정보를 추출하는 단계로, Multi-Layer CNN으로 구성되어 있다.

Raw waveform이 입력되면, Multi-Layer CNN Encoder를 거쳐 매 시점 T마다 represent vector, 즉 Latent Speech representations인 z1, z2 ... zT를 출력한다. 그리고 출력된 Representations들은 transformer, quantizer 두 모듈에 전달된다.

 

 

Module1. Masking Transformer - Z  

Transformer모듈은 contextualized representation을 위한 모듈이다. contextualized representation이란 문맥에 따라 변하는 단어나 문장 표현을 의미한다.

representation인 z 시퀀스들을 이 Masking Tranformer 모듈에 넣으면 주변 정보를 이용하여 복원 context representation 이 생성된다. 이후 context network transformer을 사용하여 전체 시퀀스에서 c1, c2 ... cT를 capturing한다.

여기서 transformer블록에서는 absolute positional embedding 대신 convolution 연산을 통해 relative positional embedding 효과를 주었다고 한다.

 

transformer의 출력은 contrasive task를 풀기 위함이다. 즉, 해당 모델은 mask된 위치에 대한 correct quantized speech units를 식별해야 하며, context representations와 latent speech representations가 서로 유사하도록 학습하는 것이 목표이다. 이를 통해 contrasive loss를 최소화시킬 수 있다. contrasive loss를 최소화시키게 되면 mutual information 을 최대치로 끌어올릴 수 있는데, 이러한 학습을 통해 양질의 representation vector를 추출 할 수 있다.

 

 

Masking

masking은 transformer block의 self-attention 효과를 보기 위해 적용시킨다. 따라서 transformer 모듈에서만 masking을 수행한다.

masking 인덱스의 시작점을 선택함

 

우선 전체 구간에서 6.5를 랜덤하게 고른다.(p = 0.065)

 

masking 진행

 

그리고 M = 10만큼 마스킹을 수행한다. 이때 p와 M은 하이퍼파라미터이다. 시작지점이 어디냐에 따라 마스킹이 겹칠 수도 있고, 안겹칠 수도 있다. 또한 마스킹은 trained feature vector로 마스킹되는 부위를 대체하는 방식이며 모든 마스킹 부위에는 동일하게 해당 *feature vector를 사용한다.

예시를 들어보자면, "??? 나는 인공지능이야" 라는 문장이 있을 때, "???"이 마스킹되었다고 가정해본다. 이때 모델은 마스킹된 부분을 예측해야 하며, 그 예측 결과가 "안녕?"이라면 "안녕하세요. 오늘 날씨는..."이라는 문장에서 "안녕하세요"를 마스킹했을 때 "안녕?"으로 대체한다는 뜻이다.

 

 

*feature vector: 음성, 이미지 등의 데이터를 벡터로 변형시킨 것

 

 

Module2. Quantization Module - ZQ

G개의 codebook, V개의 code work

 

CNN Encoder를 통과한 출력값 zT는 transformer 모듈 이외에도 Quantizer Module(양자화 모듈) ZQ 을 사용하여 qT로 이산화를 수행한다. quantizer는 codebook에서 zT를 위한 code words를 선택한다. 그리고 latent audio representation (zT)의 절반은 transformer에 입력되기 전에 masking된다.

 

**codebook 이 G개 존재 --> G x V 크기의 multiple codebooks 형성, 이 행렬은 모두 학습 가능한 파라미터로 구현

**codebook --> *embedding matrix / code word --> *embedding vector로 생각하면 됨. 정확하게 말하면 code word는 음소에 대한 representation으로, 음소는 언어가 달라도 발음할 수 있는 개수는 유한하다고 판단, 마치 embedding vector로 표현한 것이다.

 

*embedding matrix : embedding vector를 모아둔 행렬

*embedding vector : 음성, 문장과 같은 데이터를 의미와 관계를 포착하는 숫자로 나타낸 것

 

 

Quantization Module 작동과정

1. encoding된 zT가 codebook 내부의 레이어를 통과하면서 logit으로 변환된다.

 

2. gumbel softmax / argmax로 one-hot encoding 진행(이산화 과정)

 

3. 각 codebook 내에서 하나의 code word vector를 골라내는데, 총 G개의 e1,,eGRd/G 벡터들로 추출됨

 

4. 벡터들 모두 concatenate 진행 --> etRd를 만듦(위 그림에서 3번째 과정)

 

5. 여기서 linear transformation RdRf 를 통해 quantized representation 


 

2. 학습 방식


masking 기법을 사용하여 오디오의 마스킹된 부분에 대한 올바른 speech unit을 예측, 그 speech unit이 무엇이어야 하는지 학습하도록 되어 있다. pre-training을 마치면, labeled data로 fine-tuning을 진행한다.

 

 

Fine-Tuning

이 단계부터는 Quantization을 활용하지 않는다. 대신 무작위로 초기화된 *linear projection Layer를 모델 최상단에 배치하고, context represntation C를 통과시키는 방식이다(처음 architecture 사진 참고). 이후 CTC loss, specAugment 등을 이용하여 labeled data에 대해 fine-tuning을 진행한다. 참고로 이 단계에서도 masking기법은 유지한다.

 

*linear projection Layer : 풀고자 하는 task의 어휘 수를 의미하는 C개 class만큼의 차원으로 projection


 

 

참고한 블로그를 많이 따라쓰고, 모르는 단어와 문장들은 찾아보고 공부하며 나름대로의 예시도 몇개 끄적여봤다. 좀 더 공부하여 모델에 대한 이해를 확실하게 하고, Wav2Vec 2.0 모델을 구현하는 것은 추후 Project Nova에서 다룰 예정이다.