torch.load in FlashAttention training / eval checkpoint pathsMITRE service request: 1988723
Status: RESERVED (pending a qualifying public reference per CNA Rules §5.3).
The flash-attention training framework thru commit e724e2588cbe754beb97cf7c011b5e7e34119e62 (2025-13-04) contains an insecure deserialization vulnerability (CWE-502) in its checkpoint loading mechanism. The load_checkpoint() function in checkpoint.py and the checkpoint loading code in eval.py use torch.load() without enabling the security-restrictive weights_only=True parameter. This allows the deserialization of arbitrary Python objects via the pickle module. An attacker can exploit this by providing a maliciously crafted checkpoint file. When a victim loads this checkpoint during model warmstarting or evaluation, arbitrary code is executed on the victim’s system.
Warm-start and evaluation flows call load_checkpoint() → torch.load(..., map_location=device) without weights_only=True, so any swapped checkpoint on shared NFS/HF cache executes pickle gadgets in large GPU clusters.
e724e2588cbe754beb97cf7c011b5e7e34119e62.train.warmstart.path or eval checkpoint arguments referencing .pt files.training/src/utils/checkpoint.py and training/src/eval.py.High for shared HPC scratch directories.