본문 바로가기
프로젝트/레고테크닉 개조

[프로젝트] 레고테크닉 인공지능 RC카 #2 - 아두이노 코드 공유

by 만들오 2021. 3. 18.
728x90

지금까지 만들어놓은 프로토타입V1(이후 V1모델)의 아두이노 코드를 공유합니다.

 

이 코드로 Object classification를 학습하고 결과를 실행할 수 있습니다.

 

Object tracking과 Object detection 모두 구현할 수 있는 상태지만,

 

GUI 구성이 시간을 많이 잡아먹어서 업그레이드가 지연되고 있습니다...

 

코드관련 궁금하신 부분은 댓글을 남겨주세요.

 

감사합니다.

 

#include <HardwareSerial.h>
#include <WebSocketsServer.h>
#include <WiFi.h>
#include "esp_camera.h"
#include "esp_timer.h"
#include "img_converters.h"
#include "Arduino.h"
#include "fb_gfx.h"
#include "soc/soc.h"
#include "soc/rtc_cntl_reg.h"
#include "esp_http_server.h"

#define PWDN_GPIO_NUM     32
#define RESET_GPIO_NUM    -1
#define XCLK_GPIO_NUM      0
#define SIOD_GPIO_NUM     26
#define SIOC_GPIO_NUM     27
#define Y9_GPIO_NUM       35
#define Y8_GPIO_NUM       34
#define Y7_GPIO_NUM       39
#define Y6_GPIO_NUM       36
#define Y5_GPIO_NUM       21
#define Y4_GPIO_NUM       19
#define Y3_GPIO_NUM       18
#define Y2_GPIO_NUM        5
#define VSYNC_GPIO_NUM    25
#define HREF_GPIO_NUM     23
#define PCLK_GPIO_NUM     22

const char* ssid = "YOUR SSED";
const char* password = "YOUR PASSWORD";

//Custom serial setting
HardwareSerial mySerial(2);
//Websocket setting
WebSocketsServer webSocket = WebSocketsServer(81);
void onWebSocketEvent(uint8_t num, WStype_t type, uint8_t * payload, size_t length) {
  if(type == WStype_TEXT){
    String cmd = (char*)payload;
    Serial.println(cmd);
    mySerial.println(cmd);
  }
}

#define PART_BOUNDARY "123456789000000000000987654321"
static const char* _STREAM_CONTENT_TYPE = "multipart/x-mixed-replace;boundary=" PART_BOUNDARY;
static const char* _STREAM_BOUNDARY = "\r\n--" PART_BOUNDARY "\r\n";
static const char* _STREAM_PART = "Content-Type: image/jpeg\r\nContent-Length: %u\r\n\r\n";
httpd_handle_t stream_httpd = NULL;

static const char PROGMEM INDEX_HTML[] = R"rawliteral(
<!DOCTYPE html>
<html>

<head>
    <title>MANDLOH</title>
</head>
<style>
    body {
        background-color: rgb(150, 150, 150);
        -webkit-user-select: none;
        color: white;
        font-size: 20px;
    }

    button {
        width: 185px;
        height: 105px;
        margin: 2px 1px 2px 1px;
        font-size: 35px;
        color: white;
        background-color: rgb(51, 51, 51);
        border: none;
        border-radius: 5px;
        box-shadow: 0 9px rgba(104, 104, 104, 0.5);
    }

    button:active {
        box-shadow: 0 5px rgba(104, 104, 104, 0.5);
        transform: translateY(8px);
    }
</style>
<script src="https://cdn.jsdelivr.net/npm/nipplejs@0.8.7/dist/nipplejs.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

<body>
    <div style="position:relative">
        <img id="img_stream" style="position:absolute;left:12px;"></img>
        <canvas id="canvas" style="position:absolute;left:12px;"></canvas>
    </div>
    </div>
    <div id="joystick_zone"
        style="position:absolute; top:490px; left:20px; border:1px dashed #BDBDBD; width:240px; height:240px;"></div>
    <div style="position:absolute; top:740px; left:20px; border:1px dashed #BDBDBD; width:240px; height:80px;">
        <button id="btnL" class="action"
            style="position:relative; left:5px; width: 106px; height:73px; font-size: 40px;">L</button>
        <button id="btnR" class="action"
            style="position:relative; left:10px; width: 106px; height:73px; font-size: 40px;">R</button>
    </div>
    <div style="position:absolute; top:490px; left:270px; border:1px dashed #BDBDBD; width:388px; height:330px;">
        <button id="btnA" class="action" style="position:relative; left:3px;">A</button>
        <button id="btnB" class="action" style="position:relative;">B</button>
        <button id="btnC" class="action" style="position:relative; left:3px;">C</button>
        <button id="btnD" class="action" style="position:relative;">D</button>
        <button id="btnE" class="action" style="position:relative; left:3px;">E</button>
        <button id="btnF" class="action" style="position:relative;">F</button>
    </div>
    <div style="position:absolute; left:662px; border:1px dashed #BDBDBD; width:190px; height:812px;">
        <button onclick="save_model()" style="position:relative;">SAVE</button>
        <button onclick="load_model()" style="position:relative;">LOAD</button>
        <button id="exec" style="position:relative;">EXECUTE</button>
        <button id="change" style="position:relative;">LABEL<br>SETTING</button>
        <button class="label" style="position:relative; height:88px;">ID1</button>
        <button class="label" style="position:relative; height:88px;">ID2</button>
        <button class="label" style="position:relative; height:88px;">ID3</button>
        <button class="label" style="position:relative; height:88px;">ID4</button>
    </div>

</body>

<script>
    const stream_url = "http://192.168.0.12/stream";

    //접속환경 구분(PC/MOBILE)
    const filter = "win16|win32|win64|mac|macintel";
    const platform = (filter.indexOf(navigator.platform.toLowerCase()) > 0) ? "PC" : "MOBILE";
    console.log(`Client platform : ${platform}`);

    //버튼 이벤트. class="action" 버튼만 할당한다.
    const evt_btns = (e) => {
        const id = e.srcElement.id.slice(3, 4); //btnA -> A
        const state = (e.srcElement.attributes[1].value == "action" && (e.type == "mousedown" || e.type == "touchstart")) ? "1" : "0";
        send_msg("BTN:" + id + state);
    }

    //추가 버튼 이벤트
    let is_exec = false;
    document.getElementById("exec").onclick = () => {
        is_exec = !is_exec;
        document.getElementById("exec").style.backgroundColor = is_exec ? "green" : "rgb(51, 51, 51)";
    }

    let is_change = false;
    document.getElementById("change").onclick = () => {
        is_change = !is_change;
        document.getElementById("change").style.backgroundColor = is_change ? "green" : "rgb(51, 51, 51)";
    }

    const change_label = (e) => {
        if (is_change) {
            const new_id = prompt("New ID : ");
            e.srcElement.innerHTML = new_id;
        }
    }

    const add_label = (e) => {
        if (!is_change) {
            add(e.srcElement.innerHTML);
        }
    }

    //라벨 버튼 이벤트. class="label" 버튼만 할당한다.
    for (btn of document.getElementsByClassName("label")) {
        btn.addEventListener("click", change_label);
        btn.addEventListener("click", add_label);
    }

    //PC는 마우스 이벤트만, MOBILE은 터치 이벤트만 적용한다. 동시 적용 시 MOBILE에서 중복으로 이벤트 발생.
    for (btn of document.getElementsByClassName("action")) {
        (platform == "PC") ? (btn.addEventListener("mousedown", evt_btns), btn.addEventListener("mouseup", evt_btns)) : (btn.addEventListener("touchstart", evt_btns), btn.addEventListener("touchend", evt_btns));
    }

    //GUI 설정
    const w = 640, h = 480;
    const img_stream = document.getElementById("img_stream");
    [img_stream.width, img_stream.height] = [w, h];

    const canvas = document.getElementById("canvas");
    [canvas.width, canvas.height] = [w, h];
    const ctx = canvas.getContext("2d");
    ctx.fillStyle = "rgb(0, 250, 000)";
    ctx.strokeStyle = "rgb(0, 250, 000)";
    ctx.font = "15px Arial";
    ctx.lineWidth = 3;

    //Websocket
    const ip = window.location.protocol + "//" + window.location.hostname;
    let ws;
    let is_connect = "OFF";
    const connect_ws = () => {
        if (is_connect == "OFF") {
            try {
                ws = new WebSocket(ip.replace("http", "ws") + ":81");
                ws.onopen = (msg) => {
                    is_connect = "ON";
                    img_stream.src = stream_url;
                }
                ws.onclose = (msg) => is_connect = "OFF";
                ws.onmessage = (msg) => console.log(msg.data);
            } catch {
                console.log("Websocket failed.");
            }
        }
        setTimeout(connect_ws, 5000);
    }
    connect_ws();

    //전송 함수. 이 함수를 통해 블루투스나 웹소켓으로 전송할 것
    let msg_last;
    const send_msg = (msg) => {
        if (msg != msg_last) {
            if (is_connect == "ON") ws.send(msg);
            console.log(msg);
            msg_last = msg;
        }
    }

    //키보드 이벤트
    let key_ary = [];
    window.addEventListener("keydown", (e) => (key_ary.indexOf(e.key) < 0) ? (key_ary.push(e.key), key_evt(e)) : false);
    window.addEventListener("keyup", (e) => (key_ary.indexOf(e.key) >= 0) ? (key_ary.splice(key_ary.indexOf(e.key), 1), key_evt(e)) : false);

    let key_evt = (e) => {
        let command_key, value;
        if (key_ary.includes("ArrowUp") && key_ary.includes("ArrowLeft")) {
            command_key = "Q5";
            value = { x: -100, y: -100 };
        } else if (key_ary.includes("ArrowUp") && key_ary.includes("ArrowRight")) {
            command_key = "E5";
            value = { x: 100, y: -100 };
        } else if (key_ary.includes("ArrowDown") && key_ary.includes("ArrowLeft")) {
            command_key = "Z5";
            value = { x: -100, y: 100 };
        } else if (key_ary.includes("ArrowDown") && key_ary.includes("ArrowRight")) {
            command_key = "C5";
            value = { x: 100, y: 100 };
        } else if (key_ary.includes("ArrowUp")) {
            command_key = "W5";
            value = { x: 0, y: -100 };
        } else if (key_ary.includes("ArrowDown")) {
            command_key = "X5";
            value = { x: 0, y: 100 };
        } else if (key_ary.includes("ArrowLeft")) {
            command_key = "A5";
            value = { x: -100, y: 0 };
        } else if (key_ary.includes("ArrowRight")) {
            command_key = "D5";
            value = { x: 100, y: 0 };
        } else {
            command_key = "S0";
            value = { x: 0, y: 0 };
        }
        joystick[0].setPosition(undefined, value);
        send_msg("JOY:" + command_key);
    }

    const joystick = nipplejs.create({
        mode: 'static',
        position: { left: "120px", top: "120px" },
        color: 'red',
        size: 200,
        zone: document.getElementById("joystick_zone")
    });

    //조이스틱 이벤트
    joystick.on('move', function (evt, obj) {
        const r = Math.round(obj.distance / 20, 0);
        const th = obj.angle.degree;
        let dir;
        if (r > 0) {
            if (th >= 0 && th < 22.5) dir = "D";
            else if (th >= 22.5 && th < 22.5 + 45) dir = "E";
            else if (th >= 22.5 && th < 22.5 + 90) dir = "W";
            else if (th >= 22.5 && th < 22.5 + 135) dir = "Q";
            else if (th >= 22.5 && th < 22.5 + 180) dir = "A";
            else if (th >= 22.5 && th < 22.5 + 225) dir = "Z";
            else if (th >= 22.5 && th < 22.5 + 270) dir = "X";
            else if (th >= 22.5 && th < 22.5 + 315) dir = "C";
            else if (th >= 22.5 && th < 22.5 + 360) dir = "D";
        } else dir = "S";
        send_msg("JOY:" + dir + String(r));
    });

    joystick.on('end', function (evt, obj) {
        send_msg("JOY:" + "S0")
    });

    //KNN
    let classifier, mobilenetModule;
    let model_state = "LOADING";
    const init = async function () {
        classifier = knnClassifier.create();
        mobilenetModule = await mobilenet.load();
        model_state = "READY";
    }
    init();

    const add = function (label) {
        if (model_state == "READY") {
            const img = tf.browser.fromPixels(img_stream);
            const logits = mobilenetModule.infer(img, 'conv_preds');
            classifier.addExample(logits, label);
            console.log(`${label} added.`);
            img.dispose();
            logits.dispose();
        } else {
            console.log("Model not loaded.")
        }
    }

    let prediction;
    let error = "";
    const pred = async function () {
        if (is_exec) {
            if (model_state == "READY") {
                if (Object.keys(classifier.classExampleCount).length > 1) {
                    const x = tf.browser.fromPixels(img_stream);
                    const logits = mobilenetModule.infer(x, 'conv_preds');
                    const result = await classifier.predictClass(logits);
                    x.dispose();
                    logits.dispose();
                    prediction = result.label;
                    error = "";
                    send_msg("AI:" + prediction);
                } else {
                    console.log("Need at least 2 IDs.");
                    error = "Need at least 2 IDs.";
                }
            } else {
                console.log("Model not loaded.");
                error = "Model not loaded.";
            }
        }
        setTimeout(pred, 100);
    }
    pred();


    const save_model = () => {
        const dataset = classifier.getClassifierDataset();
        const datasetObj = {}
        Object.keys(dataset).forEach((key) => {
            const data = dataset[key].dataSync();
            datasetObj[key] = Array.from(data);
        });
        const jsonStr = JSON.stringify(datasetObj)
        const link = document.createElement('a');
        link.download = "model.json";
        link.href = 'data:text/text;charset=utf-8,' + encodeURIComponent(jsonStr);
        document.body.appendChild(link);
        link.click();
        link.remove();

    }

    const load_model = () => {
        const input = document.createElement("input");
        input.type = "file";
        input.onchange = evt => {
            const target = event.target || window.event.srcElement;
            const files = target.files;
            const fr = new FileReader();
            if (files.length > 0) {
                fr.onload = () => {
                    const dataset = fr.result;
                    const tensorObj = JSON.parse(dataset);
                    Object.keys(tensorObj).forEach((key) => {
                        tensorObj[key] = tf.tensor(tensorObj[key], [tensorObj[key].length / 1024, 1024]);
                    })
                    classifier.setClassifierDataset(tensorObj);
                }
                fr.readAsText(files[0]);
            }
        }
        input.click();
    }

    const refresh = () => {
        ctx.clearRect(0, 0, w, h);
        ctx.fillText("MODEL : " + model_state, 5, 20);
        if (model_state == "READY") ctx.fillText(JSON.stringify(classifier.classExampleCount), 130, 20);
        ctx.fillText("CONNECTION : " + is_connect, 5, 40);
        if (is_exec) ctx.fillText("PREDICTION : " + prediction, 5, 60);
        if (error.length > 0) ctx.fillText(error, 5, 80);
        setTimeout(refresh, 0);
    }
    refresh();

</script>

</html>
)rawliteral";

static esp_err_t index_handler(httpd_req_t *req){
    httpd_resp_set_type(req, "text/html");
    return httpd_resp_send(req, (const char *)INDEX_HTML, strlen(INDEX_HTML));
}

static esp_err_t stream_handler(httpd_req_t *req){
  camera_fb_t * fb = NULL;
  esp_err_t res = ESP_OK;
  size_t _jpg_buf_len = 0;
  uint8_t * _jpg_buf = NULL;
  char * part_buf[64];

  res = httpd_resp_set_type(req, _STREAM_CONTENT_TYPE);
  if(res != ESP_OK){
    return res;
  }

  httpd_resp_set_hdr(req, "Access-Control-Allow-Origin", "*"); //Added.

  while(true){
    fb = esp_camera_fb_get();
    if (!fb) {
      Serial.println("ERROR");
      res = ESP_FAIL;
    } else {
      if(fb->width > 400){
        if(fb->format != PIXFORMAT_JPEG){
          bool jpeg_converted = frame2jpg(fb, 80, &_jpg_buf, &_jpg_buf_len);
          esp_camera_fb_return(fb);
          fb = NULL;
          if(!jpeg_converted){
            Serial.println("ERROR");
            res = ESP_FAIL;
          }
        } else {
          _jpg_buf_len = fb->len;
          _jpg_buf = fb->buf;
        }
      }
    }
    if(res == ESP_OK){
      size_t hlen = snprintf((char *)part_buf, 64, _STREAM_PART, _jpg_buf_len);
      res = httpd_resp_send_chunk(req, (const char *)part_buf, hlen);
    }
    if(res == ESP_OK){
      res = httpd_resp_send_chunk(req, (const char *)_jpg_buf, _jpg_buf_len);
    }
    if(res == ESP_OK){
      res = httpd_resp_send_chunk(req, _STREAM_BOUNDARY, strlen(_STREAM_BOUNDARY));
    }
    if(fb){
      esp_camera_fb_return(fb);
      fb = NULL;
      _jpg_buf = NULL;
    } else if(_jpg_buf){
      free(_jpg_buf);
      _jpg_buf = NULL;
    }
    if(res != ESP_OK){
      break;
    }
  }
  return res;
}

void startCameraServer(){
  httpd_config_t config = HTTPD_DEFAULT_CONFIG();
  config.server_port = 80;

  httpd_uri_t index_uri = {
    .uri       = "/",
    .method    = HTTP_GET,
    .handler   = index_handler,
    .user_ctx  = NULL
  };

  httpd_uri_t stream_uri = {
    .uri       = "/stream",
    .method    = HTTP_GET,
    .handler   = stream_handler,
    .user_ctx  = NULL
    };
  if (httpd_start(&stream_httpd, &config) == ESP_OK) {
    httpd_register_uri_handler(stream_httpd, &index_uri);
    httpd_register_uri_handler(stream_httpd, &stream_uri);
  }
}

void setup() {
  pinMode(4, OUTPUT);
  WRITE_PERI_REG(RTC_CNTL_BROWN_OUT_REG, 0); //disable brownout detector
  Serial.begin(115200);
  mySerial.begin(115200, SERIAL_8N1, 12, 13);
  Serial.setDebugOutput(false);
  camera_config_t config;
  config.ledc_channel = LEDC_CHANNEL_0;
  config.ledc_timer = LEDC_TIMER_0;
  config.pin_d0 = Y2_GPIO_NUM;
  config.pin_d1 = Y3_GPIO_NUM;
  config.pin_d2 = Y4_GPIO_NUM;
  config.pin_d3 = Y5_GPIO_NUM;
  config.pin_d4 = Y6_GPIO_NUM;
  config.pin_d5 = Y7_GPIO_NUM;
  config.pin_d6 = Y8_GPIO_NUM;
  config.pin_d7 = Y9_GPIO_NUM;
  config.pin_xclk = XCLK_GPIO_NUM;
  config.pin_pclk = PCLK_GPIO_NUM;
  config.pin_vsync = VSYNC_GPIO_NUM;
  config.pin_href = HREF_GPIO_NUM;
  config.pin_sscb_sda = SIOD_GPIO_NUM;
  config.pin_sscb_scl = SIOC_GPIO_NUM;
  config.pin_pwdn = PWDN_GPIO_NUM;
  config.pin_reset = RESET_GPIO_NUM;
  config.xclk_freq_hz = 20000000;
  config.pixel_format = PIXFORMAT_JPEG; 
  config.frame_size = FRAMESIZE_VGA;
  config.jpeg_quality = 10;
  config.fb_count = 2;
  
  // Camera init
  esp_err_t err = esp_camera_init(&config);
  if (err != ESP_OK) {
    Serial.printf("ERROR");
    return;
  }

  // Wi-Fi connection
  WiFi.begin(ssid, password);
  while (WiFi.status() != WL_CONNECTED) {
    delay(500);
  }
  Serial.println("CONNECT");
  mySerial.println("CONNECT");
  digitalWrite(4, HIGH);
  delay(30);
  digitalWrite(4, LOW);
  
  Serial.print("http://");
  Serial.println(WiFi.localIP());
    
  // Start streaming web server
  startCameraServer();

  // Start websocekt server
  webSocket.begin();
  webSocket.onEvent(onWebSocketEvent);
}

void loop() {
  webSocket.loop();
}

댓글