Skorch: Help constructing classifier for multiple outputs

I am attempting to learn Skorch by translating a simple pytorch model that predicts the 2 digits contained in a set of MNIST multi digit pictures. These pictures contain 2 overlapping digits which are the output lables (y). I am getting the following error:

ValueError: Stratified CV requires explicitely passing a suitable y

I followed the "MNIST with SciKit-Learn and skorch" notebook AND applied the multiple output fixes outlined in "Multiple return values from forward" by creating a custom get_loss function. Data dimensions are:

X - (40000, 1, 4, 28)
y - (40000, 2)

class Flatten(nn.Module):
    """A custom layer that views an input as 1D."""

    def forward(self, input):
        return input.view(input.size(0), -1)


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool2 = nn.MaxPool2d((2, 2))
        self.flatten = Flatten()
        self.fc1 = nn.Linear(2880, 64)
        self.drop1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(64, 10)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.drop1(x)
        out_first_digit = self.fc2(x)
        out_second_digit = self.fc3(x)

        return out_first_digit, out_second_digit


torch.manual_seed(0)

class CNN_net(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, *args, **kwargs):

        loss1 = F.cross_entropy(y_pred[0], y_true[:,0])
        loss2 = F.cross_entropy(y_pred[1], y_true[:,1])

        return 0.5 * (loss1 + loss2)

net = CNN_net(
    CNN,
    max_epochs=5,
    lr=0.1,
    device=device,
)

net.fit(X_train, y_train);
  1. Do I need to modify the format of y?
  2. Do I need to construct additional custom functions (predict)?
  3. Any other suggestions?