Notes on Implementing a Connectionist Temporal Classification Model
08 December, 2019 - 5 min read
A CTC (connectionist temporal classification) model is a type of classification model that can be used for classifying segments of a sequence. One application is speech recognition where you might have the audio for a speaker, and want to classify which segments of the audio wave (unsegmented) correspond to which spoken words (segmented). It also has found applications in genomics where it's been used for base calling in Nanopore sequencers.
These notes are a walk-through of fitting a contrived example in an effort to walk through a CTC model's inputs, outputs, and internals.
Imagine two coins are being flipped, but you, the observer, don't see which coin, only the result: heads or tails. However, you are told the order of coin usage just not which one is used for each flip. E.g.
HTHHHHTTT # Outcome of flips 000011100 # Underlying coin state ABA # What observer is told
So you know if you're in the "A" coin state, then the "B" state, then the "A" state again, but you're not told when.
This sounds vaguely like an HMM, some latent states and emissions we observe, so let's use an HMM to generate some fake data.
pomegranate, it's easy to generate sample data with an HMM.
import pomegranate as p # Two coins, one is heads 1% and one is heads 99% state1 = p.State(p.BernoulliDistribution(.01, 1), name=str(1)) state2 = p.State(p.BernoulliDistribution(.99, 1), name=str(2)) model = p.HiddenMarkovModel(name="Biased Coins") model.add_state(state1) model.add_state(state2) model.add_transition(model.start, state1, 0.5) model.add_transition(model.start, state2, 0.5) # Transition out of states 5% and within states 95% model.add_transition(state1, state1, 0.95) model.add_transition(state1, state2, 0.05) model.add_transition(state2, state2, 0.95) model.add_transition(state2, state1, 0.05) model.bake()
Now we'll generate some fake data. We need to keep track of a set of variables:
LThe number of coin flips.
examplesThe set of coin flips, this is an
N x Lmatrix where
Nis the number of examples and
Lis defined above.
targetsThe non-redundant sequence of underlying states,
010from the opening example. Again this will be padded to the longest target sequence. If the length of the target is different between sequences, the list should be padded with a blank character (
0by default in the
CTCLossimplementation in torch).
import numpy as np # Holds the N examples, their non-redundant targets and the length N = 1000 L = 50 examples =  targets =  target_lengths =  for n in range(N): array, path = model.sample(length=L, path=True) target =  # Add the first state to the target # then only add new ones if they don't match the original target.append(int(path.name)) for p in path[2:]: p_name = int(p.name) if target[-1] != p_name: target.append(p_name) l_target = len(target) target_lengths.append(l_target) targets.append(target) examples.append(array) # Once we know the longest target we can pad the targets with the blank character target_lengths = np.asarray(target_lengths) longest = target_lengths.max() targets = np.asarray([np.pad(t, (0, longest - len(t)), mode='constant') for t in targets]) examples = np.asarray(examples)
In the end we have the three inputs we need to fit the model, which map examples to
target_length which accounts for the exact length of each of the
hidden states since they've been padded.
I'll use torch to build the model, so in order to do that, first we need to setup the proper data loading mechanisms: a dataset and its data loader.
from torch.utils import data import torch class Dataset(data.Dataset): def __init__(self, examples, targets, target_lengths): self.examples = examples self.targets = targets self.target_lengths = target_lengths def __len__(self): return len(self.examples) def __getitem__(self, i): return (torch.tensor(self.examples[i]), torch.tensor(self.targets[i]), torch.tensor(self.target_lengths[i]), torch.tensor(L)) # Constant length examples dataset = Dataset(examples, targets, target_lengths) dataloader = data.DataLoader(dataset, batch_size=100, shuffle=True)
Then we can build the model, this will be an RNN that outputs the log softmax of the layer for each position in the target for the specific type class. Here's it's two for the two coin types.
import torch from torch import nn class CTCModel(nn.Module): def __init__(self): super(CTCModel, self).__init__() self.embedding = torch.nn.Embedding(2, 2) self.rnn = torch.nn.RNN(2, 5, num_layers=2) self.linear = torch.nn.Linear(5, 3) self.softmax = torch.nn.LogSoftmax(2) def forward(self, inputs): inputs = inputs.transpose(0, 1) inputs = self.embedding(inputs) inputs, _ = self.rnn(inputs) inputs = self.linear(inputs) return self.softmax(inputs) ctc_model = CTCModel() ctc_loss = nn.CTCLoss()
At this point the model can be trained, so we'll follow the standard PyTorch pattern. I tried to use ignite but it makes expects a single X, y pair, and we need multiple outputs to keep track of meta-data like the target length.
from torch.optim import Adam optimizer = Adam(ctc_model.parameters(), lr=0.01) for i in range(30): for i_batch,(example_batch, targets, target_length, example_length) in enumerate(dataloader): optimizer.zero_grad() output = ctc_model(example_batch) loss = ctc_loss(output, targets, example_length, target_length) loss.backward() optimizer.step()
Then we can do some spot checking to see if the model fit the training data. This is obviously not a rigorous evaluation, but if your model can't fit the training data it's probably not going to fit the test data.
ctc_model.eval() with torch.no_grad(): preds = ctc_model(example_batch) pred_token = preds.transpose(0, 1).argmax(-1) join = lambda x: "".join(str(y.item()) for y in x) print(join(example_batch)) print(join(pred_token)) print(join(targets))
In this case the output is:
00001111111111111111111111111111111111111111111111 11111222222222222222222222222222222222222222222222 120000000
This roughly looks right. The example had four cases where the coin flipped 0 and the rest were 1. This corresponds to state predictions of 1 (the biased 0 coin) in the first five flips and 2 in the rest, which match the provided state.