最近想要做一个获取博文头图的颜色的功能,想到了好久好久之前写过一篇通过 中位切分算法 来实现图片的颜色提取。
中位切分算法与 K-Means 聚类算法#
中位切分算法是一种基于颜色直方图的算法,通过递归将颜色空间划分成小立方体,最后每个小立方体的颜色值用该立方体内颜色的中位数代表。这种算法在图像颜色分布不均匀的情况下容易出现明显的颜色偏差,导致提取出来的颜色不准确,例如图片有大量红色,少量绿色,提取出来的颜色会偏向于红色。同时使用中位切分算法的话一张图片每次提取出来的颜色都是一样的。
而 K-Means 聚类算法 比较擅长处理图像颜色分布不均匀的情况,而且采用 随机初始化聚类中心 提取出来的颜色可以实现每次都不一样。
当然由于 K-Means 聚类算法的收敛速度较慢,图片过大可能会导致耗时比较长。不过对于博客图片的大小来说可以忽略不计。和中位切分算法一样,K-Means 聚类算法也会将图片绘制在 canvas
里,所以也会出现跨域导致画布被污染的问题。
效果#
这是应用于网站后的效果,当你的浏览器支持显示网页的主题色的话,可以看到状态栏的颜色会变为头图的随机颜色。
下面是使用 K-Means 聚类算法获取图片五种颜色的实时渲染:
K-Means 聚类算法 JavaScript 代码#
主要 JavaScript 代码如下:
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
fetch('{图片地址}')
.then(response => response.blob())
.then(blob => createImageBitmap(blob))
.then(imageBitmap => {
canvas.width = imageBitmap.width;
canvas.height = imageBitmap.height;
ctx.drawImage(imageBitmap, 0, 0);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const data = imageData.data;
const pixels = Array.from({ length: data.length / 4 }, (_, i) => [data[i * 4], data[i * 4 + 1], data[i * 4 + 2]]);
const kMeans = new KMeans();
const clusters = kMeans.cluster(pixels, 5);
const colors = clusters.centroids;
const colorsDiv = document.getElementById('colors');
colors.forEach((color, items) => {
const div = document.createElement('div');
console.log(`颜色${items + 1}:rgb(${color[0]}, ${color[1]}, ${color[2]})`);
div.style.backgroundColor = `rgb(${color[0]}, ${color[1]}, ${color[2]})`;
colorsDiv.appendChild(div);
});
});
class KMeans {
cluster(data, k) {
const centroids = this._initCentroids(data, k);
let oldClusters = [];
let newClusters = [];
const distancesMap = new Map();
while (!this._clustersEqual(oldClusters, newClusters)) {
oldClusters = newClusters;
newClusters = Array.from({ length: k }, () => []);
data.forEach(point => {
let distances = distancesMap.get(point);
if (!distances) {
distances = centroids.map(centroid => this._euclideanDistance(point, centroid));
distancesMap.set(point, distances);
}
const nearestCentroidIndex = distances.indexOf(Math.min(...distances));
newClusters[nearestCentroidIndex].push(point);
});
centroids.forEach((centroid, i) => {
const cluster = newClusters[i];
if (cluster.length > 0) {
const [sumR, sumG, sumB] = cluster.reduce(([accR, accG, accB], [r, g, b]) => [accR + r, accG + g, accB + b], [0, 0, 0]);
centroids[i] = [sumR / cluster.length, sumG / cluster.length, sumB / cluster.length];
}
});
}
return { centroids, clusters: newClusters };
}
_initCentroids(data, k) {
const shuffledData = this._shuffle(data);
return shuffledData.slice(0, k);
}
_euclideanDistance(p1, p2) {
const dR = p1[0] - p2[0];
const dG = p1[1] - p2[1];
const dB = p1[2] - p2[2];
return Math.sqrt(dR * dR + dG * dG + dB * dB);
}
_shuffle(array) {
const shuffledArray = [...array];
for (let i = shuffledArray.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[shuffledArray[i], shuffledArray[j]] = [shuffledArray[j], shuffledArray[i]];
}
return shuffledArray;
}
_clustersEqual(oldClusters, newClusters) {
if (oldClusters.length !== newClusters.length) {
return false;
}
for (let i = 0; i < oldClusters.length; i++) {
if (oldClusters[i].length !== newClusters[i].length) {
return false;
}
for (let j = 0; j < oldClusters[i].length; j++) {
if (oldClusters[i][j] !== newClusters[i][j]) {
return false;
}
}
}
return true;
}
}
mini-batch K-Means 算法#
mini-batch K-Means 算法是 K-Means 算法的优化版,由于 mini-batch K-Means 算法仅随机选择一部分数据进行计算,可以减少计算成本,更快地处理大规模数据集。但是也是由于只随机选择一部分数据进行计算,所以 mini-batch K-Means 算法会不如 K-Means 算法的结果准确。
下面是使用 mini-batch K-Means 算法优化的结果:
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
fetch('{图片地址}')
.then(response => response.blob())
.then(blob => createImageBitmap(blob))
.then(imageBitmap => {
canvas.width = imageBitmap.width;
canvas.height = imageBitmap.height;
canvas.classList.remove("animate-pulse");
ctx.drawImage(imageBitmap, 0, 0);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const data = imageData.data;
const pixels = Array.from({
length: data.length / 4
}, (_, i) => [data[i * 4], data[i * 4 + 1], data[i * 4 + 2]]);
const kMeans = new KMeans();
const clusters = kMeans.cluster(pixels, 5);
const colors = clusters.centroids;
const colorsDiv = document.getElementById('colors');
colors.forEach((color, items) => {
const div = document.createElement('div');
console.log(`颜色${items + 1}:rgb(${color[0]}, ${color[1]}, ${color[2]})`);
div.style.backgroundColor = `rgb(${color[0]}, ${color[1]}, ${color[2]})`;
colorsDiv.appendChild(div);
});
});
class KMeans {
cluster(data, k, batchSize = 100) {
const centroids = this._initCentroids(data, k);
let oldClusters = [];
let newClusters = [];
const distancesMap = new Map();
while (!this._clustersEqual(oldClusters, newClusters)) {
oldClusters = newClusters;
newClusters = Array.from({
length: k
}, () => []);
for (let i = 0; i < data.length; i += batchSize) {
const batch = data.slice(i, i + batchSize);
const batchDistances = new Map();
batch.forEach(point => {
let distances = distancesMap.get(point);
if (!distances) {
distances = centroids.map(centroid => this._euclideanDistance(point, centroid));
distancesMap.set(point, distances);
}
batchDistances.set(point, distances);
});
batch.forEach(point => {
const distances = batchDistances.get(point);
const nearestCentroidIndex = distances.indexOf(Math.min(...distances));
newClusters[nearestCentroidIndex].push(point);
});
}
centroids.forEach((centroid, i) => {
const cluster = newClusters[i];
if (cluster.length > 0) {
const [sumR, sumG, sumB] = cluster.reduce(([accR, accG, accB], [r, g, b]) => [accR + r, accG + g, accB + b], [0, 0, 0]);
centroids[i] = [sumR / cluster.length, sumG / cluster.length, sumB / cluster.length];
}
});
}
return {
centroids,
clusters: newClusters
};
}
_initCentroids(data, k) {
const shuffledData = this._shuffle(data);
return shuffledData.slice(0, k);
}
_euclideanDistance(p1, p2) {
const dR = p1[0] - p2[0];
const dG = p1[1] - p2[1];
const dB = p1[2] - p2[2];
return Math.sqrt(dR * dR + dG * dG + dB * dB);
}
_rgbToHex(color) {
let values = color.replace(/rgba?\(/, '').replace(/\)/, '').replace(/[\s+]/g, '').split(',');
let a = parseFloat(values[3] || 1),
r = Math.floor(a * parseInt(values[0]) + (1 - a) * 255),
g = Math.floor(a * parseInt(values[1]) + (1 - a) * 255),
b = Math.floor(a * parseInt(values[2]) + (1 - a) * 255);
return "#" + ("0" + r.toString(16)).slice(-2) + ("0" + g.toString(16)).slice(-2) + ("0" + b.toString(16)).slice(-2);
}
_shuffle(array) {
const shuffledArray = [...array];
for (let i = shuffledArray.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[shuffledArray[i], shuffledArray[j]] = [shuffledArray[j], shuffledArray[i]];
}
return shuffledArray;
}
_clustersEqual(oldClusters, newClusters) {
if (oldClusters.length !== newClusters.length) {
return false;
}
for (let i = 0; i < oldClusters.length; i++) {
const oldCentroid = oldClusters[i].centroid;
const newCentroid = newClusters[i].centroid;
const dist = distance(oldCentroid, newCentroid);
if (dist > threshold) {
return false;
}
}
return true;
}
}
此文由 Mix Space 同步更新至 xLog
原始链接为 https://www.vinking.top/posts/codes/median-cut-k-means-algorithms-javascript-implementation