[논문 리뷰] FedMSplit: Correlation-Adaptive Federated Multi-Task Learning across Multimodal Split Networks
https://dl.acm.org/doi/10.1145/3534678.3539384
1. Introduction
연합학습 (Federated Learning, FL)
여러 클라이언트가 서로의 데이터를 공유하지 않고도 함께 모델을 학습하는 분산 학습 프레임워크
의의: 통신 비용 절감과 프라이버시 보호
멀티모달 연합학습 (Multimodal Federated Learning, MFL)
배경: 센서 기술의 발전, 다양한 형태의 데이터 증가 -> FL에서 확장되어 MFL 등장
여러 클라이언트가 다양한 센서 조합(모달리티)으로 수집한 데이터를 기반으로, 데이터를 공유하지 않고도 함께 모델을 학습
기존 FL/MFL 연구의 한계
대부분 통계적 이질성(Statistical Heterogeneity), 즉 클라이언트마다 데이터 분포가 다른 문제(non-IID 문제)에 집중
한계: 클라이언트들이 동일한 모달리티(센서 구성)를 가지고 있다는 가정(=modality congruity) 전제
But, 현실에서는 모달리티 불일치(Modality Incongruity), 즉 클라이언트들의 센서 종류/개수가 다르고, 그들의 데이터의 모달리티 조합이 다른 문제 존재
문제 접근 방식: 멀티태스크 연합학습(FMTL) 기반으로 시작
멀티태스크 연합학습 (Federated Multi-Task Learning, FMTL): 각 클라이언트가 별도의, 하지만 관련 있는 모델을 학습 by 글로벌 정규화 항 -> 통계적 이질성(Statistical Heterogeneity) 다룸
목표: 기존 FMTL에 모달리티 불일치(Modality Incongruity)까지 함께 다루는 멀티모달 FMTL 프레임워크 만들기
멀티모달 FMTL 프레임워크의 도전 과제
1. 클라이언트마다 다른 작업, 다른 파라미터 공간 -> 집계가 어려움
기존 방식: 없는 모달리티에 대해 padding해서 파라미터 채우기. But, 이렇게 하면 노이즈 크고 비효율적임.
-> 그래서 아예 서로 다른 파라미터 공간 위에서의 개별 모델 학습을 목표로 함.
2. 제한된 통신 자원 -> 일부 클라이언트만 참여 가능
기존 방식: 클라이언트를 랜덤으로 선택 or 전체 선택. But, 이렇게 하면 비효율적임.
-> 그래서 클라이언트 타입이 균형 있게 참여하도록 설계하는 것을 목표로 함.
해결책: FedMSplit 프레임워크
멀티모달 클라리언트 간 연합학습, 센서 구조가 동일하다는 가정 없이!
FedMSplit의 핵심 아이디어: 동적 멀티뷰 그래프 구조 (dynamic and multi-view graph structure)
- 클라이언트 모델을 작은 블록 단위 쪼갬
- 그래프의 엣지 특징이 클라이언트 사이의 연관관계를 반영함
- 그래프 기반 메세지 전달을 통해 로컬 모델 관계를 향상함
이 연구의 기여
1. 모달리티 불일치 문제를 FMTL을 활용하여 다룬 첫 연구
2. 그래프를 통해 다른 타입의 클라이언트 간 관계를 동적으로 포착하는 FedMSplit 제안
3. 두 멀티모달 연합학습 데이터셋으로 FedMSplit 평가
같이 생각해본 한계
- 서버에서 하는 일이 좀 많아서 부담이 큰 편. 서버에 올리기 전에 인접한 두 클라이언트가 너무 다르거나 하면 필터링을 한번 해주는 게 있으면 서버 부담을 줄이는 데에 도움이 될 것 같다. or 클라이언트 쪽에서 정보 요약하는 부분이 있으면 도움이 될 것 같다.
- 모달리티가 같은 블록이 존재하면 단순히 블록 간 유사도를 계산해서 반영 (학습 목표의 이질성은 고려하지 X) -> 이것까지 고려한게 아마도 FedCola인듯!!
활용 가능성
FedCola에 client 선택 알고리즘을 반영하면 더 성능이 좋아질 수 있을듯
질문1
ModelNet40이 (클래스 수가 더 많음에도) 유독 정확도가 높은 이유가 뭘까? 다른 데이터셋들은 음향이 들어가서 그럴까? 좀 더 구체적으로는.. 다른 데이터셋들에 비해 모달리티 간 이질성(Modality Incongruity)이 상대적으로 적어서 그런게 아닐까 싶음. (두 modality가 시점이 다르지만 둘 다 시각적 정보이기 때문에)
질문2
논문에 한계가 명시적으로 적혀있지는 않은 것 같아서 같이 얘기해보면 좋을 것 같음. 하나는 람다를 업데이트하고 클라이언트를 선택하는 과정에서 오버헤드가 크다는 한계임. 추가로 in-modality gap (같은 모달리티여도 클라이언트 별로 이질성이 있을 수 있다, 멀티모달 클라이언트와 유니모달 클라이언트의 학습 목적이 다르기 때문에) 을 고려하지 않은 점도 한계로 볼 수 있을 것 같음.
교수님 피드백
논문 세미나 시, 10% 정도만 ppt 또는 스크립트 보자. 나머지 90%는 청중과 아이컨택하자.
그리고 현재 말하는 내용을 발표 자료와 synchronize 필요.
깃허브에서 논문 관련 코드를 찾아서 세미나의 experiment 전 혹은 후에 코드를 첨부해서 설명하자.
다른 분들의 질문과 의견
- classifier를 모두 공통으로 사용하는데, dimension을 어떻게 맞춰주는 거지?
local block과 global block 사이에 뭔가 있어서 거기에서 맞춰주는 듯. 아마도 fully connected로 맞춰주는 것 같고, 이것들은 서버에 보내서 학습시키는 게 아니라 로컬에서 학습되는 것 같음.
- classification 작업에만 한정됨. 그리고 통신 쪽 (communication overhead) 성능 언급이 없어서 아쉬웠음.
- 그래프 구조: 클라이언트 관계를 더 잘 포착할 수 있었을 것 같은데, 단순히 유사도만 계산하는 것 같아서 그런 부분은 아쉬움.
- ModelNet40 같은 경우에 싱글 모달리티의 헤테로지니어스 데이터라고 보는 게 맞을듯. 그래서 특히 정확성이 높은 것 같음.
- 전체적으로 아이디어가 엄청 특별하다기 보다는, 부분 부분 차용할 부분은 많은 것 같음.