정화 코딩

MDETR 모델 주요 코드 분석 본문

AI

MDETR 모델 주요 코드 분석

jungh150c 2025. 7. 18. 13:54

논문: https://arxiv.org/abs/2104.12763

 

MDETR -- Modulated Detection for End-to-End Multi-Modal Understanding

Multi-modal reasoning systems rely on a pre-trained object detector to extract regions of interest from the image. However, this crucial module is typically used as a black box, trained independently of the downstream task and on a fixed vocabulary of obje

arxiv.org

깃허브(코드): https://github.com/ashkamath/mdetr

 

GitHub - ashkamath/mdetr

Contribute to ashkamath/mdetr development by creating an account on GitHub.

github.com

 


1. 전체 디렉토리 구조

MDETR/
├── .github/
├── configs/
├── datasets/  # 다양한 데이터셋(CLEVR, COCO, GQA 등)을 처리하는 로더 및 평가 관련 코드
│   ├── phrasecut_utils/
│   ├── __init__.py
│   ├── clevr.py
│   ├── clevrref.py
│   ├── coco.py
│   ├── coco_eval.py
│   ├── flickr.py
│   ├── flickr_eval.py
│   ├── gqa.py
│   ├── lvis.py
│   ├── lvis_eval.py
│   ├── lvis_modulation.py
│   ├── mixed.py
│   ├── phrasecut.py
│   ├── phrasecut_eval.py
│   ├── refexp.py
│   ├── transforms.py
│   └── vg.py
├── models/  # MDETR의 핵심 모델 아키텍처 (백본, 트랜스포머, 매처, 세그멘테이션 등)
│   ├── __init__.py
│   ├── backbone.py  # 다양한 백본(ResNet, Timm 등)을 정의하고, 위치 인코딩을 결합해 특징 맵을 추출하는 역할
│   ├── matcher.py  # 예측된 객체와 정답 객체 사이의 최적 1:1 매칭을 위해 Hungarian Algorithm 기반 매칭 비용을 계산하는 모듈
│   ├── mdetr.py  # MDETR 모델의 전체 구조(백본, 트랜스포머, 객체 쿼리, 손실 함수 등) 정의
│   ├── position_encoding.py  # 2D 이미지용 위치 인코딩(Position Embedding) 방식 정의
│   ├── postprocessors.py  # MDETR 모델의 출력 결과를 데이터셋별 평가 형식에 맞게 후처리하는 모듈
│   ├── segmentation.py  # MDETR의 마스크 예측과 손실 계산을 위한 mask head와 loss 함수를 정의한 모듈
│   └── transformer.py  # 이미지와 텍스트를 함께 처리하기 위한 커스텀 Transformer 모듈
├── scripts/  # 학습 및 파인튜닝, 실험 스크립트들
│   ├── clevr/
│   ├── fine-tuning/
│   ├── pre-training/
│   └── utils/
├── util/  # 박스 연산, 분산 학습 도구, 시각화 등 유틸리티 함수들
│   ├── __init__.py
│   ├── box_ops.py  # 바운딩 박스 변환, IoU 및 GIoU 계산, 마스크로부터 박스 추출 등 박스 연산 유틸리티 정의
│   ├── dist.py  # 분산 학습을 지원하기 위한 유틸리티 함수들(프로세스 초기화, all_gather, reduce, rank 확인 등) 정의
│   ├── metrics.py  # 학습 중 지표 추적 및 로깅을 위한 유틸리티(SmoothedValue, MetricLogger, accuracy 등) 정의
│   ├── misc.py  # 데이터 전처리, 배치 구성, Git 정보 추적, 텐서 보간 등 다양한 보조 유틸리티 함수들 정의
│   ├── optim.py  # EMA 업데이트와 다양한 학습률 스케줄링 전략을 지원하는 최적화 유틸리티 함수들 정의
│   └── plot_utils.py  # 학습 로그 및 평가 결과를 시각화하기 위한 그래프 유틸리티 함수들 정의
├── engine.py
├── hubconf.py
├── LICENSE
├── main.py  # 메인 실행 스크립트
├── README.md  # 학습 엔진 (훈련 및 평가 루프 정의)
├── requirements.txt  # 필요한 파이썬 패키지 명세
├── run_with_submitit.py
├── run_with_submitit_gqa_eval.py
└── run_with_submitit_lvis_eval.py

 


2. mdetr.py 파일

__init__ 함수

MDETR 모델의 구조를 정의하고 필요한 구성 요소들 초기화

forward 함수

forward pass 정의

- encode_and_save=True 모드: 인코딩-only 모드 → 이미지+텍스트를 인코딩하고 memory_cache를 생성하여 반환 (학습 또는 추론 시 디코딩 전에 반드시 먼저 호출해야 함)
- encode_and_save=False 모드: 디코딩-only 모드 (캐시 기반) → encode_and_save=True 모드에서 저장된 기존 memory_cache를 받아서 쿼리 디코딩을 수행 (최종 예측 결과를 생성하는 단계)

ContrastiveCriterion 클래스

이미지와 텍스트 간의 대조 학습을 위한 손실을 계산하는 클래스

- 글로벌 이미지 ↔ 문장 수준의 대응 관계를 학습할 때 사용됨
- 이미지와 텍스트가 서로 매칭되는 경우엔 유사도 높게, 다른 쌍들은 유사도 낮게 학습하도록 유도

QACriterionGQA 클래스

GQA의 QA 태스크에서의 손실 함수를 정의한 클래스

QACriterionClevr 클래스

CLEVR의 QA 태스크에서의 손실 함수를 정의한 클래스

SetCriterion 클래스

MDETR의 통합 손실 계산 클래스 (DETR의 Set Prediction 손실 클래스를 확장함)

cf. MDETR의 세가지 Loss: Set Prediction Loss, Soft Token Prediction Loss, Contrastive Alignment Loss
- loss_isfinal 함수: 최종적으로 지목된 객체가 어떤 것인지 예측하는 binary classification loss 계산 함수 (CLEVR-Ref+)
- loss_labels 함수: 각 object query가 어떤 텍스트 토큰(span)에 해당하는지를 예측하는 soft classification loss 계산 함수 ⇒ Soft Token Prediction Loss 계산 함수
- loss_contrastive_align 함수: 이미지 object query와 텍스트 토큰의 임베딩이 의미상 일치하도록 학습하기 위한 contrastive loss 계산 함수 ⇒ Contrastive Alignment Loss 계산 함수
- loss_cardinality 함수: 예측한 객체 수와 실제 객체 수의 차이를 측정하는 loss 계산 함수 (참고용, 로깅용) → 역전파 X
- loss_boxes 함수: 예측된 박스(바운딩 박스)와 정답 박스 간의 위치 정확도를 평가하기 위해 L1 & GIoU loss 계산 함수
- loss_masks 함수: segmentation 마스크에 대해 Focal Loss와 Dice Loss를 계산하는 함수 (선택적)
- forward 함수: 모델 출력과 정답 간의 매칭을 통해 모든 손실을 계산하고 반환하는 함수

MLP 클래스

다층 퍼셉트론(Multi-Layer Perceptron, MLP) 또는 Feed-Forward Network (FFN) 클래스

build 함수

MDETR 모델과 해당 구성 옵션에 따라 손실 함수 및 weight dict을 초기화하는 빌더 함수

 


3. transformer.py 파일

forward 함수

forward pass 정의

- encode_and_save=True 모드: 이미지 + 텍스트를 인코딩하고 memory_cache를 반환
- encode_and_save=False 모드: 인코딩된 memory_cache를 받아 디코더 실행 (query 예측)

TransformerEncoder 클래스

여러 층의 TransformerEncoderLayer를 순차적으로 적용하여 입력 시퀀스(이미지 + 텍스트)를 인코딩하는 클래스

TransformerDecoder 클래스

여러 층의 TransformerDecoderLayer를 통해, object query를 받아 이미지/텍스트 memory로부터 cross-attention 기반 정보 추론을 수행하는 클래스

TransformerEncoderLayer 클래스

- forward_post 함수: [순서] 입력 전체(src)에 대해 Self-Attention -> FFN
- forward_pre 함수
- forward 함수: LayerNorm 위치를 Residual 뒤로 할지, 앞으로 할지에 따라 각각 forward_post 함수, forward_pre 함수 호출 (기본: Residual 뒤 => forward_post)

TransformerDecoderLayer 클래스

- forward_post 함수: [순서] Self-attention on tgt -> Cross-attention with memory (image) -> Feed Forward Network (FFN)
- forward_pre 함수
- forward 함수: LayerNorm 위치를 Residual 뒤로 할지, 앞으로 할지에 따라 각각 forward_post 함수, forward_pre 함수 호출 (기본: Residual 뒤 => forward_post)

 

Comments