본문 바로가기

scFoundation

Geneformer - BERT

Geneformer 코드에서 아래 부분은 BERT의 인코더 부분은 공통이고, task별로 다른 output head, 즉 task-specific head가 붙은 모델 클래스를 불러오는 부분이다.

BERT가 transformer의 encoder-only model인데, Geneformer는 6개의 transformer encoder units으로 구성된 모델(=encoder only)이므로 그냥 BERT를 불러온 것. 

from transformers import (
    BertForMaskedLM,
    BertForSequenceClassification,
    BertForTokenClassification,
)

 

 

task-specific head라는 것은 즉, 

MLM head -> BertForMaskedLM

Sequence Classifier Head (BertForSequenceClassification)

Token Classifier Head (BertForTokenClassification)

 

즉, BERT의 핵심 encoder 부분과 해당 task에 맞는 출력층 전체를 loading하고 있다. (전체 파라미터 로드) 

 

 

BertForMaskedLM, BertForSequenceClassification, BertForTokenClassification 이 3개 모델 클래스는 BERT 표준 아키텍쳐 클래스인데, 역할은 아래와 같다.

 

  • BertForMaskedLM
    • BERT encoder + MLM(Masked Language Modeling) head
    • 사전학습(pretraining)이나 마스크 복원 작업에 사용
    • Geneformer의 pretraining 단계에서 사용; MLM 방식으로 gene context 학습 (self-supervised learning?)
    • Multi-Task learning에서 MLM+다른 task를 동시에 학습하는 변형 구조
  • BertForSequenceClassification
    • BERT encoder + 문장 단위 classification head
    • 텍스트 분류, 감정 분석, 문장 카테고리 분류 등에 사용
    • cell 단위 classification task에 사용
  • BertForTokenClassification
    • BERT encoder + token-level classification head
    • 개체명 인식(NER), BIO 태깅, token별 라벨링 작업에 사용
    • token 단위 prediction 작업, 예를 들어 gene classifier로서 사용