diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md
index 940144538a35..fa6c2a82ead2 100644
--- a/docs/source/en/api/pipelines/ltx_video.md
+++ b/docs/source/en/api/pipelines/ltx_video.md
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
+ - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
@@ -414,6 +414,91 @@ export_to_video(video, "output.mp4", fps=24)
+
+ Long image-to-video generation with multi-prompt sliding windows (ComfyUI parity)
+
+ ```py
+ import torch
+ from diffusers import LTXI2VLongMultiPromptPipeline, LTXLatentUpsamplePipeline
+ from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
+ from diffusers.utils import export_to_video
+ from PIL import Image
+
+
+ # Stage A: long I2V with sliding windows and multi-prompt scheduling
+ pipe = LTXI2VLongMultiPromptPipeline.from_pretrained(
+ "LTX-Video-0.9.8-13B-distilled",
+ torch_dtype=torch.bfloat16
+ ).to("cuda")
+
+ schedule = "a chimpanzee walks in the jungle |a chimpanzee stops and eats a snack |a chimpanzee lays on the ground"
+ cond_image = Image.open("chimpanzee_l.jpg").convert("RGB")
+
+ latents = pipe(
+ prompt=schedule,
+ negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
+ width=768,
+ height=512, # must be divisible by 32
+ num_frames=361,
+ temporal_tile_size=120,
+ temporal_overlap=32,
+ sigmas=[1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0],
+ guidance_scale=1.0, # distilled variants typically use 1.0
+ cond_image=cond_image, # hard-conditions the first frame
+ adain_factor=0.25, # cross-window normalization
+ output_type="latent", # return latent-space video for downstream processing
+ ).frames
+
+ # Optional: decode with VAE tiling
+ video_pil = pipe.vae_decode_tiled(latents, decode_timestep=0.05, decode_noise_scale=0.025, output_type="pil")[0]
+ export_to_video(video_pil, "ltx_i2v_long_base.mp4", fps=24)
+
+ # Stage B (optional): spatial latent upsampling + short refinement
+ upsampler = LTXLatentUpsamplerModel.from_pretrained("LTX-Video-spatial-upscaler-0.9.8/latent_upsampler", torch_dtype=torch.bfloat16)
+ pipe_upsample = LTXLatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler).to(torch.bfloat16).to("cuda")
+
+ up_latents = pipe_upsample(
+ latents=latents,
+ adain_factor=1.0,
+ tone_map_compression_ratio=0.6,
+ output_type="latent"
+ ).frames
+ try:
+ pipe.load_lora_weights(
+ "LTX-Video-ICLoRA-detailer-13b-0.9.8/ltxv-098-ic-lora-detailer-diffusers.safetensors",
+ adapter_name="ic-detailer",
+ )
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
+ print("[Info] IC-LoRA detailer adapter loaded and fused.")
+ except Exception as e:
+ print(f"[Warn] Failed to load IC-LoRA: {e}. Skipping the second refinement sampling.")
+
+ # Short refinement pass (distilled; low steps)
+ frames_refined = pipe(
+ negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
+ width=768,
+ height=512,
+ num_frames=361,
+ temporal_tile_size=80,
+ temporal_overlap=24,
+ seed=1625,
+ adain_factor=0.0, # disable AdaIN in refinement
+ latents=up_latents, # start from upscaled latents
+ guidance_latents=up_latents,
+ sigmas=[0.99, 0.9094, 0.0], # short sigma schedule
+ output_type="pil",
+ ).frames[0]
+
+ export_to_video(frames_refined, "ltx_i2v_long_refined.mp4", fps=24)
+ ```
+
+ Notes:
+ - Seeding: window-local hard-condition noise uses `seed + w_start` when `seed` is provided; otherwise the passed-in `generator` drives stochasticity.
+ - Height/width must be divisible by 32; latent shapes follow the pipeline docstrings.
+ - Use VAE tiled decoding to avoid OOM for high resolutions or long sequences.
+ - Distilled variants generally prefer `guidance_scale=1.0` and short schedules for refinement.
+
+
- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].
@@ -474,6 +559,12 @@ export_to_video(video, "output.mp4", fps=24)
+## LTXI2VLongMultiPromptPipeline
+
+[[autodoc]] LTXI2VLongMultiPromptPipeline
+ - all
+ - __call__
+
## LTXPipeline
[[autodoc]] LTXPipeline
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 572aad4bd3f1..d65a74c812fa 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -514,6 +514,7 @@
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
+ "LTXI2VLongMultiPromptPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -1191,6 +1192,7 @@
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
+ LTXI2VLongMultiPromptPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 87d953845e21..e72f47ceb7eb 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -285,6 +285,7 @@
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
+ "LTXI2VLongMultiPromptPipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
@@ -691,7 +692,7 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
- from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
+ from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline,LTXI2VLongMultiPromptPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py
index 6001867916b3..e2979e4f3bcf 100644
--- a/src/diffusers/pipelines/ltx/__init__.py
+++ b/src/diffusers/pipelines/ltx/__init__.py
@@ -27,6 +27,7 @@
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
+ _import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -41,6 +42,7 @@
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
+ from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
new file mode 100644
index 000000000000..59e3fe3e7199
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
@@ -0,0 +1,1233 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+LTXI2VLongMultiPromptPipeline
+
+Overview:
+- Long-duration Image-to-Video (I2V) pipeline with multi-prompt segmentation and temporal sliding windows.
+- ComfyUI parity: window stride (tile_size - overlap), first-frame hard conditioning via per-token mask,
+ optional AdaIN cross-window normalization, "negative index" and guidance latent injection,
+ and VAE tiled decoding to control VRAM without spatial sharding during denoising.
+
+Key components:
+- scheduler: FlowMatchEulerDiscreteScheduler
+- vae: AutoencoderKLLTXVideo (supports optional timestep_conditioning)
+- text_encoder: T5EncoderModel
+- tokenizer: T5TokenizerFast
+- transformer: LTXVideoTransformer3DModel
+
+Seeding:
+- Per-window hard-condition initialization noise uses a window-local seed when `seed` is provided:
+ local_seed = seed + w_start (w_start is the latent-frame start index of the window).
+- When `seed` is None, the passed-in `generator` (if any) drives stochasticity.
+- Global initial latents are sampled with randn_tensor using the (normalized) `generator`.
+
+I/O and shapes:
+- Decoded output has `num_frames` frames at `height x width`.
+- Latent sizes for internal sampling:
+ F_lat = (num_frames - 1) // vae_temporal_compression_ratio + 1
+ H_lat = height // vae_spatial_compression_ratio
+ W_lat = width // vae_spatial_compression_ratio
+- Height and width must be divisible by 32.
+
+Usage:
+ >>> from diffusers import LTXI2VLongMultiPromptPipeline
+ >>> import torch
+ >>> pipe = LTXI2VLongMultiPromptPipeline.from_pretrained("LTX-Video-0.9.8-13B-distilled")
+ >>> pipe = pipe.to("cuda").to(dtype=torch.bfloat16)
+ >>> # Example A: get decoded frames (PIL)
+ >>> out = pipe(prompt="a chimpanzee walks | a chimpanzee eats", num_frames=161, height=512, width=704,
+ ... temporal_tile_size=80, temporal_overlap=24, output_type="pil", return_dict=True)
+ >>> frames = out.frames[0] # list of PIL.Image.Image
+ >>> # Example B: get latent video and decode later (saves VRAM during sampling)
+ >>> out_latent = pipe(prompt="a chimpanzee walking", output_type="latent", return_dict=True).frames
+ >>> frames = pipe.vae_decode_tiled(out_latent, output_type="pil")[0]
+
+See also:
+- LTXPipeline.encode_prompt, LTXPipeline._pack_latents, LTXPipeline._unpack_latents,
+ rescale_noise_cfg, FlowMatchEulerDiscreteScheduler.set_timesteps/step.
+"""
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import numpy as np
+import copy
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+from einops import rearrange
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTXVideo
+from ...models.transformers import LTXVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LTXPipelineOutput
+from .pipeline_ltx import LTXPipeline, rescale_noise_cfg
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_latent_coords(
+ latent_num_frames, latent_height, latent_width, batch_size, device, rope_interpolation_scale, latent_idx
+):
+ """
+ Compute latent patch top-left coordinates in (t, y, x) order.
+
+ Args:
+ latent_num_frames: int. Number of latent frames (T_lat).
+ latent_height: int. Latent height (H_lat).
+ latent_width: int. Latent width (W_lat).
+ batch_size: int. Batch dimension (B).
+ device: torch.device for the resulting tensor.
+ rope_interpolation_scale: tuple[int|float, int|float, int|float]. Scale per (t, y, x) latent step to pixel coords.
+ latent_idx: Optional[int]. When not None, shifts the time coordinate to align segments:
+ - <= 0 uses step multiples of rope_interpolation_scale[0]
+ - > 0 starts at 1 then increments by rope_interpolation_scale[0]
+
+ Returns:
+ Tensor of shape [B, 3, T_lat * H_lat * W_lat] containing top-left coordinates per latent patch,
+ repeated for each batch element.
+ """
+ latent_sample_coords = torch.meshgrid(
+ torch.arange(0, latent_num_frames, 1, device=device),
+ torch.arange(0, latent_height, 1, device=device),
+ torch.arange(0, latent_width, 1, device=device),
+ indexing="ij",
+ )
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+ latent_coords = rearrange(latent_coords, "b c f h w -> b c (f h w)", b=batch_size)
+ pixel_coords = latent_coords * torch.tensor(rope_interpolation_scale, device=latent_coords.device)[None, :, None]
+ if latent_idx is not None:
+ if latent_idx <= 0:
+ frame_idx = latent_idx * rope_interpolation_scale[0]
+ else:
+ frame_idx = 1 + (latent_idx - 1) * rope_interpolation_scale[0]
+ if frame_idx == 0:
+ pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - rope_interpolation_scale[0]).clamp(min=0)
+ pixel_coords[:, 0] += frame_idx
+ return pixel_coords
+
+
+class LTXI2VLongMultiPromptPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ """
+ Long-duration I2V (image-to-video) multi-prompt pipeline with ComfyUI parity.
+
+ Key features:
+ - Temporal sliding-window sampling only (no spatial H/W sharding); autoregressive fusion across windows.
+ - Multi-prompt segmentation per window with smooth transitions at window heads.
+ - First-frame hard conditioning via per-token mask for I2V.
+ - VRAM control via temporal windowing and VAE tiled decoding.
+
+ Components:
+ - scheduler: FlowMatchEulerDiscreteScheduler
+ - vae: AutoencoderKLLTXVideo
+ - text_encoder: T5EncoderModel
+ - tokenizer: T5TokenizerFast
+ - transformer: LTXVideoTransformer3DModel
+
+ Reused Diffusers components:
+ - LTXPipeline.encode_prompt() (diffusers/src/diffusers/pipelines/ltx/pipeline_ltx.py:283)
+ - LTXPipeline._pack_latents() (diffusers/src/diffusers/pipelines/ltx/pipeline_ltx.py:419)
+ - LTXPipeline._unpack_latents() (diffusers/src/diffusers/pipelines/ltx/pipeline_ltx.py:443)
+ - LTXImageToVideoPipeline.prepare_latents() cond_mask semantics (diffusers/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py:502)
+ - FlowMatchEulerDiscreteScheduler.set_timesteps() (diffusers/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py:249)
+ - FlowMatchEulerDiscreteScheduler.step() (diffusers/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py:374)
+ - rescale_noise_cfg() (diffusers/src/diffusers/pipelines/ltx/pipeline_ltx.py:144)
+
+ Defaults:
+ - default_height=512, default_width=704, default_frames=121
+
+ Example:
+ >>> from diffusers import LTXI2VLongMultiPromptPipeline
+ >>> pipe = LTXI2VLongMultiPromptPipeline.from_pretrained("LTX-Video-0.9.8-13B-distilled")
+ >>> pipe = pipe.to("cuda").to(dtype=torch.bfloat16)
+ >>> out = pipe(prompt="a chimpanzee walking", num_frames=161, height=512, width=704, output_type="pil")
+ >>> frames = out.frames[0] # list of PIL frames
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ self.default_height = 512
+ self.default_width = 704
+ self.default_frames = 121
+ self._current_tile_T = None
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Tokenize and encode prompts using the T5 encoder.
+
+ Args:
+ prompt: str or list[str]. If str, it is internally wrapped as a single-element list.
+ num_videos_per_prompt: number of generations per prompt; embeddings are duplicated accordingly.
+ max_sequence_length: tokenizer max length; longer inputs are truncated with a warning.
+ device: optional device override; defaults to the pipeline execution device.
+ dtype: optional dtype override for embeddings; defaults to text_encoder dtype.
+
+ Returns:
+ - prompt_embeds: Tensor of shape [B * num_videos_per_prompt, seq_len, dim]
+ - prompt_attention_mask: Bool Tensor of shape [B * num_videos_per_prompt, seq_len]
+
+ Notes:
+ - Truncation: if inputs exceed `max_sequence_length`, the overflow part is removed and a warning is logged.
+ - Embeddings are duplicated per `num_videos_per_prompt` to match generation count.
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def _split_into_temporal_windows(
+ self,
+ latent_len: int,
+ temporal_tile_size: int,
+ temporal_overlap: int,
+ compression: int,
+ ) -> List[Tuple[int, int]]:
+ """
+ Split latent frames into sliding windows.
+
+ Args:
+ latent_len: int. Number of latent frames (T_lat).
+ temporal_tile_size: int. Window size in latent frames (> 0).
+ temporal_overlap: int. Overlap between windows in latent frames (>= 0).
+ compression: int. VAE temporal compression ratio (unused here; kept for parity).
+
+ Returns:
+ list[tuple[int, int]]: inclusive-exclusive (start, end) indices per window.
+
+ Notes:
+ - Mirrors ComfyUI stride = tile_size - overlap.
+ - Windows do not exceed [0, latent_len).
+ """
+ if temporal_tile_size <= 0:
+ raise ValueError("temporal_tile_size must be > 0")
+ stride = max(temporal_tile_size - temporal_overlap, 1)
+ windows = []
+ start = 0
+ while start < latent_len:
+ end = min(start + temporal_tile_size, latent_len)
+ windows.append((start, end))
+ if end == latent_len:
+ break
+ start = start + stride
+ return windows
+
+ # [Deprecated] Spatial tiling weights removed: sampling performs full-frame prediction without H/W sharding.
+
+ def _linear_overlap_fuse(self, prev: torch.Tensor, new: torch.Tensor, overlap: int) -> torch.Tensor:
+ """
+ Temporal linear crossfade between two latent clips over the overlap region.
+
+ Args:
+ prev: Tensor [B, C, F, H, W]. Previous output segment.
+ new: Tensor [B, C, F, H, W]. New segment to be appended.
+ overlap: int. Number of frames to crossfade (overlap <= 1 concatenates without blend).
+
+ Returns:
+ Tensor [B, C, F_prev + F_new - overlap, H, W] after crossfade at the seam.
+
+ Notes:
+ - Crossfade weights are linear in time from 1→0 (prev) and 0→1 (new).
+ """
+ if overlap <= 1:
+ return torch.cat([prev, new], dim=2)
+ alpha = torch.linspace(1, 0, overlap + 2, device=prev.device, dtype=prev.dtype)[1:-1]
+ shape = [1] * prev.ndim
+ shape[2] = alpha.size(0)
+ alpha = alpha.reshape(shape)
+ blended = alpha * prev[:, :, -overlap:] + (1 - alpha) * new[:, :, :overlap]
+ return torch.cat([prev[:, :, :-overlap], blended, new[:, :, overlap:]], dim=2)
+
+ def _adain_normalize_latents(
+ self,
+ curr_latents: torch.Tensor,
+ ref_latents: Optional[torch.Tensor],
+ factor: float,
+ ) -> torch.Tensor:
+ """
+ Optional AdaIN normalization: channel-wise mean/variance matching of curr_latents to ref_latents, controlled by factor.
+
+ Args:
+ curr_latents: Tensor [B, C, T, H, W]. Current window latents.
+ ref_latents: Optional[Tensor] [B, C, T_ref, H, W]. Reference latents (e.g., first window) used to compute target stats.
+ factor: float in [0, 1]. 0 keeps current stats; 1 matches reference stats.
+
+ Returns:
+ Tensor with per-channel mean/std blended towards the reference.
+
+ Details:
+ - Statistics are computed over (T, H, W), i.e., per-channel mean/variance.
+ - Output = (curr - mu_curr) / sigma_curr * sigma_blend + mu_blend.
+ - If ref is None or factor <= 0, returns curr_latents unchanged.
+
+ ComfyUI parity:
+ - Matches LTXVAdainLatent.batch_normalize behavior (ComfyUI-LTXVideo/easy_samplers.py:449).
+ """
+ if ref_latents is None or factor is None or factor <= 0:
+ return curr_latents
+
+ eps = torch.tensor(1e-6, device=curr_latents.device, dtype=curr_latents.dtype)
+
+ # Compute per-channel means/stds for current and reference over (T, H, W)
+ mu_curr = curr_latents.mean(dim=(2, 3, 4), keepdim=True)
+ sigma_curr = curr_latents.std(dim=(2, 3, 4), keepdim=True)
+
+ mu_ref = ref_latents.mean(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype)
+ sigma_ref = ref_latents.std(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype)
+
+ # Blend target statistics
+ mu_blend = (1.0 - float(factor)) * mu_curr + float(factor) * mu_ref
+ sigma_blend = (1.0 - float(factor)) * sigma_curr + float(factor) * sigma_ref
+ sigma_blend = torch.clamp(sigma_blend, min=float(eps))
+
+ # Apply AdaIN
+ curr_norm = (curr_latents - mu_curr) / (sigma_curr + eps)
+ return curr_norm * sigma_blend + mu_blend
+
+ def _inject_prev_tail_latents(
+ self,
+ window_latents: torch.Tensor,
+ prev_tail_latents: Optional[torch.Tensor],
+ window_cond_mask_5d: torch.Tensor,
+ overlap_lat: int,
+ strength: Optional[float],
+ prev_overlap_len: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Inject the tail latents from the previous window at the beginning of the current window (first k frames),
+ where k = min(overlap_lat, T_curr, T_prev_tail).
+
+ Args:
+ window_latents: Tensor [B, C, T, H, W]. Current window latents.
+ prev_tail_latents: Optional[Tensor] [B, C, T_prev, H, W]. Tail segment from the previous window.
+ window_cond_mask_5d: Tensor [B, 1, T, H, W]. Per-token conditioning mask (1 = free, 0 = hard condition).
+ overlap_lat: int. Number of latent frames to inject from the previous tail.
+ strength: Optional[float] in [0, 1]. Blend strength; 1.0 replaces, 0.0 keeps original.
+ prev_overlap_len: int. Accumulated overlap length so far (used for trimming later).
+
+ Returns:
+ Tuple[Tensor, Tensor, int]: (updated_window_latents, updated_cond_mask, updated_prev_overlap_len)
+
+ ComfyUI parity:
+ - Matches LTXVExtendSampler.sample() behavior by injecting last_overlap_latents at the new window start (latent_idx=0)
+ (ComfyUI-LTXVideo/easy_samplers.py:379-390).
+
+ Notes:
+ - This is an initialization prior to denoising; only the first k frames of window_latents are modified.
+ - Reuses input dtype/device; pure weighted ops avoid numerical instability.
+ """
+ if prev_tail_latents is None or overlap_lat <= 0 or strength is None or strength <= 0:
+ return window_latents, window_cond_mask_5d, prev_overlap_len
+
+ # Expected shape: [B, C, T, H, W]
+ T = int(window_latents.shape[2])
+ k = min(int(overlap_lat), T, int(prev_tail_latents.shape[2]))
+ if k <= 0:
+ return window_latents, window_cond_mask_5d, prev_overlap_len
+
+ tail = prev_tail_latents[:, :, -k:]
+ mask = torch.full(
+ (window_cond_mask_5d.shape[0], 1, tail.shape[2], window_cond_mask_5d.shape[3], window_cond_mask_5d.shape[4]),
+ 1.0 - strength,
+ dtype=window_cond_mask_5d.dtype,
+ device=window_cond_mask_5d.device,
+ )
+
+ window_latents = torch.cat([window_latents, tail], dim=2)
+ window_cond_mask_5d = torch.cat([window_cond_mask_5d, mask], dim=2)
+ return window_latents, window_cond_mask_5d, prev_overlap_len + k
+
+ def _build_video_coords_for_window(
+ self,
+ latents: torch.Tensor,
+ overlap_len: int,
+ guiding_len: int,
+ negative_len: int,
+ rope_interpolation_scale: torch.Tensor,
+ frame_rate: int,
+ ) -> torch.Tensor:
+ """
+ Build video_coords: [B, 3, S] with order [t, y, x].
+
+ Args:
+ latents: Tensor [B, C, T, H, W]. Current window latents (before any trimming).
+ overlap_len: int. Number of frames from previous tail injected at the head.
+ guiding_len: int. Number of guidance frames appended at the head.
+ negative_len: int. Number of negative-index frames appended at the head (typically 1 or 0).
+ rope_interpolation_scale: tuple[int|float, int|float, int|float]. Scale for (t, y, x).
+ frame_rate: int. Used to convert time indices into seconds (t /= frame_rate).
+
+ Returns:
+ Tensor [B, 3, T*H*W] of fractional pixel coordinates per latent patch.
+
+ Notes:
+ - Base grid matches the transformer default (flatten t-h-w; w is fastest).
+ - For the head segments (overlap → guiding → negative), time coordinates are adjusted accordingly and scaled by rope_interpolation_scale.
+ """
+
+ b, c, f, h, w = latents.shape
+ pixel_coords = get_latent_coords(f, h, w, b, latents.device, rope_interpolation_scale, 0)
+ replace_corrds = []
+ if overlap_len > 0:
+ replace_corrds.append(get_latent_coords(overlap_len, h, w, b, latents.device, rope_interpolation_scale, 0))
+ if guiding_len > 0:
+ replace_corrds.append(get_latent_coords(guiding_len, h, w, b, latents.device, rope_interpolation_scale, overlap_len))
+ if negative_len > 0:
+ replace_corrds.append(get_latent_coords(negative_len, h, w, b, latents.device, rope_interpolation_scale, -1))
+ if len(replace_corrds) > 0:
+ replace_corrds = torch.cat(replace_corrds, axis=2)
+ pixel_coords[:, :, -replace_corrds.shape[2] :] = replace_corrds
+ fractional_coords = pixel_coords.to(torch.float32)
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
+ return fractional_coords
+
+ def _filter_sigmas_by_threshold(
+ self, sigmas: Union[List[float], torch.Tensor], threshold: float
+ ) -> Union[List[float], torch.Tensor]:
+ """
+ Drop sigma steps below the given threshold; preserves order.
+
+ Args:
+ sigmas: list[float] | Tensor. Sigma schedule.
+ threshold: float. Minimum sigma to keep.
+
+ Returns:
+ list[float] | Tensor filtered by threshold.
+
+ Notes:
+ - Mirrors Comfy STGGuiderAdvanced `skip_steps_sigma_threshold` semantics.
+ """
+ if threshold is None:
+ return sigmas
+ if isinstance(sigmas, torch.Tensor):
+ return sigmas[sigmas >= threshold]
+ return [s for s in sigmas if s >= threshold]
+
+ def _parse_prompt_segments(
+ self, prompt: Union[str, List[str]], prompt_segments: Optional[List[Dict[str, Any]]]
+ ) -> List[str]:
+ """
+ Return a list of positive prompts per window index.
+
+ Args:
+ prompt: str | list[str]. If str contains '|', parts are split by bars and trimmed.
+ prompt_segments: list[dict], optional. Each dict with {"start_window", "end_window", "text"} overrides prompts per window.
+
+ Returns:
+ list[str] containing the positive prompt for each window index.
+ """
+ if prompt is None:
+ return []
+ if prompt_segments:
+ max_w = 0
+ for seg in prompt_segments:
+ max_w = max(max_w, int(seg.get("end_window", 0)))
+ texts = [""] * (max_w + 1)
+ for seg in prompt_segments:
+ s = int(seg.get("start_window", 0))
+ e = int(seg.get("end_window", s))
+ txt = seg.get("text", "")
+ for w in range(s, e + 1):
+ texts[w] = txt
+ # fill empty by last non-empty
+ last = ""
+ for i in range(len(texts)):
+ if texts[i] == "":
+ texts[i] = last
+ else:
+ last = texts[i]
+ return texts
+
+ # bar-split mode
+ if isinstance(prompt, str):
+ parts = [p.strip() for p in prompt.split("|")]
+ else:
+ parts = prompt
+ parts = [p for p in parts if p is not None]
+ return parts
+
+ def batch_normalize(self, latents, reference, factor):
+ """
+ Batch AdaIN-like normalization for latents in dict format (ComfyUI-compatible).
+ Args:
+ latents: dict containing "samples" shaped [B, C, F, H, W]
+ reference: dict containing "samples" used to compute target stats
+ factor: float in [0, 1]; 0 = no change, 1 = full match to reference
+ Returns:
+ Tuple[dict]: a single-element tuple with the updated latents dict.
+ Notes:
+ - Operates channel-wise across all temporal and spatial dims (T, H, W).
+ - Uses torch.std_mean for numerical stability.
+ - Final samples are linearly interpolated with the original via `factor`.
+ """
+ latents_copy = copy.deepcopy(latents)
+ t = latents_copy["samples"] # B x C x F x H x W
+
+ for i in range(t.size(0)): # batch
+ for c in range(t.size(1)): # channel
+ r_sd, r_mean = torch.std_mean(reference["samples"][i, c], dim=None) # index by original dim order
+ i_sd, i_mean = torch.std_mean(t[i, c], dim=None)
+
+ t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean
+
+ latents_copy["samples"] = torch.lerp(latents["samples"], t, factor)
+ return (latents_copy,)
+
+ def apply_cfg(
+ self,
+ noise_pred: torch.Tensor,
+ noise_pred_uncond: Optional[torch.Tensor],
+ noise_pred_text: Optional[torch.Tensor],
+ guidance_scale: float,
+ guidance_rescale: float,
+ ) -> torch.Tensor:
+ """
+ Unified classifier-free guidance (CFG) composition.
+
+ Args:
+ noise_pred: Tensor. Base prediction to return if CFG components are None.
+ noise_pred_uncond: Optional[Tensor]. Unconditional prediction.
+ noise_pred_text: Optional[Tensor]. Text-conditional prediction.
+ guidance_scale: float. CFG scale w.
+ guidance_rescale: float. Optional rescale to mitigate overexposure (see `rescale_noise_cfg`).
+
+ Returns:
+ Tensor: combined = uncond + w * (text - uncond), optionally rescaled.
+ """
+ if noise_pred_uncond is None or noise_pred_text is None:
+ return noise_pred
+ combined = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if guidance_rescale > 0:
+ combined = rescale_noise_cfg(combined, noise_pred_text, guidance_rescale=guidance_rescale)
+ return combined
+
+ @torch.no_grad()
+ def vae_decode_tiled(
+ self,
+ latents: torch.Tensor,
+ decode_timestep: Optional[float] = None,
+ decode_noise_scale: Optional[float] = None,
+ horizontal_tiles: int = 4,
+ vertical_tiles: int = 4,
+ overlap: int = 3,
+ last_frame_fix: bool = True,
+ generator: Optional[torch.Generator] = None,
+ output_type: str = "pt",
+ auto_denormalize: bool = True,
+ compute_dtype: torch.dtype = torch.float32,
+ ) -> Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]:
+ """
+ VAE-based spatial tiled decoding (ComfyUI parity) implemented in Diffusers style.
+ - Linearly feather and blend overlapping tiles to avoid seams.
+ - Optional last_frame_fix: duplicate the last latent frame before decoding, then drop time_scale_factor frames at the end.
+ - Supports timestep_conditioning and decode_noise_scale injection.
+ - By default, "normalized latents" (the denoising output) are de-normalized internally (auto_denormalize=True).
+ - Tile fusion is computed in compute_dtype (float32 by default) to reduce blur and color shifts.
+
+ Args:
+ latents: [B, C_latent, F_latent, H_latent, W_latent]
+ decode_timestep: Optional decode timestep (effective only if VAE supports timestep_conditioning)
+ decode_noise_scale: Optional decode noise interpolation (effective only if VAE supports timestep_conditioning)
+ horizontal_tiles, vertical_tiles: Number of tiles horizontally/vertically (>= 1)
+ overlap: Overlap in latent space (in latent pixels, >= 0)
+ last_frame_fix: Whether to enable the "repeat last frame" fix
+ generator: Random generator (used for decode_noise_scale noise)
+ output_type: "latent" | "pt" | "np" | "pil"
+ - "latent": return latents unchanged (useful for downstream processing)
+ - "pt": return tensor in VAE output space
+ - "np"/"pil": post-processed outputs via VideoProcessor.postprocess_video
+ auto_denormalize: If True, apply LTX de-normalization to `latents` internally (recommended)
+ compute_dtype: Precision used during tile fusion (float32 default; significantly reduces seam blur)
+
+ Returns:
+ - If output_type="latent": returns input `latents` unchanged
+ - If output_type="pt": returns [B, C, F, H, W] (values roughly in [-1, 1])
+ - If output_type="np"/"pil": returns post-processed outputs via postprocess_video
+ """
+ if output_type =='latent':
+ return latents
+ if horizontal_tiles < 1 or vertical_tiles < 1:
+ raise ValueError("horizontal_tiles and vertical_tiles must be >= 1")
+ overlap = max(int(overlap), 0)
+
+ # Device and precision
+ device = self._execution_device
+ latents = latents.to(device=device, dtype=compute_dtype)
+
+ # De-normalize to VAE space (avoid color artifacts)
+ if auto_denormalize:
+ latents = LTXPipeline._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ # dtype required for VAE forward pass
+ latents = latents.to(dtype=self.vae.dtype)
+
+ # Temporal/spatial upscaling ratios (parity with ComfyUI's downscale_index_formula)
+ tsf = int(self.vae_temporal_compression_ratio)
+ sf = int(self.vae_spatial_compression_ratio)
+
+ # Optional: last_frame_fix (repeat last latent frame)
+ if last_frame_fix:
+ latents = torch.cat([latents, latents[:, :, -1:].contiguous()], dim=2)
+
+ b, c_lat, f_lat, h_lat, w_lat = latents.shape
+ f_out = 1 + (f_lat - 1) * tsf
+ h_out = h_lat * sf
+ w_out = w_lat * sf
+
+ # timestep_conditioning + decode-time noise injection (aligned with pipeline)
+ if getattr(self.vae.config, "timestep_conditioning", False):
+ dt = float(decode_timestep) if decode_timestep is not None else 0.0
+ vt = torch.tensor([dt], device=device, dtype=latents.dtype)
+ if decode_noise_scale is not None:
+ dns = torch.tensor([float(decode_noise_scale)], device=device, dtype=latents.dtype)[:, None, None, None, None]
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ latents = (1 - dns) * latents + dns * noise
+ else:
+ vt = None
+
+
+ # Compute base tile sizes (in latent space)
+ base_tile_h = (h_lat + (vertical_tiles - 1) * overlap) // vertical_tiles
+ base_tile_w = (w_lat + (horizontal_tiles - 1) * overlap) // horizontal_tiles
+
+ output: Optional[torch.Tensor] = None # [B, C_img, F, H, W], fused using compute_dtype
+ weights: Optional[torch.Tensor] = None # [B, 1, F, H, W], fused using compute_dtype
+
+ # Iterate tiles in latent space (no temporal tiling)
+ for v in range(vertical_tiles):
+ for h in range(horizontal_tiles):
+ h_start = h * (base_tile_w - overlap)
+ v_start = v * (base_tile_h - overlap)
+
+ h_end = min(h_start + base_tile_w, w_lat) if h < horizontal_tiles - 1 else w_lat
+ v_end = min(v_start + base_tile_h, h_lat) if v < vertical_tiles - 1 else h_lat
+
+ # Slice latent tile and decode
+ tile_latents = latents[:, :, :, v_start:v_end, h_start:h_end]
+ decoded_tile = self.vae.decode(tile_latents, vt, return_dict=False)[0] # [B, C, F, Ht, Wt]
+ # Cast to high precision to reduce blending blur
+ decoded_tile = decoded_tile.to(dtype=compute_dtype)
+
+ # Initialize output buffers (compute_dtype)
+ if output is None:
+ output = torch.zeros(
+ (b, decoded_tile.shape[1], f_out, h_out, w_out),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+ weights = torch.zeros(
+ (b, 1, f_out, h_out, w_out),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+
+ # Tile placement in output pixel space
+ out_h_start = v_start * sf
+ out_h_end = v_end * sf
+ out_w_start = h_start * sf
+ out_w_end = h_end * sf
+
+ tile_out_h = out_h_end - out_h_start
+ tile_out_w = out_w_end - out_w_start
+
+ # Linear feathering weights [B, 1, F, Ht, Wt] (compute_dtype)
+ tile_weights = torch.ones(
+ (b, 1, decoded_tile.shape[2], tile_out_h, tile_out_w),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+
+ overlap_out_h = overlap * sf
+ overlap_out_w = overlap * sf
+
+ # Horizontal feathering: left/right overlaps
+ if overlap_out_w > 0:
+ if h > 0:
+ h_blend = torch.linspace(0, 1, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype)
+ tile_weights[:, :, :, :, :overlap_out_w] *= h_blend.view(1, 1, 1, 1, -1)
+ if h < horizontal_tiles - 1:
+ h_blend = torch.linspace(1, 0, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype)
+ tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend.view(1, 1, 1, 1, -1)
+
+ # Vertical feathering: top/bottom overlaps
+ if overlap_out_h > 0:
+ if v > 0:
+ v_blend = torch.linspace(0, 1, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype)
+ tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1)
+ if v < vertical_tiles - 1:
+ v_blend = torch.linspace(1, 0, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype)
+ tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1)
+
+ # Accumulate blended tile
+ output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += decoded_tile * tile_weights
+ weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights
+
+ # Normalize, then clamp to [-1, 1] in compute_dtype to avoid color artifacts
+ output = output / (weights + 1e-8)
+ output = output.clamp(-1.0, 1.0)
+ output = output.to(dtype=self.vae.dtype)
+
+ # Optional: drop the last tsf frames after last_frame_fix
+ if last_frame_fix:
+ output = output[:, :, :-tsf, :, :]
+
+ if output_type in ("np", "pil"):
+ return self.video_processor.postprocess_video(output, output_type=output_type)
+ return output
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_segments: Optional[List[Dict[str, Any]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: float = 25,
+ guidance_scale: float = 3.0,
+ guidance_rescale: float = 1.0,
+ num_inference_steps: Optional[int] = 8,
+ sigmas: Optional[Union[List[float], torch.Tensor]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ seed: Optional[int] = 0,
+ cond_image: Optional[Union["PIL.Image.Image", torch.Tensor]] = None,
+ cond_strength: float = 1.0,
+ latents: Optional[torch.Tensor] = None,
+ temporal_tile_size: int = 80,
+ temporal_overlap: int = 24,
+ temporal_overlap_cond_strength: float = 0.5,
+ adain_factor: float = 0.25,
+ guidance_latents: Optional[torch.Tensor] = None,
+ guiding_strength: float = 1.0,
+ negative_index_latents: Optional[torch.Tensor] = None,
+ negative_index_strength: float = 1.0,
+ skip_steps_sigma_threshold: Optional[float] = 0.997,
+ decode_timestep: Optional[float] = 0.05,
+ decode_noise_scale: Optional[float] = 0.025,
+ decode_horizontal_tiles: int = 4,
+ decode_vertical_tiles: int = 4,
+ decode_overlap: int = 3,
+ output_type: Optional[str] = "latent", # "latent" | "pt" | "np" | "pil"
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ ):
+ """
+ Generate an image-to-video sequence via temporal sliding windows and multi-prompt scheduling.
+
+ Args:
+ prompt: str | list[str]. Positive text prompt(s) per window. If a single string contains '|', parts are split by bars.
+ negative_prompt: str | list[str], optional. Negative prompt(s) to suppress undesired content.
+ prompt_segments: list[dict], optional. Segment mapping with {"start_window", "end_window", "text"} to override prompts per window.
+ height: int. Output image height in pixels; must be divisible by 32.
+ width: int. Output image width in pixels; must be divisible by 32.
+ num_frames: int. Number of output frames (in decoded pixel space).
+ frame_rate: float. Frames-per-second; used to normalize temporal coordinates in `video_coords`.
+ guidance_scale: float. CFG scale; values > 1 enable classifier-free guidance.
+ guidance_rescale: float. Optional rescale to mitigate overexposure under CFG (see `rescale_noise_cfg`).
+ num_inference_steps: int, optional. Denoising steps per window. Ignored if `sigmas` is provided.
+ sigmas: list[float] | Tensor, optional. Explicit sigma schedule per window; if set, overrides `num_inference_steps`.
+ generator: torch.Generator | list[torch.Generator], optional. Controls stochasticity; list accepted but first element is used (batch=1).
+ seed: int, optional. If provided, per-window hard-condition noise uses `seed + w_start` (latent-frame index).
+ cond_image: PIL.Image.Image | Tensor, optional. Conditioning image; fixes frame 0 via per-token mask when `cond_strength > 0`.
+ cond_strength: float. Strength of first-frame hard conditioning (smaller cond_mask ⇒ stronger preservation).
+ latents: Tensor, optional. Initial latents [B, C_lat, F_lat, H_lat, W_lat]; if None, sampled with `randn_tensor`.
+ temporal_tile_size: int. Temporal window size (in decoded frames); internally scaled by VAE temporal compression.
+ temporal_overlap: int. Overlap between consecutive windows (in decoded frames); internally scaled by compression.
+ temporal_overlap_cond_strength: float. Strength for injecting previous window tail latents at new window head.
+ adain_factor: float. AdaIN normalization strength for cross-window consistency (0 disables).
+ guidance_latents: Tensor, optional. Reference latents injected at window head; length trimmed by overlap for subsequent windows.
+ guiding_strength: float. Injection strength for `guidance_latents`.
+ negative_index_latents: Tensor, optional. A single-frame latent appended at window head for "negative index" semantics.
+ negative_index_strength: float. Injection strength for `negative_index_latents`.
+ skip_steps_sigma_threshold: float, optional. Skip steps whose sigma exceeds this threshold.
+ decode_timestep: float, optional. Decode-time timestep (if VAE supports timestep_conditioning).
+ decode_noise_scale: float, optional. Decode-time noise mix scale (if VAE supports timestep_conditioning).
+ decode_horizontal_tiles: int. Number of horizontal tiles during VAE decoding.
+ decode_vertical_tiles: int. Number of vertical tiles during VAE decoding.
+ decode_overlap: int. Overlap (in latent pixels) between tiles during VAE decoding.
+ output_type: str, optional. "latent" | "pt" | "np" | "pil". If "latent", returns latents without decoding.
+ return_dict: bool. If True, return LTXPipelineOutput; else return tuple(frames,).
+ attention_kwargs: dict, optional. Extra attention parameters forwarded to the transformer.
+ callback_on_step_end: PipelineCallback | MultiPipelineCallbacks, optional. Per-step callback hook.
+ callback_on_step_end_tensor_inputs: list[str]. Keys from locals() to pass into the callback.
+ max_sequence_length: int. Tokenizer max length for prompt encoding.
+
+ Returns:
+ - LTXPipelineOutput when `return_dict=True`:
+ frames: Tensor | np.ndarray | list[PIL.Image.Image]
+ • "latent"/"pt": Tensor [B, C, F, H, W]; "latent" is in normalized latent space, "pt" is VAE output space.
+ • "np": np.ndarray post-processed; "pil": list of PIL images.
+ - tuple(frames,) when `return_dict=False`.
+
+ Shapes:
+ - Latent sizes (when auto-generated):
+ F_lat = (num_frames - 1) // vae_temporal_compression_ratio + 1
+ H_lat = height // vae_spatial_compression_ratio
+ W_lat = width // vae_spatial_compression_ratio
+
+ Notes:
+ - Seeding: per-window hard-condition initialization uses `seed + w_start` when `seed` is provided; otherwise uses `generator`.
+ - CFG: unified `noise_pred = uncond + w * (text - uncond)` with optional `guidance_rescale`.
+ - Memory: denoising performs full-frame predictions (no spatial tiling); decoding can be tiled to avoid OOM.
+
+ Example:
+ >>> out = pipe(prompt="a chimpanzee walks | a chimpanzee eats", num_frames=161, height=512, width=704,
+ ... temporal_tile_size=80, temporal_overlap=24, output_type="pil", return_dict=True)
+ >>> frames = out.frames[0] # list of PIL.Image.Image
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. Input validation: height/width must be divisible by 32
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 1. Device & generator
+ device = self._execution_device
+ # Normalize generator input: accept list but use the first (batch_size=1)
+ # When `seed` is set, per-window noise will override `generator` with `seed + w_start`
+ # to ensure deterministic window-local noise.
+ if isinstance(generator, list):
+ generator = generator[0]
+ if seed is not None and generator is None:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ # 2. Global initial latents [B,C,F,H,W]
+ if latents is None:
+ num_channels_latents = self.transformer.config.in_channels
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ # Initialize latents with standard Gaussian noise (Diffusers-style)
+ latents = torch.ones((1, num_channels_latents, latent_num_frames, latent_height, latent_width), device=device, dtype=torch.float32)
+ else:
+ latent_num_frames = latents.shape[2]
+ latent_height = latents.shape[3]
+ latent_width = latents.shape[4]
+ latents= latents.to(device=device, dtype=torch.float32)
+ if guidance_latents is not None:
+ guidance_latents = guidance_latents.to(device=device, dtype=torch.float32)
+ if latents.shape[2] != guidance_latents.shape[2]:
+ raise ValueError("The number of frames in `latents` and `guidance_latents` must be the same")
+ # 3. Optional i2v first frame conditioning: encode cond_image and inject at frame 0
+ if cond_image is not None and cond_strength > 0:
+ img = self.video_processor.preprocess(cond_image, height=height, width=width)
+ img = img.to(device=device, dtype=self.vae.dtype)
+ enc = self.vae.encode(img.unsqueeze(2)) # [B, C, 1, h, w]
+ init_latents = enc.latent_dist.mode() if hasattr(enc, "latent_dist") else enc.latents
+ init_latents = init_latents.to(torch.float32)
+ init_latents = LTXPipeline._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ if negative_index_latents is None:
+ negative_index_latents = init_latents
+ # Blend only first latent frame
+ latents[:, :, 0, :, :] = init_latents.squeeze(2)
+
+ # 4. Prepare sigma schedule; honor explicit sigmas; disable shifting here (handled by scheduler)
+ sigmas_reset = None
+ n_steps_reset = None
+ if sigmas is not None:
+ s = torch.tensor(sigmas, dtype=torch.float32) if not isinstance(sigmas, torch.Tensor) else sigmas
+ # ComfyUI parity: do not filter steps here by threshold; skip inside the loop before scheduler.step
+ sigmas_reset = s
+ self.scheduler.set_timesteps(sigmas=s, device=device)
+ else:
+ if num_inference_steps is None:
+ num_inference_steps = 50
+ n_steps_reset = num_inference_steps
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+ self._num_timesteps = len(timesteps)
+
+ # 5. Sliding windows in latent frames
+ tile_size_lat = max(1, temporal_tile_size // self.vae_temporal_compression_ratio)
+ overlap_lat = max(0, temporal_overlap // self.vae_temporal_compression_ratio)
+ windows = self._split_into_temporal_windows(
+ latent_num_frames, tile_size_lat, overlap_lat, self.vae_temporal_compression_ratio
+ )
+
+ # 6. Multi-prompt segments parsing
+ segment_texts = self._parse_prompt_segments(prompt, prompt_segments)
+
+ out_latents = None
+ first_window_latents = None
+
+ # 7. Process each temporal window
+ for w_idx, (w_start, w_end) in enumerate(windows):
+ if self.interrupt:
+ break
+
+ # 7.1 Encode prompt embeddings per window segment
+ seg_index = min(w_idx, len(segment_texts) - 1) if segment_texts else 0
+ pos_text = segment_texts[seg_index] if segment_texts else (prompt if isinstance(prompt, str) else "")
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = LTXPipeline.encode_prompt(
+ self,
+ prompt=[pos_text],
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=1,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=None,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 7.2 Window-level timesteps reset: fresh sampling for each temporal window
+ if sigmas_reset is not None:
+ self.scheduler.set_timesteps(sigmas=sigmas_reset, device=device)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps=n_steps_reset, device=device)
+
+ # 7.3 Extract window latents [B,C,T,H,W]
+ window_latents = latents[:, :, w_start:w_end]
+ window_guidance_latents = guidance_latents[:, :, w_start:w_end] if guidance_latents is not None else None
+ window_T = window_latents.shape[2]
+
+ # 7.4 Build per-window cond mask and inject previous tails / reference
+ window_cond_mask_5d = torch.ones(
+ (1, 1, window_T, latent_height, latent_width), device=device, dtype=torch.float32
+ )
+ self._current_tile_T = window_T
+ prev_overlap_len = 0
+ # Inter-window tail latent injection (Extend)
+ if w_idx > 0 and overlap_lat > 0 and out_latents is not None:
+ k = min(overlap_lat, out_latents.shape[2])
+ prev_tail = out_latents[:, :, -k:]
+ window_latents, window_cond_mask_5d, prev_overlap_len = self._inject_prev_tail_latents(
+ window_latents, prev_tail, window_cond_mask_5d, overlap_lat, temporal_overlap_cond_strength, prev_overlap_len
+ )
+ # Reference/negative-index latent injection (append 1 frame at window head; controlled by negative_index_strength)
+ if window_guidance_latents is not None:
+ guiding_len = window_guidance_latents.shape[2] if w_idx==0 else window_guidance_latents.shape[2] - overlap_lat
+ window_latents, window_cond_mask_5d, prev_overlap_len = self._inject_prev_tail_latents(
+ window_latents, window_guidance_latents[:,:,-guiding_len:], window_cond_mask_5d, guiding_len, guiding_strength, prev_overlap_len
+ )
+ else:
+ guiding_len = 0
+ window_latents, window_cond_mask_5d, prev_overlap_len = self._inject_prev_tail_latents(
+ window_latents, negative_index_latents, window_cond_mask_5d, 1, negative_index_strength, prev_overlap_len
+ )
+ if w_idx == 0 and cond_image is not None and cond_strength > 0:
+ # First-frame I2V: smaller mask means stronger preservation of the original latent
+ window_cond_mask_5d[:, :, 0] = 1.0 - cond_strength
+
+ # Update effective window latent sizes (consider injections on T/H/W)
+ w_B, w_C, w_T_eff, w_H_eff, w_W_eff = window_latents.shape
+ p = self.transformer_spatial_patch_size
+ pt = self.transformer_temporal_patch_size
+
+ # 7.5 Pack full-window latents/masks once
+ # randn*mask + (1-mask)*latents implements hard-condition initialization
+ # Seeding policy:
+ # - If `seed` is provided, derive a window-local seed = seed + w_start (latent-frame index),
+ # and use it to create a dedicated local generator. This mirrors ComfyUI behavior and keeps
+ # cross-window reproducibility while avoiding inter-window RNG coupling.
+ # - Otherwise, fall back to the passed-in `generator` (if any).
+ if seed is not None:
+ tile_seed = int(seed) + int(w_start)
+ local_gen = torch.Generator(device=device).manual_seed(tile_seed)
+ else:
+ local_gen = generator
+ init_rand = randn_tensor(
+ window_latents.shape, generator=local_gen, device=device, dtype=torch.float32
+ )
+ mixed_latents = init_rand * window_cond_mask_5d + (1 - window_cond_mask_5d) * window_latents
+ window_latents_packed = LTXPipeline._pack_latents(
+ window_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ latents_packed = LTXPipeline._pack_latents(
+ mixed_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ cond_mask_tokens = LTXPipeline._pack_latents(
+ window_cond_mask_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ if self.do_classifier_free_guidance:
+ cond_mask = torch.cat([cond_mask_tokens, cond_mask_tokens], dim=0)
+ else:
+ cond_mask = cond_mask_tokens
+
+ # 7.6 Denoising loop per full window (no spatial tiling)
+ sigmas_current = self.scheduler.sigmas.to(device=latents_packed.device)
+ if sigmas_current.shape[0] >= 2:
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[:-1])):
+ if self.interrupt:
+ break
+
+ self._current_timestep = t
+
+ # Model input (stack 2 copies under CFG)
+ latent_model_input = torch.cat([latents_packed] * 2) if self.do_classifier_free_guidance else latents_packed
+ # Broadcast timesteps, combine with per-token cond mask (I2V at window head)
+ timestep = t.expand(latent_model_input.shape[0])
+ if cond_mask is not None:
+ timestep = timestep.unsqueeze(-1) * (cond_mask)
+ # Skip semantics: if sigma exceeds threshold, skip this step (do not call scheduler.step)
+ sigma_val = float(sigmas_current[i].item())
+ if skip_steps_sigma_threshold is not None and float(skip_steps_sigma_threshold) > 0.0:
+ if sigma_val > float(skip_steps_sigma_threshold):
+ continue
+
+ # Micro-conditions: only provide video_coords (num_frames/height/width set to 1)
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+ # Negative-index/overlap lengths (for segmenting time coordinates; RoPE-compatible)
+ k_negative_count = 1 if (negative_index_latents is not None and float(negative_index_strength) > 0.0) else 0
+ k_overlap_count = overlap_lat if (w_idx > 0 and overlap_lat > 0) else 0
+ video_coords = self._build_video_coords_for_window(
+ latents=window_latents,
+ overlap_len=int(k_overlap_count),
+ guiding_len=int(guiding_len),
+ negative_len=int(k_negative_count),
+ rope_interpolation_scale=rope_interpolation_scale,
+ frame_rate=frame_rate,
+ )
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input.to(dtype=self.transformer.dtype),
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=1,
+ height=1,
+ width=1,
+ rope_interpolation_scale=rope_interpolation_scale,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # Unified CFG
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = self.apply_cfg(
+ noise_pred, noise_pred_uncond, noise_pred_text, self.guidance_scale, self.guidance_rescale
+ )
+
+ # Use global timestep for scheduling, but apply suppressive blending with hard-condition tokens (e.g., first frame) after step to avoid brightness/flicker due to time misalignment
+ noise_pred = noise_pred * cond_mask_tokens.unsqueeze(-1)+ window_latents_packed * (1.0 - cond_mask_tokens.unsqueeze(-1))
+ latents_packed = self.scheduler.step(
+ noise_pred, t, latents_packed, return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents_packed = callback_outputs.pop("latents", latents_packed)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+ else:
+ # Not enough sigmas to perform a valid step; skip this window safely.
+ pass
+
+ # 7.7 Unpack back to [B,C,T,H,W] once
+ window_out = LTXPipeline._unpack_latents(
+ latents_packed,
+ w_T_eff,
+ w_H_eff,
+ w_W_eff,
+ p,
+ pt,
+ )
+ if prev_overlap_len > 0:
+ window_out = window_out[:, :, :-prev_overlap_len]
+
+ # 7.8 Overlap handling and fusion
+ if out_latents is None:
+ # First window: keep all latent frames and cache as AdaIN reference
+ out_latents = window_out
+ first_window_latents = out_latents
+ else:
+ window_out=window_out[:, :, 1:] # Drop the first frame of the new window
+ if adain_factor > 0 and first_window_latents is not None:
+ window_out = self._adain_normalize_latents(window_out, first_window_latents, adain_factor)
+ overlap_len = max(overlap_lat-1, 1)
+ prev_tail_chunk = out_latents[:, :, -window_out.shape[2]:]
+ fused = self._linear_overlap_fuse(prev_tail_chunk, window_out, overlap_len)
+ out_latents = torch.cat([out_latents[:, :, :-window_out.shape[2]], fused], dim=2)
+
+ # 8. Decode or return latent
+ if output_type == "latent":
+ video = out_latents
+ else:
+ # Decode via tiling to avoid OOM from full-frame decoding; latents are already de-normalized, so keep auto_denormalize disabled
+ video = self.vae_decode_tiled(
+ out_latents,
+ decode_timestep=decode_timestep,
+ decode_noise_scale=decode_noise_scale,
+ horizontal_tiles=int(decode_horizontal_tiles),
+ vertical_tiles=int(decode_vertical_tiles),
+ overlap=int(decode_overlap),
+ generator=generator,
+ output_type=output_type, # Keep type consistent; postprocess is applied afterwards
+ )
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)