Skip to content

Conversation

@toilaluan
Copy link
Contributor

What does this PR do?

Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569

Author's codebase: https://github.com/Shenyi-Z/TaylorSeer

This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation

Expected Output

4->5x speeding up by these settings while keep output images are qualified

image

State Design

Core of this algorithm is about predict features of step t by using real computed features from previous step using Taylor Expansion Approximation.
We design a State class, include predict & update method and taylor_factors: Tensor to maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)

  • update method will be called from real compute timestep and update taylor_factors using math formular referenced to original implementation
  • predict method will be called to predict feature from current taylor_factors using math formular referenced to original implementation

@seed93
Copy link

seed93 commented Nov 14, 2025

Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out.

@toilaluan
Copy link
Contributor Author

@seed93 yes, i am prioritizing for flux series and qwen image

@toilaluan
Copy link
Contributor Author

Here is analysis about TaylorSeer for Flux
Comparing with baseline, the output image is different, although PAB method give pretty close result
This result is match with author's implementation

model_id cache_method compute_dtype compile time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
flux none fp16 False 22.318 33.313 33.322 33.322 34.305
flux pyramid_attention_broadcast fp16 False 18.394 33.313 33.322 33.322 35.789
flux taylorseer_cache fp16 False 6.457 33.313 33.322 33.322 38.18

Flux visual results

Baseline

image

Pyramid Attention Broadcast

image

TaylorSeer Cache (this implementation)

image

TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)

image

Benchmark code is based on #10163

import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch

from diffusers import (
    apply_taylorseer_cache, 
    TaylorSeerCacheConfig, 
    apply_faster_cache, 
    FasterCacheConfig, 
    apply_pyramid_attention_broadcast, 
    PyramidAttentionBroadcastConfig,
)

def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            current_timestep_callback=lambda: pipe.current_timestep,
        )
    elif cache_method == "taylorseer_cache":
        return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
    elif cache_method == "fastercache":
        return FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 681),
        low_frequency_weight_update_timestep_range=(99, 641),
        high_frequency_weight_update_timestep_range=(-1, 301),
        spatial_attention_block_identifiers=["transformer_blocks"],
        attention_weight_callback=lambda _: 0.3,
        tensor_format="BFCHW",
    )
    elif cache_method == "none":
        return None


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    print(f"Generator: {generator}")
    print(f"Generation kwargs: {generation_kwargs}")
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method, pipe)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_faster_cache(pipe.transformer, config)
        elif cache_method == "taylorseer_cache":
            apply_taylorseer_cache(pipe.transformer, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype)
    

@toilaluan
Copy link
Contributor Author

More comparison between this impl, baseline, author's impl

image

@toilaluan
Copy link
Contributor Author

I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute
Example in a sequence of Linear1, Act1, Linear2, Act2: we need to add hook for Linear1,act1,linear2 to do nothing (return an empty tensor) but cache output of act2
I already fix template for flux, but for other models, user have to write their own and pass it to the config init
@sayakpaul how do you think about this mechanism? I need some advises here

@sayakpaul sayakpaul requested a review from DN6 November 14, 2025 17:41
@toilaluan
Copy link
Contributor Author

toilaluan commented Nov 15, 2025

Tuning cache config really helps!

TaylorSeer cache configuration comparison

In the original code, they use 3 warmup steps and no cooldown. The output image differs significantly from the baseline, as shown in the report above.

As suggested in Shenyi-Z/TaylorSeer#12, increasing the warmup steps to 10 helps narrow the gap, but the cached output still has noticeable artifacts. This naturally suggested adding a cooldown phase (running the last steps without caching).

All runs below use the same prompt and 50 inference steps.

Visual comparison

Baseline vs. 3 warmup / 0 cooldown

Baseline (no cache) 3 warmup steps, 0 cooldown (cache)
Baseline output 3 warmup, 0 cooldown output

With only 3 warmup steps and 0 cooldown steps, the image content is not very close to the baseline.

10 warmup / 0 cooldown vs. 10 warmup / 5 cooldown

10 warmup steps, 0 cooldown (cache) 10 warmup steps, 5 cooldown (cache)
10 warmup, 0 cooldown output 10 warmup, 5 cooldown output

With 10 warmup steps, the content is closer to the baseline, but there are still many artifacts and noise.
By running the last 5 steps without caching (cooldown), most of these issues are resolved.


Hardware usage comparison

The table below shows the hardware usage comparison:

cache_method predict_steps max_order warmup_steps stop_predicts time (s) model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none - - - - 22.781 33.313 33.321 37.943 fp16
taylorseer_cache 5.0 1.0 3.0 - 7.099 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 9.024 55.490 55.490 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 - 9.451 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 11.000 55.490 55.490 70.283 fp16
taylorseer_cache 6.0 1.0 3.0 - 6.701 55.492 55.492 70.285 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 8.651 55.490 55.490 70.285 fp16
taylorseer_cache 6.0 1.0 10.0 - 9.053 55.492 55.492 70.283 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 11.001 55.490 55.490 70.283 fp16
image

Code

import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import FluxPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
    generation_kwargs = {
        "prompt": prompt,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.float16
    taylor_factors_dtype = torch.float16

    param_grid = {
        'predict_steps': [5, 6],
        'max_order': [1],
        'warmup_steps': [3, 10],
        'stop_predicts': [None, 45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    set_verbosity_info()
    main(args.output_dir)

@toilaluan
Copy link
Contributor Author

Similar behavior with Qwen Image

cache_method predict_steps max_order warmup_steps stop_predicts time model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none 23.01 53.791 53.807 64.359 fp16
taylorseer_cache 5.0 1.0 3.0 13.457 53.807 53.813 67.303 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 14.562 53.813 53.819 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 14.775 53.819 53.825 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 15.628 53.825 53.832 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 13.214 53.832 53.839 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 14.349 53.838 53.845 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 14.595 53.844 53.851 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 15.707 53.851 53.858 67.342 fp16

@toilaluan
Copy link
Contributor Author

toilaluan commented Nov 16, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93
I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.


Original Image

original image

Image Comparison (Side-by-Side)

Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8

Processing Times (CSV → Markdown Table)

| cache_method     | predict_steps | max_order | warmup_steps | stop_predicts | time    | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype |
|------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------|
| none             |               |           |              |               | 48.391  | 31.438           | 31.446               | 36.209                   | fp16          |
| taylorseer_cache | 7.0           | 1.0       | 10.0         | 45.0          | 21.468  | 31.447           | 31.447               | 44.625                   | fp16          |
| taylorseer_cache | 8.0           | 1.0       | 10.0         | 45.0          | 20.633  | 31.447           | 31.447               | 44.625                   | fp16          |

Reproduce Code

import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import DiffusionPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    from diffusers.utils import load_image
    model_id = "black-forest-labs/FLUX.1-Kontext-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang, Ultra HD, 4K, cinematic composition."
    edit_prompt = "Add a hat to the cat"
    input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
    generation_kwargs = {
        "prompt": edit_prompt,
        "num_inference_steps": 50,
        "guidance_scale": 2.5,
        "image": input_image,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.bfloat16
    taylor_factors_dtype = torch.bfloat16

    param_grid = {
        'predict_steps': [7, 8],
        'max_order': [1],
        'warmup_steps': [10],
        'stop_predicts': [45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        del pipe
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    # set_verbosity_info()
    main(args.output_dir)

@toilaluan toilaluan marked this pull request as ready for review November 17, 2025 06:21
@seed93
Copy link

seed93 commented Nov 17, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93 I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.

Original Image

original image ## **Image Comparison (Side-by-Side)** Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8 ## **Processing Times (CSV → Markdown Table)** ``` | cache_method | predict_steps | max_order | warmup_steps | stop_predicts | time | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype | |------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------| | none | | | | | 48.391 | 31.438 | 31.446 | 36.209 | fp16 | | taylorseer_cache | 7.0 | 1.0 | 10.0 | 45.0 | 21.468 | 31.447 | 31.447 | 44.625 | fp16 | | taylorseer_cache | 8.0 | 1.0 | 10.0 | 45.0 | 20.633 | 31.447 | 31.447 | 44.625 | fp16 | ```

Reproduce Code

This is amazing!

@seed93
Copy link

seed93 commented Nov 19, 2025

I am not sure why it uses so much gpu memory? I have only 24 GB gpu memory.

@seed93
Copy link

seed93 commented Nov 19, 2025

Could you please try using the taylorseer-lite as an option? refer to Shenyi-Z/TaylorSeer#5

@toilaluan
Copy link
Contributor Author

@seed93 yeah it seems to not complicated, i will try and post some report here

@toilaluan
Copy link
Contributor Author

Taylorseer-lite

@seed93, you can use lite version with minimal extra memory by following this script but it works for Hunyuan model, not Flux.
Flux TS-lite's output is purely noise

  • Hunyuan Output
image
  • Flux Output
image
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig


model = "hunyuanimage"  # or "flux"
if model == "flux":
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
elif model == "hunyuanimage":
    pipeline = HunyuanImagePipeline.from_pretrained(
        "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    skip_identifiers=[r"^(?!proj_out$)[^.]+\.[^.]+$"],
    cache_identifiers=[r"proj_out"],
    predict_steps=5,
    max_order=2,
    warmup_steps=10,
    stop_predicts=48,
    taylor_factors_dtype=torch.bfloat16,
)

pipeline.transformer.enable_cache(cache_config)

prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]

image.save("teddy_bear.jpg")

@toilaluan
Copy link
Contributor Author

@DN6 This feature is ready for reviewing, could you take a look 🙇

@DN6
Copy link
Collaborator

DN6 commented Dec 3, 2025

@toilaluan Just one more thing to consider is if the the cache is compatible with torch.compile. Conditional checks usually don't play well with torch.compile. We've disabled it for certain checks in previous cacheing approaches

@torch.compiler.disable

@toilaluan
Copy link
Contributor Author

@DN6, @sayakpaul I added a similar compiler disable to taylorseer cache, but i have an observation that can be applied to both FBCache and TaylorSeer:

  • graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
    By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig, FirstBlockCacheConfig

# torch._logging.set_logs(graph_code=True)

import torch._dynamo as dynamo
dynamo.config.recompile_limit = 100

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    cache_interval=5,
    max_order=1,
    disable_cache_before_step=50, # assume we will run full compute to see compile effect
    disable_cache_after_step=48,
    taylor_factors_dtype=torch.bfloat16,
    use_lite_mode=True
)
fbconfig = FirstBlockCacheConfig(
    threshold=1e-6, # set this value to very small so cache will not be applied to see compile effect
)

pipeline.transformer.enable_cache(fbconfig) # or cache_config

pipeline.transformer.compile(fullgraph=False, dynamic=True)

prompt = "A laptop on top of a teddy bear, realistic, high quality, 4k"
# warmup
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
# monitor this call
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
image.save("teddy_bear.jpg")

@toilaluan
Copy link
Contributor Author

toilaluan commented Dec 4, 2025

Comparison: Baseline, Baseline-22steps, FBCache, TaylorSeer Cache

FLUX.1-dev

Memory & Speed Metrics (GPU: H100, 50 steps, compiled)

Prompt Index Variant Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB)
0 baseline 5.812418 31.437537 1.162642 3.940589 11.744348 33.851628
0 baseline(steps=22) 5.445316 31.469763 0.054121 2.662160 5.271367 33.851628
0 firstblock(threshold=0.05) 5.618118 31.469763 0.053683 30.769011 8.686095 33.928777
0 taylorseer(max_order=1, cache_interval=5, disable_cache_before_step=10) 5.487217 31.469763 0.051885 59.841501 4.865684 33.852117

Visual Outputs

  • Women’s Health magazine cover, April 2025 issue, ‘Spring forward’ headline, woman in green outfit sitting on orange blocks, white sneakers, ‘Covid: five years on’ feature text, ‘15 skincare habits’ callout, professional editorial photography, magazine layout with multiple text elements
Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache
  • Soaking wet tiger cub taking shelter under a banana leaf in the rainy jungle, close up photo
Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache

Analyze TaylorSeer Configurations

Memory & Speed (no compile)

Prompt Index Variant Steps Max Order Cache Interval Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB) Speedup vs baseline_50 Speedup vs baseline_22
0 baseline_50 50 N/A N/A 5.68 31.44 0.00 2.38 15.32 33.85 1.00x 1.00x
0 baseline_22 22 N/A N/A 5.39 31.47 0.00 1.70 6.86 33.85 2.23x 1.00x
0 taylor_o0_ci5 50 0 5 5.37 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o0_ci10 50 0 10 5.38 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o0_ci15 50 0 15 5.44 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o1_ci5 50 1 5 5.67 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o1_ci8 50 1 8 5.68 31.47 0.00 1.71 4.88 33.85 3.14x 1.41x
0 taylor_o1_ci10 50 1 10 5.68 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o1_ci15 50 1 15 5.68 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o2_ci5 50 2 5 5.66 31.47 0.00 1.71 5.69 33.85 2.69x 1.21x
0 taylor_o2_ci10 50 2 10 5.73 31.47 0.00 1.70 4.64 33.85 3.30x 1.48x
0 taylor_o2_ci15 50 2 15 5.36 31.47 0.00 1.70 4.40 33.85 3.48x 1.56x

Visual Comparison (o1,ci5 means max_order=1, cache_interval=5)

image

Reproduce Code

  1. Baselines Vs. TaylorSeer variants
import torch
from diffusers import FluxPipeline, TaylorSeerCacheConfig
import time
import os
import matplotlib.pyplot as plt
import pandas as pd
import gc

# Set dynamo config
import torch._dynamo as dynamo
dynamo.config.recompile_limit = 200

prompts = [
    "Black cat hiding behind a watermelon slice, professional studio shot, bright red and turquoise background with summer mystery vibe",
]

# Create output folder
os.makedirs("outputs", exist_ok=True)

# ============================================================================
# CONFIGURATION SECTION - Easily modify these parameters
# ============================================================================

# Fixed config parameters (applied to all TaylorSeer configs)
FIXED_CONFIG = {
    'disable_cache_before_step': 10,
    'taylor_factors_dtype': torch.bfloat16,
    'use_lite_mode': True
}

# Variable parameters to test - modify these as needed
# Format: (max_order, cache_interval)
TAYLOR_CONFIGS = [
    (0, 5),   # max_order=0, cache_interval=5
    (0, 10),
    (0, 15),
    (1, 5),   # max_order=1, cache_interval=5
    (1, 8),   # max_order=1, cache_interval=6
    (1, 10),   # max_order=1, cache_interval=7
    (1, 15),  # max_order=1, cache_interval=10
    (2, 5),   # max_order=2, cache_interval=5
    (2, 10),   # max_order=2, cache_interval=6
    (2, 15),
]

# Baseline configurations
BASELINES = [
    {'name': 'baseline_50', 'steps': 50},
    {'name': 'baseline_22', 'steps': 22},
]

# Main inference steps for TaylorSeer variants
MAIN_STEPS = 50
WARMUP_STEPS = 5

# ============================================================================

# Build TaylorSeer configs
taylor_configs = {}
for max_order, cache_interval in TAYLOR_CONFIGS:
    config_name = f'taylor_o{max_order}_ci{cache_interval}'
    taylor_configs[config_name] = TaylorSeerCacheConfig(
        max_order=max_order,
        cache_interval=cache_interval,
        **FIXED_CONFIG
    )

# Collect results
results = []

for i, prompt in enumerate(prompts):
    print(f"\n{'='*80}")
    print(f"Processing Prompt {i}: {prompt[:50]}...")
    print(f"{'='*80}\n")
    
    images = {}
    baseline_times = {}
    
    # Run all baseline variants first
    for baseline_config in BASELINES:
        variant = baseline_config['name']
        num_steps = baseline_config['steps']
        
        print(f"Running {variant} (steps={num_steps})...")
        
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
        
        # Compile with timing
        start_compile = time.time()
        # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 5 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=WARMUP_STEPS,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=num_steps,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
        main_time = end_main - start_main
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Store baseline time
        baseline_times[variant] = main_time
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Steps': num_steps,
            'Max Order': 'N/A',
            'Cache Interval': 'N/A',
            'Load Time (s)': f"{load_time:.2f}",
            'Load Memory (GB)': f"{load_mem_gb:.2f}",
            'Compile Time (s)': f"{compile_time:.2f}",
            'Warmup Time (s)': f"{warmup_time:.2f}",
            'Main Time (s)': f"{main_time:.2f}",
            'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
            'Speedup vs baseline_50': '1.00x' if variant == 'baseline_50' else f"{baseline_times['baseline_50']/main_time:.2f}x",
            'Speedup vs baseline_22': '1.00x' if variant == 'baseline_22' else f"{baseline_times.get('baseline_22', main_time)/main_time:.2f}x"
        })
        
        print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
        print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB\n")
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()
        dynamo.reset()
    
    # TaylorSeer variants with different configurations
    for config_name, tsconfig in taylor_configs.items():
        variant = config_name
        max_order = tsconfig.max_order
        cache_interval = tsconfig.cache_interval
        print(f"Running {variant} (max_order={max_order}, cache_interval={cache_interval})...")
        
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
        
        # Enable TaylorSeer cache
        pipeline.transformer.enable_cache(tsconfig)
        
        # Compile with timing
        start_compile = time.time()
        # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 5 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=WARMUP_STEPS,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=MAIN_STEPS,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
        main_time = end_main - start_main
        speedup_50 = baseline_times['baseline_50'] / main_time
        speedup_22 = baseline_times['baseline_22'] / main_time
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Steps': MAIN_STEPS,
            'Max Order': max_order,
            'Cache Interval': cache_interval,
            'Load Time (s)': f"{load_time:.2f}",
            'Load Memory (GB)': f"{load_mem_gb:.2f}",
            'Compile Time (s)': f"{compile_time:.2f}",
            'Warmup Time (s)': f"{warmup_time:.2f}",
            'Main Time (s)': f"{main_time:.2f}",
            'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
            'Speedup vs baseline_50': f"{speedup_50:.2f}x",
            'Speedup vs baseline_22': f"{speedup_22:.2f}x"
        })
        
        print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
        print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB")
        print(f"  Speedup vs baseline_50: {speedup_50:.2f}x, vs baseline_22: {speedup_22:.2f}x\n")
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()
        dynamo.reset()
    
    # Plot image comparison for this prompt (select key variants)
    key_variants = ['baseline_50', 'baseline_22'] + [list(taylor_configs.keys())[j] for j in range(min(4, len(taylor_configs)))]
    num_variants = len(key_variants)
    
    fig, axs = plt.subplots(1, num_variants, figsize=(10*num_variants, 10))
    if num_variants == 1:
        axs = [axs]
    
    for j, var in enumerate(key_variants):
        if var in images:
            axs[j].imshow(images[var])
            axs[j].set_title(f"{var}", fontsize=24)
            axs[j].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"outputs/comparison_prompt{i}.png", dpi=100)
    plt.close()

# Print results table
print("\n" + "="*140)
print("BENCHMARK RESULTS")
print("="*140 + "\n")

df = pd.DataFrame(results)
print(df.to_string(index=False))

# Save results to CSV
df.to_csv("outputs/benchmark_results.csv", index=False)
print("\nResults saved to outputs/benchmark_results.csv")

# Calculate and display averages per variant
print("\n" + "="*140)
print("AVERAGE METRICS BY VARIANT")
print("="*140 + "\n")

# Convert numeric columns back to float for averaging
numeric_cols = ['Load Time (s)', 'Load Memory (GB)', 'Compile Time (s)', 
                'Warmup Time (s)', 'Main Time (s)', 'Peak Memory (GB)']

df_numeric = df.copy()
for col in numeric_cols:
    df_numeric[col] = df_numeric[col].astype(float)

# Group by variant and calculate means
avg_df = df_numeric.groupby('Variant')[numeric_cols + ['Steps']].mean()

# Add configuration info
avg_df['Max Order'] = df.groupby('Variant')['Max Order'].first()
avg_df['Cache Interval'] = df.groupby('Variant')['Cache Interval'].first()

# Calculate average speedups
speedup_50_df = df.groupby('Variant')['Speedup vs baseline_50'].apply(
    lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x"
)
speedup_22_df = df.groupby('Variant')['Speedup vs baseline_22'].apply(
    lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x"
)
avg_df['Avg Speedup vs baseline_50'] = speedup_50_df
avg_df['Avg Speedup vs baseline_22'] = speedup_22_df

# Reorder columns
avg_df = avg_df[['Steps', 'Max Order', 'Cache Interval'] + numeric_cols + 
                ['Avg Speedup vs baseline_50', 'Avg Speedup vs baseline_22']]

# Format numeric columns
avg_df['Steps'] = avg_df['Steps'].apply(lambda x: f"{x:.0f}")
for col in numeric_cols:
    avg_df[col] = avg_df[col].apply(lambda x: f"{x:.2f}")

print(avg_df.to_string())

# Create comprehensive visualizations
fig, axes = plt.subplots(2, 2, figsize=(20, 16))

# Extract data for plotting
variants = []
main_times = []
peak_memories = []
speedups_50 = []
speedups_22 = []
labels = []

for variant in df['Variant'].unique():
    variant_data = df_numeric[df_numeric['Variant'] == variant]
    variants.append(variant)
    main_times.append(variant_data['Main Time (s)'].mean())
    peak_memories.append(variant_data['Peak Memory (GB)'].mean())
    
    # Calculate average speedups
    speedup_50_values = df[df['Variant'] == variant]['Speedup vs baseline_50'].apply(
        lambda x: float(x.rstrip('x'))
    )
    speedup_22_values = df[df['Variant'] == variant]['Speedup vs baseline_22'].apply(
        lambda x: float(x.rstrip('x'))
    )
    speedups_50.append(speedup_50_values.mean())
    speedups_22.append(speedup_22_values.mean())
    
    # Create readable labels
    if 'baseline' in variant:
        labels.append(variant)
    else:
        parts = variant.split('_')
        order = parts[1].replace('o', 'O')
        ci = parts[2].replace('ci', 'CI')
        labels.append(f"{order}_{ci}")

# Assign colors
colors = ['#1f77b4', '#ff7f0e'] + ['#2ca02c', '#d62728', '#9467bd', '#8c564b', 
                                     '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] * 3
colors = colors[:len(variants)]

# Plot 1: Main Time Comparison
ax1 = axes[0, 0]
bars1 = ax1.bar(labels, main_times, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax1.set_ylabel('Main Generation Time (seconds)', fontsize=12, fontweight='bold')
ax1.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax1.set_title('Average Generation Time Comparison', fontsize=14, fontweight='bold')
ax1.grid(axis='y', alpha=0.3, linestyle='--')
ax1.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, time in zip(bars1, main_times):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{time:.2f}s', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 2: Peak Memory Comparison
ax2 = axes[0, 1]
bars2 = ax2.bar(labels, peak_memories, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax2.set_ylabel('Peak Memory Usage (GB)', fontsize=12, fontweight='bold')
ax2.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax2.set_title('Average Peak Memory Comparison', fontsize=14, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')
ax2.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, mem in zip(bars2, peak_memories):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
             f'{mem:.2f}', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 3: Speedup vs baseline_50
ax3 = axes[1, 0]
bars3 = ax3.bar(labels, speedups_50, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax3.axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Baseline (50 steps)')
ax3.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold')
ax3.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax3.set_title('Speedup vs Baseline (50 steps)', fontsize=14, fontweight='bold')
ax3.grid(axis='y', alpha=0.3, linestyle='--')
ax3.tick_params(axis='x', rotation=45)
ax3.legend()

# Add value labels on bars
for bar, speedup in zip(bars3, speedups_50):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height,
             f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 4: Speedup vs baseline_22
ax4 = axes[1, 1]
bars4 = ax4.bar(labels, speedups_22, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax4.axhline(y=1.0, color='orange', linestyle='--', linewidth=2, label='Baseline (22 steps)')
ax4.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold')
ax4.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax4.set_title('Speedup vs Baseline (22 steps)', fontsize=14, fontweight='bold')
ax4.grid(axis='y', alpha=0.3, linestyle='--')
ax4.tick_params(axis='x', rotation=45)
ax4.legend()

# Add value labels on bars
for bar, speedup in zip(bars4, speedups_22):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height,
             f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig("outputs/metrics_comparison.png", dpi=150, bbox_inches='tight')
plt.close()

print("\n" + "="*140)
print("Benchmark completed! Check the outputs/ folder for results and visualizations.")
print("="*140)
  1. Baseline, TaylorSeer, FirstBlockCache
import torch
from diffusers import FluxPipeline, TaylorSeerCacheConfig, FirstBlockCacheConfig, FasterCacheConfig
import time
import os
import matplotlib.pyplot as plt
import pandas as pd
import gc  # Added for explicit garbage collection

# Set dynamo config
import torch._dynamo as dynamo
dynamo.config.recompile_limit = 200

prompts = [
    "Soaking wet tiger cub taking shelter under a banana leaf in the rainy jungle, close up photo",
]

# Create output folder
os.makedirs("outputs", exist_ok=True)

# Define cache configs
fbconfig = FirstBlockCacheConfig(
    threshold=0.05
)

tsconfig = TaylorSeerCacheConfig(
    cache_interval=5,
    max_order=1,
    disable_cache_before_step=10,
    disable_cache_after_step=48,
    taylor_factors_dtype=torch.bfloat16,
    use_lite_mode=True
)

# Collect results
results = []

for i, prompt in enumerate(prompts):
    images = {}
    for variant in ['baseline', 'baseline_reduce', 'firstblock', 'taylor']:
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
        
        # Enable cache if applicable
        if variant == 'firstblock':
            pipeline.transformer.enable_cache(fbconfig)
        elif variant == 'taylor':
            pipeline.transformer.enable_cache(tsconfig)
        # No cache for baseline and baseline_reduce
        
        # Compile with timing
        start_compile = time.time()
        pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 10 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=5,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        steps = 22 if variant == 'baseline_reduce' else 50
        
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=steps,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
        main_time = end_main - start_main
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Load Time (s)': load_time,
            'Load Memory (GB)': load_mem_gb,
            'Compile Time (s)': compile_time,
            'Warmup Time (s)': warmup_time,
            'Main Time (s)': main_time,
            'Peak Memory (GB)': peak_mem_gb
        })
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()  # Force garbage collection
        torch.cuda.empty_cache()  # Empty CUDA cache after GC
        dynamo.reset()  # Reset Dynamo cache (harmless even if not compiling)

    # Plot image comparison for this prompt
    fig, axs = plt.subplots(1, 4, figsize=(40, 10))
    variants_order = ['baseline', 'baseline_reduce', 'firstblock', 'taylor']
    for j, var in enumerate(variants_order):
        axs[j].imshow(images[var])
        axs[j].set_title(var)
        axs[j].axis('off')
    plt.tight_layout()
    plt.savefig(f"outputs/comparison_prompt{i}.png")
    plt.close()

# Print speed and memory comparison as a table
df = pd.DataFrame(results)
print("Speed and Memory Comparison:")
print(df.to_string(index=False))

# Optionally, plot bar charts for averages
avg_df = df.groupby('Variant').mean().reset_index()
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.bar(avg_df['Variant'], avg_df['Main Time (s)'], color='b', label='Main Time (s)')
ax1.set_ylabel('Main Time (s)')
ax2 = ax1.twinx()
ax2.plot(avg_df['Variant'], avg_df['Peak Memory (GB)'], color='r', marker='o', label='Peak Memory (GB)')
ax2.set_ylabel('Peak Memory (GB)')
fig.suptitle('Average Speed and Memory Comparison')
fig.legend()
plt.savefig("outputs/metrics_comparison.png")
plt.close()

@sayakpaul
Copy link
Member

graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:

Yes, increasing the compile limit is fine here.

Some questions / notes:

  • In the code snippet provided in [Feat] TaylorSeer Cache #12648 (comment), why do we need dynamic=True?
  • Could we also add the compilation timing in here to see if that helps at all (especially with the recompilations)?
  • Let's try to add this comparison (just a link your comment is fine) in the docs? I think this is golden information!

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2025

Style bot fixed some files and pushed the changes.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @toilaluan. Small update to docstring and we should be good to merge once tests pass.

toilaluan and others added 2 commits December 5, 2025 14:26
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@toilaluan
Copy link
Contributor Author

@DN6 @sayakpaul (cc @Shenyi-Z) I added more intensive comparison for TaylorSeer in #12648 (comment)

Interesting finding is that increasing max_order doesn't give better result, while max_order=0 (means we reuse v predicted from timestep i for next N steps still work well but i think max_order=1 is the optimal

It also includes complication timing

@Shenyi-Z
Copy link

Shenyi-Z commented Dec 5, 2025

@DN6 @sayakpaul (cc @Shenyi-Z) I added more intensive comparison for TaylorSeer in #12648 (comment)

Interesting finding is that increasing max_order doesn't give better result, while max_order=0 (means we reuse v predicted from timestep i for next N steps still work well but i think max_order=1 is the optimal

It also includes complication timing

This is basically consistent with our recent experimental results. This is because TaylorSeer essentially only simulates derivatives from differences; it is actually just a simple prediction method. Excessively high orders are not stable enough, which naturally introduces more numerical errors, making the model's predictions overly arbitrary. On the other hand, the zeroth-order simple reuse results in noticeable lack of detail. In practical applications, TaylorSeer with order=1 often achieves relatively stable performance.

@sayakpaul
Copy link
Member

@toilaluan I opened a PR here: toilaluan#1. I believe merging that should help us fix the CI issues.

@toilaluan
Copy link
Contributor Author

I merged it

@sayakpaul
Copy link
Member

Tests failing seem related. @toilaluan anything we're missing?

@toilaluan
Copy link
Contributor Author

@sayakpaul Sorry I can't find where to look into, any suggestion?

@sayakpaul
Copy link
Member

@toilaluan
Copy link
Contributor Author

@sayakpaul It fails due to this issue #12648 (comment)

Removing test on flux kontext will help

@sayakpaul sayakpaul merged commit 6290fdf into huggingface:main Dec 6, 2025
26 of 28 checks passed
@sayakpaul
Copy link
Member

An immense amount of thanks for shipping this! We will get back to you for the MVP stuff!

@toilaluan
Copy link
Contributor Author

🤗

@Trgtuan10
Copy link
Contributor

I love you @toilaluan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants