Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ jobs:
uses: astral-sh/setup-uv@v4
with:
enable-cache: true
cache-prefix: diffulab-uv-cache

- name: "Set up Python"
if: steps.changes.outputs.sources == 'true'
Expand Down
67 changes: 67 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: Unit Tests

on:
push:
branches: [ main ]
paths:
- '**.py'
- 'pyproject.toml'
- 'uv.lock'
- '.github/workflows/unit_tests.yml'
- 'tests/**'
pull_request:
types: [opened, synchronize, reopened]
branches: [ main ]
workflow_dispatch:

permissions:
contents: read

concurrency:
group: unit-tests-${{ github.ref }}
cancel-in-progress: true

jobs:
if: github.event.pull_request.draft == false
tests:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v4

- name: check changes
uses: dorny/paths-filter@v3
id: changes
with:
filters: |
sources:
- src/**
- tests/**
- .github/workflows/unit_tests.yml

- name: Install uv
if: steps.changes.outputs.sources == 'true'
uses: astral-sh/setup-uv@v4
with:
enable-cache: true

- name: "Set up Python"
if: steps.changes.outputs.sources == 'true'
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
if: steps.changes.outputs.sources == 'true'
run: uv sync --all-extras

- name: Run tests
if: steps.changes.outputs.sources == 'true'
run: uv run pytest --junitxml=results.xml

- name: Upload test report
if: steps.changes.outputs.sources == 'true'
uses: actions/upload-artifact@v4
with:
name: test-results
path: results.xml
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dev = [
"jupyter>=1.1.1",
"pre-commit>=4.2.0",
"pyright>=1.1.396",
"pytest>=8.4.1",
"ruff>=0.9.7",
"typos>=1.30.0",
]
Expand Down
18 changes: 9 additions & 9 deletions src/diffulab/networks/denoisers/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __init__(
self.n_classes = n_classes
self.classifier_free = classifier_free

self.label_embed: LabelEmbed | None = None
if not self.simple_dit:
assert self.context_embedder is not None, "for MMDiT context embedder must be provided"
assert self.context_embedder.n_output == 2, "for MMDiT context embedder should provide 2 embeddings"
Expand All @@ -513,9 +514,8 @@ def __init__(
)
self.context_embed = nn.Linear(self.context_embedder.output_size[1], context_dim)
else:
self.label_embed = (
LabelEmbed(self.n_classes, embedding_dim, self.classifier_free) if self.n_classes is not None else None
)
if self.n_classes is not None:
self.label_embed = LabelEmbed(self.n_classes, embedding_dim, self.classifier_free)

self.last_layer = ModulatedLastLayer(
embedding_dim=embedding_dim,
Expand Down Expand Up @@ -623,14 +623,14 @@ def mmdit_forward(
# Pass through each layer sequentially
for layer in self.layers:
x, context = layer(x, context_pooled, context)
if features:
if features is not None:
features.append(x)

x = self.last_layer(x, context_pooled)
if features:
if features is not None:
features.append(x)
model_output: ModelOutput = {"x": x}
if features:
if features is not None:
model_output["features"] = features
return model_output

Expand All @@ -655,14 +655,14 @@ def simple_dit_forward(
# Pass through each layer sequentially
for layer in self.layers:
x = layer(x, emb)
if features:
if features is not None:
features.append(x)

x = self.last_layer(x, emb)
if features:
if features is not None:
features.append(x)
model_output: ModelOutput = {"x": x}
if features:
if features is not None:
model_output["features"] = features
return model_output

Expand Down
13 changes: 5 additions & 8 deletions src/diffulab/networks/denoisers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def _forward(
context_len = k.shape[-1]

# reshape for multi-head attention
q = q.view(b * self.num_heads, self.dim_head, x_len)
k = k.view(b * self.num_heads, self.dim_head, context_len)
v = v.view(b * self.num_heads, self.dim_head, context_len)
q = q.reshape(b * self.num_heads, self.dim_head, x_len)
k = k.reshape(b * self.num_heads, self.dim_head, context_len)
v = v.reshape(b * self.num_heads, self.dim_head, context_len)

dots = torch.einsum("bct,bcs->bts", q * self.scale, k * self.scale) # (b*h, x_len, context_len)
attn = self.attend(dots.float()).type(dots.dtype)
Expand Down Expand Up @@ -374,7 +374,6 @@ class UNetModel(Denoiser):
label_embed (LabelEmbed | None): Label embedding module (if class-conditional).
context_embedder (ContextEmbedder | None): External context conditioning module.
classifier_free (bool): Whether classifier-free guidance is enabled for labels.
dtype (torch.dtype): Chosen compute dtype (bfloat16 or float32).
input_blocks (nn.ModuleList): Encoder path (with optional attention + downsampling).
middle_block (EmbedSequential): Bottleneck block (res-attn-res).
output_blocks (nn.ModuleList): Decoder path (skip connections + optional attention + upsampling).
Expand Down Expand Up @@ -450,7 +449,6 @@ def __init__(
channel_mult: str = "1, 2, 4, 8",
conv_resample: bool = True,
use_checkpoint: bool = False,
use_fp16: bool = False,
num_heads: int = 1,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
Expand All @@ -477,7 +475,6 @@ def __init__(
self.channel_mult: list[int] = eval(f"[{channel_mult}]")
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = torch.bfloat16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.context_embedder = context_embedder
self.classifier_free = classifier_free
Expand Down Expand Up @@ -714,10 +711,10 @@ class labels and/or external context (e.g. text or images) with support for
if self.label_embed is not None:
emb = emb + self.label_embed(y, p)
if self.context_embedder is not None:
context = self.context_embedder(context, p)
context = self.context_embedder(context, p)[0]
if x_context is not None:
x = torch.cat([x, x_context], dim=1)
h = x.type(self.dtype)
h = x
for module in self.input_blocks:
h: Tensor = module(h, emb=emb, context=context)
hs.append(h)
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Pytest configuration and fixtures for the entire test suite."""

import pytest
import torch


@pytest.fixture(autouse=True)
def set_random_seed():
"""Set random seed for reproducible testing."""
torch.manual_seed(42) # type: ignore
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
yield
pass
Empty file.
Loading
Loading