Membangun Jaringan Saraf Hopfield di JavaScript

Dihadapkan dengan jaringan saraf di universitas, jaringan Hopfield menjadi salah satu favorit saya. Saya terkejut bahwa itu adalah yang terakhir dalam daftar lab, karena pekerjaannya dapat ditunjukkan dengan jelas menggunakan gambar dan tidak begitu sulit untuk diterapkan.





Artikel ini menunjukkan bagaimana memecahkan masalah memulihkan gambar yang terdistorsi menggunakan jaringan saraf Hopfield, yang sebelumnya dilatih pada gambar referensi.





Saya mencoba menjelaskan langkah demi langkah dan sesederhana mungkin proses penerapan program yang memungkinkan Anda bermain dengan jaringan saraf tepat di browser, melatih jaringan menggunakan gambar yang saya buat sendiri dan menguji operasinya pada distorsi. gambar-gambar.





Sumber di Github dan demo .





Untuk implementasi Anda akan membutuhkan:





  • Peramban





  • Pemahaman dasar tentang jaringan saraf





  • Pengetahuan dasar tentang JavaScript / HTML





Sedikit teori

Jaringan saraf Hopfield adalah jaringan saraf yang sepenuhnya terhubung dengan matriks koneksi simetris. Jaringan tersebut dapat digunakan untuk mengatur memori asosiatif, sebagai filter, dan juga untuk memecahkan beberapa masalah optimasi.





-  .     ,   .   ,   , .





Diagram blok jaringan saraf Hopfieldfield

, , . (, ), .   ,  , «» ( ).





:







  1. :



    w_ {ij} = \ kiri \ {\ begin {matrix} \ sum_ {k = 1} ^ {m} x_ {i} ^ {k} * x_ {j} ^ {k} & i \ neq j \\ 0 , & i = j \ end {matriks} \ kanan.



    sayaβ€”  

    x_ {i} ^ {k}, x_ {j} ^ {k} β€” saya- j- k- .





  2.   . :

    y_ {j} (0) = x_ {j}





  3. (   ):



    y_ {j} (t + 1) = f \ kiri (\ sum_ {i = 1} ^ {n} w_ {ij} * y_ {i} (t) \ kanan)



    f β€”   [-1; 1];

    untuk β€” ;

    j = 1 ... n;  tidak β€”  .





  4.   .  β€”   3, , , . ,   .





.





Demonstrasi program

    Canvas   . HTML  CSS ,     (  ).





Canvas , ( ) .   ,     Canvas (  «»     ).





 , 10Γ—10   .   , ,   100 (  100  ).  β€” ,  βˆ’1  1, βˆ’1 β€” ,  1 β€” .





-   , .





//     10   
const gridSize = 10;
//     
const squareSize = 45;
//    (100)
const inputNodes = gridSize * gridSize;

//         ,
//      
let userImageState = [];
//      
let isDrawing = false;
//  
for (let i = 0; i < inputNodes; i += 1) {  
  userImageState[i] = -1;  
}

//   :
const userCanvas = document.getElementById('userCanvas');
const userContext = userCanvas.getContext('2d');
const netCanvas = document.getElementById('netCanvas');
const netContext = netCanvas.getContext('2d');
      
      



, .





//      
//   100  (gridSize * gridSize)
const drawGrid = (ctx) => {
  ctx.beginPath();
  ctx.fillStyle = 'white';
  ctx.lineWidth = 3;
  ctx.strokeStyle = 'black';
  for (let row = 0; row < gridSize; row += 1) {
    for (let column = 0; column < gridSize; column += 1) {
      const x = column * squareSize;
      const y = row * squareSize;
      ctx.rect(x, y, squareSize, squareSize);
      ctx.fill();
      ctx.stroke();
    }
  }
  ctx.closePath();
};
      
      



«» ,    .





//   
const handleMouseDown = (e) => {
  userContext.fillStyle = 'black';
  //      x, y
  //  squareSize  squareSize (4545 )
  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  //     ,
  //      
  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  //       
  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }

  //   (   )
  isDrawing = true;
};

//     
const handleMouseMove = (e) => {
  //   , ..      ,    
  if (!isDrawing) return;

  //  ,   handleMouseDown
  //     isDrawing = true;
  userContext.fillStyle = 'black';

  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }
};
      
      



  , , getNewSquareCoords, calcIndex  isValidIndex.  .





//      
//      
const calcIndex = (x, y, size) => x + y * size;

// ,     
const isValidIndex = (index, len) => index < len && index >= 0;

//        
//  ,      0  9
const getNewSquareCoords = (canvas, clientX, clientY, size) => {
  const rect = canvas.getBoundingClientRect();
  const x = Math.ceil((clientX - rect.left) / size) - 1;
  const y = Math.ceil((clientY - rect.top) / size) - 1;
  return { x, y };
};
      
      



.   .





const clearCurrentImage = () => {
  //    ,    
  //       
  drawGrid(userContext);
  drawGrid(netContext);
  userImageState = new Array(gridSize * gridSize).fill(-1);
};
      
      



  «» .





 β€” .   ( ).





...
const weights = [];  //   
for (let i = 0; i < inputNodes; i += 1) {
  weights[i] = new Array(inputNodes).fill(0); //       0
  userImageState[i] = -1;
}
...
      
      



    , , inputNodes .     100 ,      100 .





( )   .     . .





const memorizeImage = () => {
  for (let i = 0; i < inputNodes; i += 1) {
    for (let j = 0; j < inputNodes; j += 1) {
      if (i === j) weights[i][j] = 0;
      else {
        // ,       userImageState  
        //  -1  1,  -1 -  ,  1 -     
        weights[i][j] += userImageState[i] * userImageState[j];
      }
    }
  }
};
      
      



,   ,    ,   . :





// -  html   lodash:
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.21/lodash.min.js"></script>
...
const recognizeSignal = () => {
  let prevNetState;
  //      .  
  //       
  // (2  ),     
  const currNetState = [...userImageState];
  do {
    //    , 
		// ..     
    prevNetState = [...currNetState];
    //      3  
    for (let i = 0; i < inputNodes; i += 1) {
      let sum = 0;
      for (let j = 0; j < inputNodes; j += 1) {
        sum += weights[i][j] * prevNetState[j];
      }
      //    ( - )
      currNetState[i] = sum >= 0 ? 1 : -1;
    }
    //      
    //     - isEqual
  } while (!_.isEqual(currNetState, prevNetState));

  //    ( ),   
  drawImageFromArray(currNetState, netContext);
};
      
      



    isEqual   lodash.





drawImageFromArray.       .





const drawImageFromArray = (data, ctx) => {
  const twoDimData = [];
  //     
  while (data.length) twoDimData.push(data.splice(0, gridSize));

  //   
  drawGrid(ctx);
  //     ( )
  for (let i = 0; i < gridSize; i += 1) {
    for (let j = 0; j < gridSize; j += 1) {
      if (twoDimData[i][j] === 1) {
        ctx.fillStyle = 'black';
        ctx.fillRect((j * squareSize), (i * squareSize), squareSize, squareSize);
      }
    }
  }
};
      
      



  HTML   .





HTML
const resetButton = document.getElementById('resetButton');
const memoryButton = document.getElementById('memoryButton');
const recognizeButton = document.getElementById('recognizeButton');

//    
resetButton.addEventListener('click', () => clearCurrentImage());
memoryButton.addEventListener('click', () => memorizeImage());
recognizeButton.addEventListener('click', () => recognizeSignal());

//    
userCanvas.addEventListener('mousedown', (e) => handleMouseDown(e));
userCanvas.addEventListener('mousemove', (e) => handleMouseMove(e));
//  ,         
userCanvas.addEventListener('mouseup', () => isDrawing = false);
userCanvas.addEventListener('mouseleave', () => isDrawing = false);

//  
drawGrid(userContext);
drawGrid(netContext);
      
      



, :





Gambar referensi untuk pelatihan jaringan

:





Mencoba mengenali gambar terdistorsi dari huruf H
Mencoba mengenali gambar terdistorsi dari huruf T

! .





, saya  , 0,15 * n( tidakβ€”   ). , , , ,   ,               .





Sumber di Github dan demo .





Alih-alih sastra, kuliah digunakan oleh seorang guru yang sangat baik di jaringan saraf - Sergei Mikhailovich Roshchin , yang banyak terima kasih kepadanya.








All Articles