HomeBlog

Notes on Implementing a Connectionist Temporal Classification Model

08 December, 2019 - 5 min read

A CTC (connectionist temporal classification) model is a type of classification model that can be used for classifying segments of a sequence. One application is speech recognition where you might have the audio for a speaker, and want to classify which segments of the audio wave (unsegmented) correspond to which spoken words (segmented). It also has found applications in genomics where it's been used for base calling in Nanopore sequencers.

These notes are a walk-through of fitting a contrived example in an effort to walk through a CTC model's inputs, outputs, and internals.

Example

Imagine two coins are being flipped, but you, the observer, don't see which coin, only the result: heads or tails. However, you are told the order of coin usage just not which one is used for each flip. E.g.

HTHHHHTTT  # Outcome of flips
000011100  # Underlying coin state
ABA  # What observer is told

So you know if you're in the "A" coin state, then the "B" state, then the "A" state again, but you're not told when.

This sounds vaguely like an HMM, some latent states and emissions we observe, so let's use an HMM to generate some fake data.

Generating the Coin Flip

Using pomegranate, it's easy to generate sample data with an HMM.

import pomegranate as p

# Two coins, one is heads 1% and one is heads 99%
state1 = p.State(p.BernoulliDistribution(.01, 1), name=str(1))
state2 = p.State(p.BernoulliDistribution(.99, 1), name=str(2))

model = p.HiddenMarkovModel(name="Biased Coins")
model.add_state(state1)
model.add_state(state2)

model.add_transition(model.start, state1, 0.5)
model.add_transition(model.start, state2, 0.5)

# Transition out of states 5% and within states 95%
model.add_transition(state1, state1, 0.95)
model.add_transition(state1, state2, 0.05)
model.add_transition(state2, state2, 0.95)
model.add_transition(state2, state1, 0.05)

model.bake()

Generate Fake Data

Now we'll generate some fake data. We need to keep track of a set of variables:

  • L The number of coin flips.
  • examples The set of coin flips, this is an N x L matrix where N is the number of examples and L is defined above.
  • targets The non-redundant sequence of underlying states, 010 from the opening example. Again this will be padded to the longest target sequence. If the length of the target is different between sequences, the list should be padded with a blank character (0 by default in the CTCLoss implementation in torch).
import numpy as np

# Holds the N examples, their non-redundant targets and the length
N = 1000
L = 50

examples = []
targets = []
target_lengths = []

for n in range(N):
    array, path = model.sample(length=L, path=True)

target = [] # Add the first state to the target # then only add new ones if they don't match the original
target.append(int(path[1].name))
for p in path[2:]:
    p_name = int(p.name)

    if target[-1] != p_name:
        target.append(p_name)

l_target = len(target)

target_lengths.append(l_target)
targets.append(target)

examples.append(array)

# Once we know the longest target we can pad the targets with the blank character
target_lengths = np.asarray(target_lengths)
longest = target_lengths.max()
targets = np.asarray([np.pad(t, (0, longest - len(t)), mode='constant') for t in targets])
examples = np.asarray(examples)

In the end we have the three inputs we need to fit the model, which map examples to targets, and target_length which accounts for the exact length of each of the hidden states since they've been padded.

Training with PyTorch

I'll use torch to build the model, so in order to do that, first we need to setup the proper data loading mechanisms: a dataset and its data loader.

from torch.utils import data
import torch

class Dataset(data.Dataset):
    def __init__(self, examples, targets, target_lengths):
        self.examples = examples
        self.targets = targets
        self.target_lengths = target_lengths

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return (torch.tensor(self.examples[i]),
                torch.tensor(self.targets[i]),
                torch.tensor(self.target_lengths[i]),
                torch.tensor(L)) # Constant length examples

dataset = Dataset(examples, targets, target_lengths)
dataloader = data.DataLoader(dataset, batch_size=100, shuffle=True)

Then we can build the model, this will be an RNN that outputs the log softmax of the layer for each position in the target for the specific type class. Here's it's two for the two coin types.

import torch
from torch import nn

class CTCModel(nn.Module):

    def __init__(self):
        super(CTCModel, self).__init__()

        self.embedding = torch.nn.Embedding(2, 2)

        self.rnn = torch.nn.RNN(2, 5, num_layers=2)
        self.linear = torch.nn.Linear(5, 3)

        self.softmax = torch.nn.LogSoftmax(2)

    def forward(self, inputs):

        inputs = inputs.transpose(0, 1)
        inputs = self.embedding(inputs)
        inputs, _ = self.rnn(inputs)
        inputs = self.linear(inputs)

        return self.softmax(inputs)

ctc_model = CTCModel()
ctc_loss = nn.CTCLoss()

At this point the model can be trained, so we'll follow the standard PyTorch pattern. I tried to use ignite but it makes expects a single X, y pair, and we need multiple outputs to keep track of meta-data like the target length.

from torch.optim import Adam

optimizer = Adam(ctc_model.parameters(), lr=0.01)

for i in range(30):
    for i_batch,(example_batch, targets, target_length, example_length) in enumerate(dataloader):
        optimizer.zero_grad()

        output = ctc_model(example_batch)
        loss = ctc_loss(output, targets, example_length, target_length)
        loss.backward()

        optimizer.step()

Then we can do some spot checking to see if the model fit the training data. This is obviously not a rigorous evaluation, but if your model can't fit the training data it's probably not going to fit the test data.

ctc_model.eval()
with torch.no_grad():
    preds = ctc_model(example_batch)

pred_token = preds.transpose(0, 1).argmax(-1)

join = lambda x: "".join(str(y.item()) for y in x)

print(join(example_batch[0]))
print(join(pred_token[0]))
print(join(targets[0]))

In this case the output is:

00001111111111111111111111111111111111111111111111
11111222222222222222222222222222222222222222222222
120000000

This roughly looks right. The example had four cases where the coin flipped 0 and the rest were 1. This corresponds to state predictions of 1 (the biased 0 coin) in the first five flips and 2 in the rest, which match the provided state.