본문 바로가기
소프트웨어/자바스크립트

[자바스크립트] Node.js - Tensorflow.js를 이용한 이미지 분류 예제

by 만들오 2021. 1. 23.

안녕하세요? 만들오 입니다. 요즘 머신러닝 공부를 하고있습니다.

Python이 쉽고 또 많이 사용하기 때문에 접근하기 좋지만

Javascript 언어로 사용할 수 있는 방법이 있어 소개를 드립니다.

tensorflow.js의 가장 큰 장점은 브라우저에서 결과물을 쉽게 사용할 수 있다는점 입니다.

오늘은 그동안 작업한 최종 결과물만 공유하고 나중에 조금씩 나누어서 글을 작성해 보겠습니다.

const tf = require('@tensorflow/tfjs-node');
const path = require('path');
const fs = require('fs');
const { createCanvas, loadImage } = require('canvas');
const canvas = createCanvas();
const ctx = canvas.getContext('2d');

function getImgPath(startPath, fileExt) {
    const dirName = fs.readdirSync(startPath, { withFileTypes: true }).filter(dirent => dirent.isDirectory()).map(dirent => dirent.name);
    const imgPath = [];
    dirName.forEach(name => {
        const filePath = path.join(startPath, name);
        const fileName = fs.readdirSync(filePath).filter((file) => file.endsWith(fileExt));
        fileName.forEach(file => {
            const fullPath = path.join(startPath, name, file);
            imgPath.push(fullPath);
        });
    });
    return imgPath;
}

async function* imageGenerator() {
    for (let i = 0; i < imgPath.length; i++) {
        const img = await loadImage(imgPath[i]);
        canvas.width = img.width;
        canvas.height = img.height;
        ctx.drawImage(img, 0, 0);
        const tens = await tf.browser.fromPixels(canvas)
            .resizeNearestNeighbor([imgSize, imgSize])
            .div(255.0);
        yield tens;
    }
}

async function* labelGenerator() {
    const labels = imgPath.map(path => {
        const labelSplit = path.split("\\");
        const label = labelSplit[labelSplit.length - 2];
        return label;
    });
    const uniq = labels.reduce(function (a, b) {
        if (a.indexOf(b) < 0) a.push(b);
        return a;
    }, []);

    for (let j = 0; j < labels.length; j++) {
        for (let i = 0; i < uniq.length; i++) {
            if(labels[j] == uniq[i]){
                const lbl = await tf.oneHot(i, uniq.length);
                yield lbl;
            }
        }
    }
}

function makeCNN(classes) {
    const model = tf.sequential();
    model.add(tf.layers.conv2d({
        filters: 32,
        kernelSize: 3,
        activation: 'relu',
        inputShape: [imgSize, imgSize, 3],
        strides: 1,
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.conv2d({
        filters: 64,
        kernelSize: 3,
        activation: 'relu'
    }));
    model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2]
    }));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({
        units: 128,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: classes,
        activation: 'softmax'
    }));
    model.compile({
        loss: 'categoricalCrossentropy',
        optimizer: tf.train.adam(),
        metrics: ['accuracy']
    });
    return model;
}

let imgPath;
let imgSize = 24;
async function run() {
    console.log('Loading...');
    imgPath = await getImgPath('shapes', 'png');
    const xs = await tf.data.generator(imageGenerator);
    const ys = await tf.data.generator(labelGenerator);
    const ds = await tf.data.zip({xs,ys}).shuffle(imgPath.length).batch(15);
    const model = await makeCNN(3);
    await model.fitDataset(ds, {epochs:5});
    console.log('done');
}

run();

추가 수정(19.10.04)

node-canvas를 이용한 예제를 pureimage로 대체하였습니다.

const tf = require('@tensorflow/tfjs-node');
const path = require('path');
const fs = require('fs');
const PImage = require('pureimage')

function getImgPath(startPath, fileExt) {
    const dirName = fs.readdirSync(startPath, { withFileTypes: true }).filter(dirent => dirent.isDirectory()).map(dirent => dirent.name);
    const imgPath = [];
    dirName.forEach(name => {
        const filePath = path.join(startPath, name);
        const fileName = fs.readdirSync(filePath).filter((file) => file.endsWith(fileExt));
        fileName.forEach(file => {
            const fullPath = path.join(startPath, name, file);
            imgPath.push(fullPath);
        });
    });
    return imgPath;
}

function loadImage(imgPath) {
    let img
    if(imgPath.split('.')[1] == 'jpg'){
        img = PImage.decodeJPEGFromStream(fs.createReadStream(imgPath))
    }
    else if(imgPath.split('.')[1] == 'png'){
        img = PImage.decodePNGFromStream(fs.createReadStream(imgPath))
    }
    else{
        console.log('Input image is not supported format. Use jpg or png.')
    }
    return img
}

async function* imageGenerator() {
    for (let i = 0; i < imgPath.length; i++) {
        const img = await loadImage(imgPath[i]);
        const tens = await tf.browser.fromPixels(img)
            .resizeNearestNeighbor([imgSize, imgSize])
            .div(255.0);
        yield tens;
    }
}

async function* labelGenerator() {
    const labels = imgPath.map(path => {
        const labelSplit = path.split("\\");
        const label = labelSplit[labelSplit.length - 2];
        return label;
    });
    const uniq = labels.reduce(function (a, b) {
        if (a.indexOf(b) < 0) a.push(b);
        return a;
    }, []);

    for (let j = 0; j < labels.length; j++) {
        for (let i = 0; i < uniq.length; i++) {
            if(labels[j] == uniq[i]){
                const lbl = await tf.oneHot(i, uniq.length);
                yield lbl;
            }
        }
    }
}

function makeCNN(classes) {
    const model = tf.sequential();
    model.add(tf.layers.conv2d({
        filters: 32,
        kernelSize: 3,
        activation: 'relu',
        inputShape: [imgSize, imgSize, 3],
        strides: 1,
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.conv2d({
        filters: 64,
        kernelSize: 3,
        activation: 'relu'
    }));
    model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2]
    }));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({
        units: 128,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: classes,
        activation: 'softmax'
    }));
    model.compile({
        loss: 'categoricalCrossentropy',
        optimizer: tf.train.adam(),
        metrics: ['accuracy']
    });
    return model;
}

let imgPath;
let imgSize = 24;
async function run() {
    console.log('Loading...');
    imgPath = await getImgPath('cat-dog', 'jpg');
    const xs = await tf.data.generator(imageGenerator);
    const ys = await tf.data.generator(labelGenerator);
    const ds = await tf.data.zip({xs,ys}).shuffle(imgPath.length).batch(15);
    const model = await makeCNN(2);
    await model.fitDataset(ds, {epochs:5});
    console.log('done');
}

run();

* 이 글은 티스토리 카카오계정 연동정책으로 인해 이전 블로그(오코취) 글을 옮겨왔습니다.

[].

728x90

댓글