카테고리 없음

FLASHOMNI: A UNIFIED SPARSE ATTENTION ENGINEFOR DIFFUSION TRANSFORMERS 논문리뷰

jinuklee 2025. 10. 16. 23:47
반응형
FLASH-OMNI는 Diffusion Transformers, DiTs의 이미지 및 비디오 합성 시 발생하는 높은 계산 비용 문제를 해결하기 위한 unified sparse attention engine

 

1. 문제 제기 및 목표
Multi-Modal Diffusion Transformers (MMDiTs)는 시각적 합성에 뛰어난 성능을 보이지만, 높은 계산 복잡성 때문에 고해상도 이미지 및 긴 비디오 생성에서 추론 효율성이 제한
이를 위한 기존의 sparsity는 다음과 같은 한계
Inconsistent sparsity granularity:

 기존의 방식이 다양하다 Existing methods vary from coarse-grained caching to fine-grained block skipping

-> 통일된 프레임워크 없이 결합하기 어렵다

 

Fragmented design space: 각 희소성 패턴은 특정 작업에 맞춰 설계되어, 최적 전략을 찾기 어렵고 다른 애플리케이션에 재사용하기 어렵다. Current methods introduce sparsity patterns tailored to specific tasks (e.g., dynamic vs. pattern-based attention) This fragmentation complicates the search for optimal strategies and prevents reuse across different applications.
Lack of kernel generality: 대부분의 희소성 접근 방식은 특정 구조에 최적화된 **맞춤형 커널(customized kernels)**을 필요로 하여 유연성을 떨어뜨리고 확장 가능한 배포를 저해 most sparse approaches require dedicated kernels optimized for specific sparsity structures.
FLASHOMNI는 이러한 문제를 해결하고, 임의의 DiT 아키텍처와 호환되는 attention 엔진을 제공하여 계산 병목 현상을 완화하는 것을 목표

2. 핵심 방법론: 통합 및 효율성 디자인
FLASHOMNI는 "Update–Dispatch" 패러다임을 사용하여 여러 sparse 전략들을 통합해, 세 가지 핵심 디자인을 도입
A.Unified Sparse Symbols
다양한 수준의 희소성을 통일된 형식으로 표현하기 위해 **8-bit sparse symbols()**을 도입
(Feature Caching Symbols): 캐시된 출력 블록을 표시
(Block-Sparse Skipping Symbols): 불필요한 계산을 건너뛸 블록 타일 쌍을 표시

 

다중 양자화 통합: 이 symbols은 실행 시 디코딩되어 유연한 다중 양자화 통합(multi-granularity integration)을 가능하게 하고, 캐시된 feature가 선택적으로 업데이트 하게 한다

We introduce compact 8-bit sparse symbols to represent multiple levels of sparsity in a unified format. These symbols guide the selective update of cached features

MMDiT에 대한 통찰: 텍스트-투-비전 DiT(MMDiT)에서는 Vision-to-Text 및 Text-to-Vision 영역이 멀티모달 융합에 필수적이므로, FLASHOMNI는 텍스트 토큰에 크게 영향을 미치거나 텍스트 안내를 강하게 받는 이미지 토큰을 캐싱에서 제외하여 교차 모달 일관성(cross-modal consistency)을 보장합니다.
B. 일반 희소 어텐션 커널 (General Sparse Attention Kernel)
FLASHOMNI는 런타임에 희소 심볼을 디코딩하고 다양한 희소성 전략을 효율적으로 실행하는 단일 커널을 설계
Update 단계: 인접한 타임스텝 시퀀스를 기반으로 희소 심볼과 피처 캐시를 업데이트
Dispatch 단계: 갱신 단계에서 생성된 희소 심볼을 활용하여 어텐션 계산을 가속화, CTA(Cooperative Thread Arrays)는 희소성 단위에 따라 캐시 재사용(cache-then-reuse) 또는 주문형 계산(compute-on-demand) 모드
효율성: 피처 캐싱(Feature Caching, FC)은 BSS(Block-Sparse Skipping)보다 더 높은 성능을 보이는데, FC는 CTA당 한 번만 디코딩이 필요하지만 BSS는 축소 과정(reduction process) 전반에 걸쳐 반복적으로 디코딩해야 하기 때문
C. 최적화된 희소 GEMM (Optimized Sparse GEMMs)
FLASHOMNI는 어텐션 모듈의 선형 레이어에서 중복 계산을 제거하고 캐시 저장 로직을 개선하기 위해 GEMM-QGEMM-O를  사용
GEMM-Q (Query Projection 최적화): Dispatch 단계에서 가 특정 블록 가 캐시에서 검색되도록 지정하는 경우, 해당 쿼리 투영 계산을 건너뛸 수 있다.
GEMM-O (Output Projection 최적화): 아웃풋 투영() 단계에서, 캐시에서 검색된 헤드들의 가중치가 적용된 합(weighted sum)을 cache bias() 항으로 캐싱하여 저장. Dispatch 단계에서 이 를 재사용하고 나머지 계산에만 집중함으로써, 중복 계산을 제거하고 메모리 소비를 감소. GEMM-O는 이 전략을 통해 **이론적 한계의 약 87.5%**에 달하는 가속을 달성
3. 실험 결과 및 성능
실험은 FLUX.1(이미지 생성), HunyuanVideo(비디오 생성), FLUX.1-Kontext(이미지 편집) 등 최신 DiT 모델에서 
정량적 결과: FLASHOMNI는 PSNR, LPIPS, SSIM, FID 등 모든 정량적 품질 지표에서 기존의 모든 희소 가속 방법론(DiTFastAttnV2, SpargeAttn, FORA, ToCa, TaylorSeer)보다 지속적으로 우수한 성능
가속 및 효율성:
    ◦ 어텐션 및 GEMM-Q에서 희소성 비율에 거의 선형적으로 일치하는(1:1) 가속을 제공
    ◦ GEMM-O에서는 2.5배에서 3.8배의 가속을 달성
    ◦ Hunyuan 모델(33K)에 적용했을 때, 시각적 품질 저하 없이 end-to-end acceleration)약 1.5배를 달성

 

특히 워밍업 단계(warmup steps)가 낮을 때 TaylorSeer와 같은 캐싱 방법보다 훨씬 높은 이미지 품질을 유지
반응형