Objectif
Supposons que nous ayons un ensemble de données de caractéristiques, mais sans étiquettes. Si nous savons (ou devinons) qu’il y a classes dans l’ensemble de données, nous pourrions modéliser l’ensemble de données comme la moyenne pondérée de Gaussiennes conditionnelles par classe. C’est ce que font les modèles de mélange gaussiens.
Nous supposons que le modèle est paramétré par , où détermine le poids de la -ième Gaussienne dans le modèle.
Puisque notre ensemble de données est i.i.d., sa log-vraisemblance est
Maximisation de l’Espérance
Pour trouver les qui maximisent la vraisemblance de nos données, nous utiliserons le processus suivant :
-
Calculer des estimations initiales pour
-
Calculer la vraisemblance que appartienne à la classe . Nous notons cela ou responsabilité de pour
- Nous mettons à jour
- les poids pour qu’ils soient la responsabilité moyenne pour la Gaussienne
- les moyennes pour qu’elles soient la moyenne des points de données, pondérée par pour tout
- les variances pour qu’elles soient la variance des points de données par rapport au nouveau , de même pondérée par
Remarquez la similitude entre ce processus et la Régression par noyau ! Dans ce cas, la fonction noyau est , qui définit un voisinage de caractéristiques qui appartiennent probablement à la classe .
Les étapes 2 et 3 sont répétées jusqu’à ce que les poids convergent.
Démonstration interactive
Ci-dessous se trouve une démonstration interactive de l’algorithme EM. Les données sont générées à partir de Gaussiennes, dont les moyennes, variances et poids sont sélectionnés aléatoirement. Ensuite, un modèle GM de Gaussiennes est ajusté aux données.
Vous pouvez appuyer plusieurs fois sur Start
. Un ordinateur de bureau est recommandé.
$k$ | True $(\mu_k, \sigma_k^2)$ | Est $(\hat \mu_k, \hat \sigma_k^2)$ | True $\pi_k$ | Est $\hat{\pi}_k$ |
---|
Code
Voici le code principal de l’algorithme en JavaScript. Celui-ci ne reproduira pas directement
le graphique ci-dessus, car j’ai omis le code de tracé ainsi que le HTML/CSS.
Vous pouvez utiliser Inspect Element
pour voir l’intégralité.
function randn() {
let u = 0, v = 0;
while(u === 0) u = Math.random();
while(v === 0) v = Math.random();
return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
}
function gaussianPDF(x, mean, variance) {
const std = Math.sqrt(variance);
const coeff = 1.0 / (std * Math.sqrt(2 * Math.PI));
const exponent = -0.5 * Math.pow((x - mean)/std, 2);
return coeff * Math.exp(exponent);
}
function generateSeparatedMeans(C) {
let candidate = [];
for (let i = 0; i < C; i++) {
candidate.push(Math.random());
}
candidate.sort((a,b) => a - b);
let means = candidate.map(x => -5 + x*10);
for (let i = 1; i < C; i++) {
if (means[i] - means[i-1] < 0.5) {
means[i] = means[i-1] + 0.5;
}
}
return means;
}
function generateData(C, N=1000) {
let means = generateSeparatedMeans(C);
let variances = [];
let weights = [];
for (let i = 0; i < C; i++) {
variances.push(0.5 + 1.5*Math.random());
weights.push(1.0/C);
}
let data = [];
for (let i = 0; i < N; i++) {
const comp = Math.floor(Math.random() * C);
const x = means[comp] + Math.sqrt(variances[comp])*randn();
data.push(x);
}
return {data, means, variances, weights};
}
function decentInitialGuess(C, data) {
const N = data.length;
let means = [];
let variances = [];
let weights = [];
for (let c = 0; c < C; c++) {
means.push(data[Math.floor(Math.random()*N)]);
variances.push(1.0);
weights.push(1.0/C);
}
return {means, variances, weights};
}
function emGMM(data, C, maxIter=100, tol=1e-4) {
const N = data.length;
let init = decentInitialGuess(C, data);
let means = init.means.slice();
let variances = init.variances.slice();
let weights = init.weights.slice();
let logLikOld = -Infinity;
let paramsHistory = [];
for (let iter = 0; iter < maxIter; iter++) {
let resp = new Array(N).fill(0).map(() => new Array(C).fill(0));
for (let i = 0; i < N; i++) {
let total = 0;
for (let c = 0; c < C; c++) {
const val = weights[c]*gaussianPDF(data[i], means[c], variances[c]);
resp[i][c] = val;
total += val;
}
for (let c = 0; c < C; c++) {
resp[i][c] /= (total + 1e-15);
}
}
for (let c = 0; c < C; c++) {
let sumResp = 0;
let sumMean = 0;
let sumVar = 0;
for (let i = 0; i < N; i++) {
sumResp += resp[i][c];
sumMean += resp[i][c]*data[i];
}
const newMean = sumMean / (sumResp + 1e-15);
for (let i = 0; i < N; i++) {
let diff = data[i] - newMean;
sumVar += resp[i][c]*diff*diff;
}
const newVar = sumVar/(sumResp + 1e-15);
means[c] = newMean;
variances[c] = Math.max(newVar, 1e-6);
weights[c] = sumResp/N;
}
let logLik = 0;
for (let i = 0; i < N; i++) {
let p = 0;
for (let c = 0; c < C; c++) {
p += weights[c]*gaussianPDF(data[i], means[c], variances[c]);
}
logLik += Math.log(p + 1e-15);
}
paramsHistory.push({
means: means.slice(),
variances: variances.slice(),
weights: weights.slice()
});
if (Math.abs(logLik - logLikOld) < tol) {
break;
}
logLikOld = logLik;
}
return paramsHistory;
}