Geneformer 코드에서 아래 부분은 BERT의 인코더 부분은 공통이고, task별로 다른 output head, 즉 task-specific head가 붙은 모델 클래스를 불러오는 부분이다.
BERT가 transformer의 encoder-only model인데, Geneformer는 6개의 transformer encoder units으로 구성된 모델(=encoder only)이므로 그냥 BERT를 불러온 것.
from transformers import (
BertForMaskedLM,
BertForSequenceClassification,
BertForTokenClassification,
)
task-specific head라는 것은 즉,
MLM head -> BertForMaskedLM
Sequence Classifier Head (BertForSequenceClassification)
Token Classifier Head (BertForTokenClassification)
즉, BERT의 핵심 encoder 부분과 해당 task에 맞는 출력층 전체를 loading하고 있다. (전체 파라미터 로드)
BertForMaskedLM, BertForSequenceClassification, BertForTokenClassification 이 3개 모델 클래스는 BERT 표준 아키텍쳐 클래스인데, 역할은 아래와 같다.
- BertForMaskedLM
- BERT encoder + MLM(Masked Language Modeling) head
- 사전학습(pretraining)이나 마스크 복원 작업에 사용
- Geneformer의 pretraining 단계에서 사용; MLM 방식으로 gene context 학습 (self-supervised learning?)
- Multi-Task learning에서 MLM+다른 task를 동시에 학습하는 변형 구조
- BertForSequenceClassification
- BERT encoder + 문장 단위 classification head
- 텍스트 분류, 감정 분석, 문장 카테고리 분류 등에 사용
- cell 단위 classification task에 사용
- BertForTokenClassification
- BERT encoder + token-level classification head
- 개체명 인식(NER), BIO 태깅, token별 라벨링 작업에 사용
- token 단위 prediction 작업, 예를 들어 gene classifier로서 사용