Custom collate functions

Custom collate functions should be passed to the collate_fn argument of a DataLoader:

torchtime.collate.sort_by_length(batch_data)[source]

Collates a batch and sorts by descending length.

Pass to the collate_fn argument when creating a PyTorch DataLoader. Batches are a named dictionary with X, y and length data that is sorted by length. pack_padded_sequence() can therefore be called in the forward method of the model. For example:

from torch.utils.data import DataLoader
from torchtime.data import UEA
from torchtime.collate import sort_by_length

char_traj = UEA(
    dataset="CharacterTrajectories",
    split="train",
    train_prop=0.7,
    seed=123,
)
dataloader = DataLoader(
    char_traj,
    batch_size=32,
    collate_fn=sort_by_length,
)
print(next(iter(dataloader))["length"])
...
tensor([157, 151, 151, 150, 138, 138, 136, 135, 133, 130, 129, 129, 127, 126,
        124, 124, 121, 121, 118, 117, 117, 117, 113, 108, 106, 105, 102,  98,
         83,  74,  74,  61])
Parameters:

batch_data (Dict[str, Tensor]) – Batch from a torchtime.data class.

Return type:

Dict[str, Tensor]

Returns:

Updated batch.

torchtime.collate.packed_sequence(batch_data)[source]

Collates a batch and returns data as PackedSequence objects.

Pass to the collate_fn argument when creating a PyTorch DataLoader. Batches are a named dictionary with X, y and length data where X and y are PackedSequence objects. For example:

from torch.utils.data import DataLoader
from torchtime.data import UEA
from torchtime.collate import packed_sequence

char_traj = UEA(
    dataset="CharacterTrajectories",
    split="train",
    train_prop=0.7,
    seed=123,
)
dataloader = DataLoader(
    char_traj,
    batch_size=32,
    collate_fn=packed_sequence,
)
print(next(iter(dataloader))["X"])
...
PackedSequence(data=tensor([[ 0.0000e+00,  2.2753e-01,  6.0560e-03,  2.0894e-02],
        [ 0.0000e+00, -3.6401e-02,  1.1512e-01,  7.3964e-01],
        [ 0.0000e+00,  5.6454e-01, -1.0000e-05,  2.9244e-01],
        ...,
        [ 1.5400e+02, -2.6396e-01,  1.9185e-01, -1.4082e+00],
        [ 1.5500e+02, -2.2807e-01,  1.6577e-01, -1.2167e+00],
        [ 1.5600e+02, -1.7705e-01,  1.2868e-01, -9.4452e-01]]), batch_sizes=tensor([32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
        31, 31, 29, 29, 29, 29, 29, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28,
        28, 28, 28, 28, 28, 28, 28, 28, 27, 27, 27, 27, 26, 26, 26, 25, 24, 24,
        23, 23, 23, 23, 23, 22, 22, 22, 22, 19, 18, 18, 18, 16, 16, 16, 14, 14,
        13, 12, 12, 10,  9,  9,  9,  8,  8,  7,  6,  6,  4,  4,  4,  4,  4,  4,
         4,  4,  4,  4,  4,  4,  3,  1,  1,  1,  1,  1,  1]), sorted_indices=None, unsorted_indices=None)
Parameters:

batch_data (Dict[str, Tensor]) – Batch from a torchtime.data class.

Return type:

Dict[str, Union[PackedSequence, Tensor]]

Returns:

Updated batch.