交互式MNIST探索器

在画布上绘制数字,观察AI如何猜测它是什么!

Try drawing a digit on the canvas!

1v1 Least Squares (n/a ms)

Fully Connected Network (n/a ms)
Convolutional Network (n/a ms)

在本文中,我们将介绍三种处理MNIST数据集的基本模型,并比较它们的特性、优势和劣势。您可以通过上方的交互界面与每个模型互动,并在条形图中查看它们的输出结果。

任务

对于不熟悉机器学习的人来说,将图像转换为数字可能看起来是一项艰巨的任务。然而,如果我们以下面的方式思考这个问题,事情就会变得简单:

灰度图像只是一个由像素亮度组成的网格,这些亮度是实数值。也就是说,每张图像都是集合 中的某个元素,其中 分别是图像的宽度和高度。因此,如果我们能找到一个函数 ,从 映射到 ,我们就能解决这个问题。

为此,我们使用训练图像 和标签 构建一个模型。

最小二乘法

该方法涉及为从类别0到9中选出的每一对唯一 创建45个从 的线性映射,以推断图像最可能属于 还是 。我们可以使用一些线性代数来最小化均方误差(MSE)。首先,我们不再处理 中的图像,而是将它们“展平”为

的权重定义为 ,这是一个长度为 的向量。为了得到模型的输出,我们计算

其中 是一个数字。

我们希望在所有 个样本上最小化MSE

为此,我们创建一个新的矩阵 ,它仅包含属于类别 的图像以及一列 用于偏置,以及一个矩阵 ,它同样仅包含 中的标签,但将 替换为 ,将 替换为

现在,我们的问题简化为

该问题的解由 给出,其中 是矩阵 的伪逆(证明留给读者作为练习😁)。

一旦我们获得了所有 对的 (总共45个),我们就可以表示我们所需的函数

def f(x):
    score = [0] * 10
    for i, j, f_ij in pair_functions:
        out_ij = f_ij(x)
        if out_ij > 0:
            score[i] += 1
            score[j] -= 1
        else:
            score[j] += 1
            score[i] -= 1
    return argmax(score)

每个45个模型都会为其 “投票”。score数组就是你在上面条形图中看到的内容。

全连接网络

全连接网络(Fully Connected Network,简称FCN)是一个比最小二乘模型大得多的模型。与将标签投影到数据的主子空间不同,我们可以直接学习从输入空间到输出空间的映射。

对于单层网络,我们假设 可以通过以下方式近似:

其中 是某个非线性函数。通过梯度下降,可以学习矩阵 ,使得在局部邻域内的误差(分类交叉熵)最小化。在演示中,我们使用了一个两层网络,将图像映射到 ,然后将结果映射到 。这可以表示为:

其中我们需要学习矩阵 。在我们的例子中, ,并且

将输出 转换为概率分布,如上方的柱状图所示。

卷积网络

上述两种模型的一个局限在于,它们无法像人类那样感知视觉特征。例如,手写的1无论在画布的哪个位置绘制,它都是 。然而,由于LS和FCN模型没有空间或邻近的概念,它们只会简单地指向最有可能拥有这些确切像素的类别。

这里,我们引入卷积。卷积操作接收一张图像和一个,将核在图像上滑动,并生成一个输出图像,该图像包含图像像素与核值的加权和。

注意卷积如何编码空间数据,而普通网络则无法做到。由于邻近的像素通常高度相关,我们可以通过最大池化对卷积输出进行下采样,并保留大部分信息。在将图像通过一系列(训练过的)核处理后,我们得到一组矩阵,这些矩阵代表了学习到的空间特征的存在。最后,我们可以将这些矩阵展平并输入到一个FCN中,该FCN现在可以将空间数据映射到类别。

这个FCN(带有softmax激活)的输出如上所示。

## 模型比较

注意:最后三列是定性的,且相互之间具有相对性。

模型 参数数量 训练时间 推理时间 准确性
最小二乘法
全连接网络 (FCN) 良好
卷积网络 (CNN) 非常高 优秀

观察结果:

  • 最小二乘法模型非常快,但泛化能力较弱
  • CNN 的参数存储效率非常高
  • 相对于 CNN 的推理时间,最小二乘法和 FCN 都非常快

练习

观察模型如何响应以下输入:

  • 空白画布
  • 中心位置的 1
  • 最左侧的 1
  • 最右侧的 1
  • 中心带有一条线/点的 0
  • 顶部略微断开的 9
  • 略微旋转的数字
  • 非常细的数字
  • 非常粗的数字

你能找到两个仅相差 1 像素但映射到不同类别的输入吗?

## 实现细节

所有三个模型都在你的浏览器中以纯 JavaScript 运行;没有使用任何框架或包。

### 画布

这个 的画布由一个包含显示透明度值的数组支持。每次更新任何像素时,整个画布都会重新绘制。另一个有趣的细节是我使用的亮度衰减函数:

const plateau = 0.3;
// dist 是距离中心点的平方距离
const alpha = Math.min(1 - dist / r2 + plateau, 1);
pixels[yc * 28 + xc] = Math.max(pixels[yc * 28 + xc], alpha);

我最初尝试使用 1-dist/r2 的衰减函数,但它使中心区域过于暗淡。因此,我添加了 plateau 变量,将函数向上平移,但通过 Math.min 将其限制,以确保透明度值不超过 1。这使得笔触看起来更加自然。

### 最小二乘法

我从在ECE 174课程中与Piya Pal教授合作的一个项目中获得了这些权重。推理过程仅仅是45次点积和评分。

function evalLSModel(digit, weights) {
    const scores = new Array(10).fill(0);
    for (const pairConfig of weights) {
        const [i, j, w] = pairConfig;
        // 向量点积
        const result = vdot(digit, w);
        if (result > 0) {
            scores[i] += 1;
            scores[j] -= 1;
        } else {
            scores[j] += 1;
            scores[i] -= 1;
        }
    }
    return scores;
}

全连接网络

全连接网络(FCN)推理的主要工作是矩阵点积,我以标准方式实现了它。

function matrixDot(matrix1, matrix2, rows1, cols1, rows2, cols2) {
    // 检查矩阵是否可以相乘
    if (cols1 !== rows2) {
        console.error("矩阵维度不匹配,无法进行点积运算");
        return null;
    }

    // 初始化结果矩阵为零
    const result = new Array(rows1 * cols2).fill(0);

    // 执行点积运算
    for (let i = 0; i < rows1; i++) {
        for (let j = 0; j < cols2; j++) {
            for (let k = 0; k < cols1; k++) {
                result[i * cols2 + j] +=
                    matrix1[i * cols1 + k] * matrix2[k * cols2 + j];
            }
        }
    }

    return result;
}

为了更好的缓存局部性和减少堆分配,我将矩阵存储在单个一维 Array 中。根据上述公式,推理过程包括 2 次矩阵点积和 2 次激活函数应用。push(1) 调用用于计算偏置。

function evalNN(digit, weights) {
    const digitCopy = [...digit];
    digitCopy.push(1);
    // 第一层参数
    const [w1, [rows1, cols1]] = weights[0];
    const out1 = matrixDot(digitCopy, w1, 1, digitCopy.length, rows1, cols1).map(relu);
    const [w2, [rows2, cols2]] = weights[1];
    out1.push(1);
    const out2 = matrixDot(out1, w2, 1, out1.length, rows2, cols2);
    return softmax(out2);
}

卷积网络

这里的卷积网络非常小。在 Pytorch 中,它是这样的:

nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(32, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Dropout(0.5),
    nn.Linear(1600, 10),
    nn.Softmax(dim=1)
)

对于推理,我们只需要将前向传播部分移植到 JavaScript 中。Conv2d(带有输入/输出通道)由以下代码实现:

function conv2d(
    nInChan,
    nOutChan,
    inputData,
    inputHeight,
    inputWidth,
    kernel,
    bias,
) {
    if (inputData.length !== inputHeight * inputWidth * nInChan) {
        console.error("输入尺寸无效");
        return;
    }
    if (kernel.length !== 3 * 3 * nInChan * nOutChan) {
        console.error("卷积核尺寸无效");
        return;
    }

    const kernelHeight = 3;
    const kernelWidth = 3;

    // 计算输出尺寸
    const outputHeight = inputHeight - kernelHeight + 1;
    const outputWidth = inputWidth - kernelWidth + 1;

    const output = new Array(nOutChan * outputHeight * outputWidth).fill(0);

    for (let i = 0; i < outputHeight; i++) {
        for (let j = 0; j < outputWidth; j++) {
            for (let outChan = 0; outChan < nOutChan; outChan++) {
                let sum = 0;
                // 在所有输入通道上应用滤波器
                for (let inChan = 0; inChan < nInChan; inChan++) {
                    for (let row = 0; row < 3; row++) {
                        for (let col = 0; col < 3; col++) {
                            const inI =
                                inChan * (inputHeight * inputWidth) +
                                (i + row) * inputWidth +
                                (j + col);

                            const kI =
                                outChan * (nInChan * 3 * 3) +
                                inChan * (3 * 3) +
                                row * 3 +
                                col;
                            sum += inputData[inI] * kernel[kI];
                        }
                    }
                }
                sum += bias[outChan];
                const outI =
                    outChan * (outputHeight * outputWidth) +
                    i * outputWidth +
                    j;
                output[outI] = sum;
            }
        }
    }
    return output;
}

我知道这代码很丑。我只是把它放在这里供参考。接下来是 maxpool 的代码:

function maxPool2d(nInChannels, inputData, inputHeight, inputWidth) {
    if (inputData.length !== inputHeight * inputWidth * nInChannels) {
        console.error("maxpool2d: 输入高度/宽度无效");
        return;
    }
    const poolSize = 2;
    const stride = 2;
    const outputHeight = Math.floor((inputHeight - poolSize) / stride) + 1;
    const outputWidth = Math.floor((inputWidth - poolSize) / stride) + 1;
    const output = new Array(outputHeight * outputWidth * nInChannels).fill(0);

    for (let chan = 0; chan < nInChannels; chan++) {
        for (let i = 0; i < outputHeight; i++) {
            for (let j = 0; j < outputWidth; j++) {
                let m = 0;
                for (let row = 0; row < poolSize; row++) {
                    for (let col = 0; col < poolSize; col++) {
                        const ind =
                            chan * (inputHeight * inputWidth) +
                            (i * stride + row) * inputWidth +
                            (j * stride + col);
                        m = Math.max(m, inputData[ind]);
                    }
                }
                const outI =
                    chan * (outputHeight * outputWidth) + i * outputWidth + j;
                output[outI] = m;
            }
        }
    }
    return output;
}

是的,我处理这些令人头疼的索引计算代码的唯一原因是为了那极速 🔥JavaScript🔥 Web 应用的性能。最后,这是将所有部分整合在一起的函数:

function evalConv(digit, weights) {
    const [
        [f1, fshape1], // 卷积滤波器权重
        [b1, bshape1], // 卷积偏置
        [f2, fshape2],
        [b2, fbshape2],
        [w, wshape],   // 全连接层权重
        [b, bshape],   // 全连接层偏置
    ] = weights;

    const x1 = conv2d(1, 32, digit, 28, 28, f1, b1).map(relu);
    const x2 = maxPool2d(32, x1, 26, 26);
    const x3 = conv2d(32, 64, x2, 13, 13, f2, b2).map(relu);
    const x4 = maxPool2d(64, x3, 11, 11);
    const x5 = matrixDot(w, x4, 10, 1600, 1600, 1);
    const x6 = vsum(x5, b);
    const out = softmax(x6);
    return out;
}

结论

希望大家喜欢使用这个应用。如果有任何问题或反馈,欢迎在下方留言。