Source code for torchtime.collate

"""
========================
Custom collate functions
========================

Custom collate functions should be passed to the ``collate_fn`` argument of a
`DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_:

* `sort_by_length() <torchtime.collate.sort_by_length>`_ sorts each batch by descending
  length

* `packed_sequence() <torchtime.collate.packed_sequence>`_ returns ``X`` and ``y`` as a
  ``PackedSequence`` object
"""

from typing import Dict, Union

import torch
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence


[docs]def sort_by_length(batch_data: Dict[str, Tensor]) -> Dict[str, Tensor]: """**Collates a batch and sorts by descending length.** Pass to the ``collate_fn`` argument when creating a PyTorch DataLoader. Batches are a named dictionary with ``X``, ``y`` and ``length`` data that is sorted by length. ``pack_padded_sequence()`` can therefore be called in the forward method of the model. For example: .. testcode:: from torch.utils.data import DataLoader from torchtime.data import UEA from torchtime.collate import sort_by_length char_traj = UEA( dataset="CharacterTrajectories", split="train", train_prop=0.7, seed=123, ) dataloader = DataLoader( char_traj, batch_size=32, collate_fn=sort_by_length, ) print(next(iter(dataloader))["length"]) .. testoutput:: ... tensor([157, 151, 151, 150, 138, 138, 136, 135, 133, 130, 129, 129, 127, 126, 124, 124, 121, 121, 118, 117, 117, 117, 113, 108, 106, 105, 102, 98, 83, 74, 74, 61]) Args: batch_data: Batch from a ``torchtime.data`` class. Returns: Updated batch. """ X = torch.stack([i["X"] for i in batch_data], dim=0) y = torch.stack([i["y"] for i in batch_data]) length = torch.stack([i["length"] for i in batch_data]) length, sorted_index = torch.sort(length, descending=True) X = torch.index_select(X, 0, sorted_index) y = torch.index_select(y, 0, sorted_index) return {"X": X, "y": y, "length": length}
[docs]def packed_sequence( batch_data: Dict[str, Tensor] ) -> Dict[str, Union[PackedSequence, Tensor]]: """**Collates a batch and returns data as PackedSequence objects.** Pass to the ``collate_fn`` argument when creating a PyTorch DataLoader. Batches are a named dictionary with ``X``, ``y`` and ``length`` data where ``X`` and ``y`` are PackedSequence objects. For example: .. testcode:: from torch.utils.data import DataLoader from torchtime.data import UEA from torchtime.collate import packed_sequence char_traj = UEA( dataset="CharacterTrajectories", split="train", train_prop=0.7, seed=123, ) dataloader = DataLoader( char_traj, batch_size=32, collate_fn=packed_sequence, ) print(next(iter(dataloader))["X"]) .. testoutput:: ... PackedSequence(data=tensor([[ 0.0000e+00, 2.2753e-01, 6.0560e-03, 2.0894e-02], [ 0.0000e+00, -3.6401e-02, 1.1512e-01, 7.3964e-01], [ 0.0000e+00, 5.6454e-01, -1.0000e-05, 2.9244e-01], ..., [ 1.5400e+02, -2.6396e-01, 1.9185e-01, -1.4082e+00], [ 1.5500e+02, -2.2807e-01, 1.6577e-01, -1.2167e+00], [ 1.5600e+02, -1.7705e-01, 1.2868e-01, -9.4452e-01]]), batch_sizes=tensor([32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 29, 29, 29, 29, 29, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 27, 27, 27, 27, 26, 26, 26, 25, 24, 24, 23, 23, 23, 23, 23, 22, 22, 22, 22, 19, 18, 18, 18, 16, 16, 16, 14, 14, 13, 12, 12, 10, 9, 9, 9, 8, 8, 7, 6, 6, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None) Args: batch_data: Batch from a ``torchtime.data`` class. Returns: Updated batch. """ # noqa: E501 batch_data = sort_by_length(batch_data) length = batch_data["length"] X = pack_padded_sequence(batch_data["X"], length, batch_first=True) y = pack_padded_sequence(batch_data["y"], length, batch_first=True) return {"X": X, "y": y, "length": length}