Try drawing a digit on the canvas!
1v1 Least Squares (n/a ms)
本文探讨了三种以不同方式处理MNIST数据集的基础模型,并比较了它们的特性、优势与不足。您可以通过上方交互界面体验每个模型,并在条形图中查看它们的输出结果。
任务
对于不熟悉机器学习的人来说,将图像转换为数字可能看起来是一项艰巨的任务。然而,如果我们以如下方式思考这个问题,事情就会变得简单:
一张灰度图本质上就是一个像素亮度的网格,这些亮度值是实数。也就是说,每张图像都是集合 中的一个元素,其中 分别是图像的宽度和高度。因此,如果我们能找到一个从 映射到 的函数 ,我们就能解决这个问题。
为此,我们使用训练图像 和标签 来构建一个模型。
最小二乘法
该方法涉及为从类别 0..9 中选出的每一对不同的 ,创建 45 个从 到 的线性映射,用于推断图像最可能属于 还是 。我们可以使用线性代数来最小化均方误差(MSE)。首先,我们不直接处理 中的图像,而是将其“展平”到 。
将 的权重定义为 ,这是一个长度为 的向量。为了得到模型的输出,我们计算
其中 是一个数字。
我们希望在所有 个样本上最小化 MSE
为此,我们创建一个新矩阵 ,它仅包含属于类别 或 的图像以及一列用于偏置的 ,以及一个矩阵 ,它同样仅包含 中的标签,但将 替换为 ,将 替换为 。
现在,我们的问题被简化为
其解由 给出,其中 是矩阵 的伪逆(证明留给读者作为练习 😁)。
一旦我们获得了所有 对(共 45 对)的 ,我们就可以将我们期望的函数 表示为
def f(x):
score = [0] * 10
for i, j, f_ij in pair_functions:
out_ij = f_ij(x)
if out_ij > 0:
score[i] += 1
score[j] -= 1
else:
score[j] += 1
score[i] -= 1
return argmax(score)
45 个模型中的每一个都为其
或
“投票”。上面条形图中显示的就是 score 数组。
全连接网络
全连接网络(FCN)是一个比最小二乘模型大得多的模型。我们不再是将标签投影到数据的主子空间上,而是可以直接学习从输入空间到输出空间的映射。
对于单层网络,我们假设 可以近似为
其中 是某个非线性函数。我们可以通过梯度下降学习矩阵 ,使得在局部邻域内误差(分类交叉熵)最小化。在演示中,我们使用一个2层网络,先将图像映射到 ,再将结果映射到 。这表示为
其中我们需要学习矩阵 和 。在我们的例子中, ,并且
将 的输出转换为概率分布,如上方的条形图所示。
卷积网络
上述两种模型的一个局限在于,它们无法像人类那样感知视觉特征。例如,手写数字1无论出现在画布的哪个位置,它始终是
。然而,由于LS和FCN模型不具备空间或邻近关系的概念,它们只会简单地指向最可能拥有完全一致像素分布的类别。
为此,我们引入卷积运算。卷积操作接收一张图像和一个卷积核,将卷积核在图像上滑动计算,生成输出图像,该图像的每个像素值是输入图像局部区域与卷积核权重的加权和。
请注意卷积如何编码空间信息——这是普通全连接网络所不具备的。由于相邻像素通常高度相关,我们可以通过最大池化对卷积输出进行下采样,同时保留大部分信息。将图像通过一系列(训练得到的)卷积核处理后,我们得到一组矩阵,这些矩阵表示已学习的空间特征是否存在。最后,我们将这些矩阵展平并输入到FCN中,此时FCN便能够将空间数据映射到类别。
上方展示了该FCN(采用softmax激活函数)的输出结果。
模型对比
注:最后三列为定性指标,且为相对比较。
| 模型 | 参数量 | 训练时间 | 推理时间 | 准确度 |
|---|---|---|---|---|
| 最小二乘法 | 低 | 快 | 低 | |
| 全连接网络 | 高 | 快 | 良好 | |
| 卷积网络 | 非常高 | 慢 | 优秀 |
观察结果:
- 最小二乘法模型速度极快,但泛化能力较弱
- 卷积网络的参数存储效率非常高
- 相对于卷积网络的推理时间,最小二乘法和全连接网络都非常快
练习
观察模型对这些输入的响应:
- 空白画布
- 中心位置的一个
1 - 最左侧的一个
1 - 最右侧的一个
1 - 中心带有一条线/点的
0 - 顶部略有断开的
9 - 轻微旋转的数字
- 非常细的数字
- 非常粗的数字
你能找到两个仅相差1个像素、但被分类到不同类别的输入吗?
实现细节
所有三个模型均在您的浏览器中通过纯JavaScript运行;未使用任何框架或软件包。
画布
这块 的画布由一个数字数组支持,数组中包含所显示的透明度值。每次任何像素被更新时,整个画布都会重新绘制。另一个值得注意的细节是我使用的亮度衰减函数:
const plateau = 0.3;
// dist 是到中心距离的平方
const alpha = Math.min(1 - dist / r2 + plateau, 1);
pixels[yc * 28 + xc] = Math.max(pixels[yc * 28 + xc], alpha);
我最初尝试使用 1-dist/r2 作为衰减函数,但它使中心区域褪色过多。因此我添加了 plateau 变量来将函数整体上移,同时用 Math.min 将其限制在 1 以内,确保透明度不会超过 1。这使得笔触看起来更自然。
最小二乘法
这些权重来自我在ECE 174课程中与Piya Pal教授合作完成的一个项目。推理过程仅需45次点积运算和评分。
function evalLSModel(digit, weights) {
const scores = new Array(10).fill(0);
for (const pairConfig of weights) {
const [i, j, w] = pairConfig;
// 向量点积
const result = vdot(digit, w);
if (result > 0) {
scores[i] += 1;
scores[j] -= 1;
} else {
scores[j] += 1;
scores[i] -= 1;
}
}
return scores;
}
### 全连接网络
全连接网络推理的主要工作是矩阵点积,我以标准方式实现了它。
```javascript
function matrixDot(matrix1, matrix2, rows1, cols1, rows2, cols2) {
// 检查矩阵是否可以相乘
if (cols1 !== rows2) {
console.error("Invalid matrix dimensions for dot product");
return null;
}
// 初始化结果矩阵为零
const result = new Array(rows1 * cols2).fill(0);
// 执行点积运算
for (let i = 0; i < rows1; i++) {
for (let j = 0; j < cols2; j++) {
for (let k = 0; k < cols1; k++) {
result[i * cols2 + j] +=
matrix1[i * cols1 + k] * matrix2[k * cols2 + j];
}
}
}
return result;
}
我将矩阵存储在单个一维 Array 中,以获得更好的缓存局部性和更少的堆分配。根据上述公式,推理过程包括 2 次矩阵点积和 2 次激活函数应用。push(1) 调用是为了计算偏置项。
function evalNN(digit, weights) {
const digitCopy = [...digit];
digitCopy.push(1);
// 第一层参数
const [w1, [rows1, cols1]] = weights[0];
const out1 = matrixDot(digitCopy, w1, 1, digitCopy.length, rows1, cols1).map(relu);
const [w2, [rows2, cols2]] = weights[1];
out1.push(1);
const out2 = matrixDot(out1, w2, 1, out1.length, rows2, cols2);
return softmax(out2);
}
卷积网络
这里的卷积网络相当小。在 PyTorch 中为
nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(1600, 10),
nn.Softmax(dim=1)
)
对于推理,我们只需将前向传播过程移植到 JavaScript 中。Conv2d(包含输入/输出通道)由以下函数实现:
function conv2d(
nInChan,
nOutChan,
inputData,
inputHeight,
inputWidth,
kernel,
bias,
) {
if (inputData.length !== inputHeight * inputWidth * nInChan) {
console.error("输入尺寸无效");
return;
}
if (kernel.length !== 3 * 3 * nInChan * nOutChan) {
console.error("卷积核尺寸无效");
return;
}
const kernelHeight = 3;
const kernelWidth = 3;
// 计算输出维度
const outputHeight = inputHeight - kernelHeight + 1;
const outputWidth = inputWidth - kernelWidth + 1;
const output = new Array(nOutChan * outputHeight * outputWidth).fill(0);
for (let i = 0; i < outputHeight; i++) {
for (let j = 0; j < outputWidth; j++) {
for (let outChan = 0; outChan < nOutChan; outChan++) {
let sum = 0;
// 在所有输入通道上对单个位置应用滤波器
for (let inChan = 0; inChan < nInChan; inChan++) {
for (let row = 0; row < 3; row++) {
for (let col = 0; col < 3; col++) {
const inI =
inChan * (inputHeight * inputWidth) +
(i + row) * inputWidth +
(j + col);
const kI =
outChan * (nInChan * 3 * 3) +
inChan * (3 * 3) +
row * 3 +
col;
sum += inputData[inI] * kernel[kI];
}
}
}
sum += bias[outChan];
const outI =
outChan * (outputHeight * outputWidth) +
i * outputWidth +
j;
output[outI] = sum;
}
}
}
return output;
}
我知道这很丑。我只是放在这里供参考。注意最大池化函数:
function maxPool2d(nInChannels, inputData, inputHeight, inputWidth) {
if (inputData.length !== inputHeight * inputWidth * nInChannels) {
console.error("maxpool2d: 输入高度/宽度无效");
return;
}
const poolSize = 2;
const stride = 2;
const outputHeight = Math.floor((inputHeight - poolSize) / stride) + 1;
const outputWidth = Math.floor((inputWidth - poolSize) / stride) + 1;
const output = new Array(outputHeight * outputWidth * nInChannels).fill(0);
for (let chan = 0; chan < nInChannels; chan++) {
for (let i = 0; i < outputHeight; i++) {
for (let j = 0; j < outputWidth; j++) {
let m = 0;
for (let row = 0; row < poolSize; row++) {
for (let col = 0; col < poolSize; col++) {
const ind =
chan * (inputHeight * inputWidth) +
(i * stride + row) * inputWidth +
(j * stride + col);
m = Math.max(m, inputData[ind]);
}
}
const outI =
chan * (outputHeight * outputWidth) + i * outputWidth + j;
output[outI] = m;
}
}
}
return output;
}
是的,我之所以要处理那些令人厌恶的索引计算代码,全都是为了那极速 🔥JavaScript🔥 Web 应用的性能。最后,这是将所有部分整合在一起的函数:
function evalConv(digit, weights) {
const [
[f1, fshape1], // 卷积滤波器权重
[b1, bshape1], // 卷积偏置
[f2, fshape2],
[b2, fbshape2],
[w, wshape], // 全连接层权重
[b, bshape], // 全连接层偏置
] = weights;
const x1 = conv2d(1, 32, digit, 28, 28, f1, b1).map(relu);
const x2 = maxPool2d(32, x1, 26, 26);
const x3 = conv2d(32, 64, x2, 13, 13, f2, b2).map(relu);
const x4 = maxPool2d(64, x3, 11, 11);
const x5 = matrixDot(w, x4, 10, 1600, 1600, 1);
const x6 = vsum(x5, b);
const out = softmax(x6);
return out;
}
总结
希望大家喜欢试用这款应用。如有任何问题或反馈,欢迎在下方留言。