Try drawing a digit on the canvas!
1v1 Least Squares (n/a ms)
在本文中,我们将介绍三种处理MNIST数据集的基本模型,并比较它们的特性、优势和劣势。您可以通过上方的交互界面与每个模型互动,并在条形图中查看它们的输出结果。
任务
对于不熟悉机器学习的人来说,将图像转换为数字可能看起来是一项艰巨的任务。然而,如果我们以下面的方式思考这个问题,事情就会变得简单:
灰度图像只是一个由像素亮度组成的网格,这些亮度是实数值。也就是说,每张图像都是集合 中的某个元素,其中 和 分别是图像的宽度和高度。因此,如果我们能找到一个函数 ,从 映射到 ,我们就能解决这个问题。
为此,我们使用训练图像 和标签 构建一个模型。
最小二乘法
该方法涉及为从类别0到9中选出的每一对唯一 创建45个从 的线性映射,以推断图像最可能属于 还是 。我们可以使用一些线性代数来最小化均方误差(MSE)。首先,我们不再处理 中的图像,而是将它们“展平”为 。
将 的权重定义为 ,这是一个长度为 的向量。为了得到模型的输出,我们计算
其中 是一个数字。
我们希望在所有 个样本上最小化MSE
为此,我们创建一个新的矩阵 ,它仅包含属于类别 或 的图像以及一列 用于偏置,以及一个矩阵 ,它同样仅包含 中的标签,但将 替换为 ,将 替换为 。
现在,我们的问题简化为
该问题的解由 给出,其中 是矩阵 的伪逆(证明留给读者作为练习😁)。
一旦我们获得了所有 对的 (总共45个),我们就可以表示我们所需的函数 为
def f(x):
score = [0] * 10
for i, j, f_ij in pair_functions:
out_ij = f_ij(x)
if out_ij > 0:
score[i] += 1
score[j] -= 1
else:
score[j] += 1
score[i] -= 1
return argmax(score)
每个45个模型都会为其
或
“投票”。score
数组就是你在上面条形图中看到的内容。
全连接网络
全连接网络(Fully Connected Network,简称FCN)是一个比最小二乘模型大得多的模型。与将标签投影到数据的主子空间不同,我们可以直接学习从输入空间到输出空间的映射。
对于单层网络,我们假设 可以通过以下方式近似:
其中 是某个非线性函数。通过梯度下降,可以学习矩阵 ,使得在局部邻域内的误差(分类交叉熵)最小化。在演示中,我们使用了一个两层网络,将图像映射到 ,然后将结果映射到 。这可以表示为:
其中我们需要学习矩阵 和 。在我们的例子中, ,并且
将输出 转换为概率分布,如上方的柱状图所示。
卷积网络
上述两种模型的一个局限在于,它们无法像人类那样感知视觉特征。例如,手写的1
无论在画布的哪个位置绘制,它都是
。然而,由于LS和FCN模型没有空间或邻近的概念,它们只会简单地指向最有可能拥有这些确切像素的类别。
这里,我们引入卷积。卷积操作接收一张图像和一个核,将核在图像上滑动,并生成一个输出图像,该图像包含图像像素与核值的加权和。
注意卷积如何编码空间数据,而普通网络则无法做到。由于邻近的像素通常高度相关,我们可以通过最大池化对卷积输出进行下采样,并保留大部分信息。在将图像通过一系列(训练过的)核处理后,我们得到一组矩阵,这些矩阵代表了学习到的空间特征的存在。最后,我们可以将这些矩阵展平并输入到一个FCN中,该FCN现在可以将空间数据映射到类别。
这个FCN(带有softmax激活)的输出如上所示。
## 模型比较
注意:最后三列是定性的,且相互之间具有相对性。
模型 | 参数数量 | 训练时间 | 推理时间 | 准确性 |
---|---|---|---|---|
最小二乘法 | 低 | 快 | 低 | |
全连接网络 (FCN) | 高 | 快 | 良好 | |
卷积网络 (CNN) | 非常高 | 慢 | 优秀 |
观察结果:
- 最小二乘法模型非常快,但泛化能力较弱
- CNN 的参数存储效率非常高
- 相对于 CNN 的推理时间,最小二乘法和 FCN 都非常快
练习
观察模型如何响应以下输入:
- 空白画布
- 中心位置的
1
- 最左侧的
1
- 最右侧的
1
- 中心带有一条线/点的
0
- 顶部略微断开的
9
- 略微旋转的数字
- 非常细的数字
- 非常粗的数字
你能找到两个仅相差 1 像素但映射到不同类别的输入吗?
## 实现细节
所有三个模型都在你的浏览器中以纯 JavaScript 运行;没有使用任何框架或包。
### 画布
这个 的画布由一个包含显示透明度值的数组支持。每次更新任何像素时,整个画布都会重新绘制。另一个有趣的细节是我使用的亮度衰减函数:
const plateau = 0.3;
// dist 是距离中心点的平方距离
const alpha = Math.min(1 - dist / r2 + plateau, 1);
pixels[yc * 28 + xc] = Math.max(pixels[yc * 28 + xc], alpha);
我最初尝试使用 1-dist/r2
的衰减函数,但它使中心区域过于暗淡。因此,我添加了 plateau
变量,将函数向上平移,但通过 Math.min
将其限制,以确保透明度值不超过 1。这使得笔触看起来更加自然。
### 最小二乘法
我从在ECE 174课程中与Piya Pal教授合作的一个项目中获得了这些权重。推理过程仅仅是45次点积和评分。
function evalLSModel(digit, weights) {
const scores = new Array(10).fill(0);
for (const pairConfig of weights) {
const [i, j, w] = pairConfig;
// 向量点积
const result = vdot(digit, w);
if (result > 0) {
scores[i] += 1;
scores[j] -= 1;
} else {
scores[j] += 1;
scores[i] -= 1;
}
}
return scores;
}
全连接网络
全连接网络(FCN)推理的主要工作是矩阵点积,我以标准方式实现了它。
function matrixDot(matrix1, matrix2, rows1, cols1, rows2, cols2) {
// 检查矩阵是否可以相乘
if (cols1 !== rows2) {
console.error("矩阵维度不匹配,无法进行点积运算");
return null;
}
// 初始化结果矩阵为零
const result = new Array(rows1 * cols2).fill(0);
// 执行点积运算
for (let i = 0; i < rows1; i++) {
for (let j = 0; j < cols2; j++) {
for (let k = 0; k < cols1; k++) {
result[i * cols2 + j] +=
matrix1[i * cols1 + k] * matrix2[k * cols2 + j];
}
}
}
return result;
}
为了更好的缓存局部性和减少堆分配,我将矩阵存储在单个一维 Array
中。根据上述公式,推理过程包括 2 次矩阵点积和 2 次激活函数应用。push(1)
调用用于计算偏置。
function evalNN(digit, weights) {
const digitCopy = [...digit];
digitCopy.push(1);
// 第一层参数
const [w1, [rows1, cols1]] = weights[0];
const out1 = matrixDot(digitCopy, w1, 1, digitCopy.length, rows1, cols1).map(relu);
const [w2, [rows2, cols2]] = weights[1];
out1.push(1);
const out2 = matrixDot(out1, w2, 1, out1.length, rows2, cols2);
return softmax(out2);
}
卷积网络
这里的卷积网络非常小。在 Pytorch 中,它是这样的:
nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(1600, 10),
nn.Softmax(dim=1)
)
对于推理,我们只需要将前向传播部分移植到 JavaScript 中。Conv2d(带有输入/输出通道)由以下代码实现:
function conv2d(
nInChan,
nOutChan,
inputData,
inputHeight,
inputWidth,
kernel,
bias,
) {
if (inputData.length !== inputHeight * inputWidth * nInChan) {
console.error("输入尺寸无效");
return;
}
if (kernel.length !== 3 * 3 * nInChan * nOutChan) {
console.error("卷积核尺寸无效");
return;
}
const kernelHeight = 3;
const kernelWidth = 3;
// 计算输出尺寸
const outputHeight = inputHeight - kernelHeight + 1;
const outputWidth = inputWidth - kernelWidth + 1;
const output = new Array(nOutChan * outputHeight * outputWidth).fill(0);
for (let i = 0; i < outputHeight; i++) {
for (let j = 0; j < outputWidth; j++) {
for (let outChan = 0; outChan < nOutChan; outChan++) {
let sum = 0;
// 在所有输入通道上应用滤波器
for (let inChan = 0; inChan < nInChan; inChan++) {
for (let row = 0; row < 3; row++) {
for (let col = 0; col < 3; col++) {
const inI =
inChan * (inputHeight * inputWidth) +
(i + row) * inputWidth +
(j + col);
const kI =
outChan * (nInChan * 3 * 3) +
inChan * (3 * 3) +
row * 3 +
col;
sum += inputData[inI] * kernel[kI];
}
}
}
sum += bias[outChan];
const outI =
outChan * (outputHeight * outputWidth) +
i * outputWidth +
j;
output[outI] = sum;
}
}
}
return output;
}
我知道这代码很丑。我只是把它放在这里供参考。接下来是 maxpool 的代码:
function maxPool2d(nInChannels, inputData, inputHeight, inputWidth) {
if (inputData.length !== inputHeight * inputWidth * nInChannels) {
console.error("maxpool2d: 输入高度/宽度无效");
return;
}
const poolSize = 2;
const stride = 2;
const outputHeight = Math.floor((inputHeight - poolSize) / stride) + 1;
const outputWidth = Math.floor((inputWidth - poolSize) / stride) + 1;
const output = new Array(outputHeight * outputWidth * nInChannels).fill(0);
for (let chan = 0; chan < nInChannels; chan++) {
for (let i = 0; i < outputHeight; i++) {
for (let j = 0; j < outputWidth; j++) {
let m = 0;
for (let row = 0; row < poolSize; row++) {
for (let col = 0; col < poolSize; col++) {
const ind =
chan * (inputHeight * inputWidth) +
(i * stride + row) * inputWidth +
(j * stride + col);
m = Math.max(m, inputData[ind]);
}
}
const outI =
chan * (outputHeight * outputWidth) + i * outputWidth + j;
output[outI] = m;
}
}
}
return output;
}
是的,我处理这些令人头疼的索引计算代码的唯一原因是为了那极速 🔥JavaScript🔥 Web 应用的性能。最后,这是将所有部分整合在一起的函数:
function evalConv(digit, weights) {
const [
[f1, fshape1], // 卷积滤波器权重
[b1, bshape1], // 卷积偏置
[f2, fshape2],
[b2, fbshape2],
[w, wshape], // 全连接层权重
[b, bshape], // 全连接层偏置
] = weights;
const x1 = conv2d(1, 32, digit, 28, 28, f1, b1).map(relu);
const x2 = maxPool2d(32, x1, 26, 26);
const x3 = conv2d(32, 64, x2, 13, 13, f2, b2).map(relu);
const x4 = maxPool2d(64, x3, 11, 11);
const x5 = matrixDot(w, x4, 10, 1600, 1600, 1);
const x6 = vsum(x5, b);
const out = softmax(x6);
return out;
}
结论
希望大家喜欢使用这个应用。如果有任何问题或反馈,欢迎在下方留言。