Custom collate functions
Custom collate functions should be passed to the collate_fn
argument of a
DataLoader:
sort_by_length() sorts each batch by descending length
packed_sequence() returns
X
andy
as aPackedSequence
object
- torchtime.collate.sort_by_length(batch_data)[source]
Collates a batch and sorts by descending length.
Pass to the
collate_fn
argument when creating a PyTorch DataLoader. Batches are a named dictionary withX
,y
andlength
data that is sorted by length.pack_padded_sequence()
can therefore be called in the forward method of the model. For example:from torch.utils.data import DataLoader from torchtime.data import UEA from torchtime.collate import sort_by_length char_traj = UEA( dataset="CharacterTrajectories", split="train", train_prop=0.7, seed=123, ) dataloader = DataLoader( char_traj, batch_size=32, collate_fn=sort_by_length, ) print(next(iter(dataloader))["length"])
... tensor([157, 151, 151, 150, 138, 138, 136, 135, 133, 130, 129, 129, 127, 126, 124, 124, 121, 121, 118, 117, 117, 117, 113, 108, 106, 105, 102, 98, 83, 74, 74, 61])
- Parameters:
batch_data (
Dict
[str
,Tensor
]) – Batch from atorchtime.data
class.- Return type:
Dict
[str
,Tensor
]- Returns:
Updated batch.
- torchtime.collate.packed_sequence(batch_data)[source]
Collates a batch and returns data as PackedSequence objects.
Pass to the
collate_fn
argument when creating a PyTorch DataLoader. Batches are a named dictionary withX
,y
andlength
data whereX
andy
are PackedSequence objects. For example:from torch.utils.data import DataLoader from torchtime.data import UEA from torchtime.collate import packed_sequence char_traj = UEA( dataset="CharacterTrajectories", split="train", train_prop=0.7, seed=123, ) dataloader = DataLoader( char_traj, batch_size=32, collate_fn=packed_sequence, ) print(next(iter(dataloader))["X"])
... PackedSequence(data=tensor([[ 0.0000e+00, 2.2753e-01, 6.0560e-03, 2.0894e-02], [ 0.0000e+00, -3.6401e-02, 1.1512e-01, 7.3964e-01], [ 0.0000e+00, 5.6454e-01, -1.0000e-05, 2.9244e-01], ..., [ 1.5400e+02, -2.6396e-01, 1.9185e-01, -1.4082e+00], [ 1.5500e+02, -2.2807e-01, 1.6577e-01, -1.2167e+00], [ 1.5600e+02, -1.7705e-01, 1.2868e-01, -9.4452e-01]]), batch_sizes=tensor([32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 29, 29, 29, 29, 29, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 27, 27, 27, 27, 26, 26, 26, 25, 24, 24, 23, 23, 23, 23, 23, 22, 22, 22, 22, 19, 18, 18, 18, 16, 16, 16, 14, 14, 13, 12, 12, 10, 9, 9, 9, 8, 8, 7, 6, 6, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)
- Parameters:
batch_data (
Dict
[str
,Tensor
]) – Batch from atorchtime.data
class.- Return type:
Dict
[str
,Union
[PackedSequence
,Tensor
]]- Returns:
Updated batch.