Cnn vs rnn machine learning khi nào sử dụng năm 2024
Đối với các bạn học deep learning thì không thể không biết tới RNN, một thuật toán cực kì quan trọng chuyên xử lý thông tin dạng chuỗi. Đầu tiên, hãy nhìn xem RNN có thể làm gì. Dưới đây là một vài ví dụ. Show
Vậy, làm sao RNN làm được những việc này? Hi vọng thông qua bài viết này mình có thể cung cấp một cái nhìn rõ ràng và dễ hiêủ về RNN. Let's go !!! 2. Recurrent Neural NetworkÝ tưởng căn bảnMột cách nôm na, đối với mạng neural thông thường, chúng ta cho tất cả dữ liệu vào cùng một lúc. Nhưng đôi khi, dữ liệu của chúng ta mang ý nghĩa trình tự, tức nếu thay đổi trình tự dữ liệu, kết quả sẽ khác. Dễ thấy rõ nhất ở dữ liệu văn bản. Ví dụ, “Con ăn cơm chưa” và “Con chưa ăn cơm”, nếu tách mỗi câu theo từ, ta được bộ vocab [ ‘con’, ‘ăn’, ‘cơm’, ‘chưa’], one hot encoding và cho tất cả vào mạng neural , có thể thấy ngay, không có sự phân biệt nào giữa 2 câu trên. Việc đảo thứ tự duyệt các từ làm sai lệch ý nghĩ của câu. Nói cách khác, chúng ta cần một mạng neural có thể xử lí tuần tự. Vậy làm sao để xử lí tuần tự, đầu tiên cần đưa đầu vào vào một cách tuần tự. Mình là kiểu người hiểu nhanh hơn thông qua hình ảnh , và mình nghĩ đây là hình ảnh thể hiện rõ ràng nhất rốt cuộc RNN làm gì. Mỗi block RNN sẽ lấy thông tin từ các block trước và input hiện tại.Các x ở đây đại diện cho dữ liệu đầu vào lần lượt (được chia theo time step). xtx_{t} đại diện cho time step thứ t, và yty_{t} là output của một step. Ví dụ, x2x_{2} sẽ là vector đại diện của từ thứ 2 trong câu văn bản. Hình ảnh dưới đây cho thấy rõ hơn điều gì thực sự xảy ra trong một step.
ht=g1(Whh∗ht−1+Whx∗xt+bh) h_{t} = g1 ( W_{hh} * h_{t-1} + W_{hx} * x_{t} + b_{h} ) Hoặc có thể viết gọn hơn: ht=g1((WhhWhx)(ht−1xt)) h_{t} = g1 ( ( W_{hh} W_{hx}) ( \begin{matrix} h_{t-1} \\ x_{t} \end{matrix})) ht=g1((W)(ht−1xt)) h_{t} = g1 ( ( W) ( \begin{matrix} h_{t-1} \\ x_{t} \end{matrix}))
yt=g2(Wyh∗ht+by) y_{t} = g2 ( W_{yh} * h_{t} + b_{y} ) Rất đơn giản và cơ bản. Một step của RNN có thể được triển khai bằng code numpy như sau:
Tính toán lan truyền ngược. (BPTT - Backpropagation Through Time)Như vậy, trong quá trình training, có 3 tham số chúng ta cần tìm là Whh,Whx,WyhW_{hh},W_{hx}, W_{yh}. Chúng ta cần tính ∂L∂Whx,∂L∂Whh∂L∂Why\frac{ \partial L}{ \partial W_{hx}} ,\frac{ \partial L}{ \partial W_{hh} } \frac{ \partial L}{ \partial W_{hy}}. (với LL là loss function) (Do quá lười type công thức) Các bạn có thể tham khảo phần tính đạo hàm đầy đủ dùng chain rule trong bài viết này của anh Tuan Nguyen (hoặc nhiều toán hơn nữa ở đây). Nhìn chung, ta có thể thấy vấn đề cơ bản ở đây là: Trong mạng NN truyền thống, ta không chia sẻ tham số giữa các tầng mạng. Tuy vậy, với RNN, ta có thể thấy, để tính đạo hàm của loss theo WhhW_{hh}, ta phụ thuộc vào ht−1h_{t-1}, mà ht−1h_{t-1} lại phụ thuộc vào ht−2h_{t-2} và xt−1x_{t-1}. Nói nôm na, ta phải cộng tất cả đầu ra ở các bước trước để tính đạo hàm. Điều này gây ra một hạn chế lớn cho RNN.
Implement code RNN sử dụng thư viện hỗ trợ PyTorch:
Implement code sử dụng thư viện TensorFlow:
Như vậy, so với mạng NN bình thường, RNN có thể nắm bắt thông tin dạng chuỗi. Điều này cũng góp phần giúp RNN có thể đáp ứng chuỗi đầu vào có độ dài tùy ý, kích thước model không bị phụ thuộc vào size đầu vào. Hạn chế của RNN là gì?
Để xử lý Vanishing Gradient, có 2 cách phổ biến:
3. LSTM (Long Short-term memory)Mình nghĩ hình ảnh sau đây là rõ nét nhất để so sánh giữa RNN và LSTM. RNNHình 3.1 LSTMHình 3.2Về cơ bản, ý tưởng không khác nhau là mấy. Chúng ta chỉ thêm một số tính toán ở đây. Tất cả được tóm tắt trong hình sau. Đầu tiên, chúng ta có i,f,gi, f, g có công thức gần giống hệt nhau và chỉ khác mỗi ma trận tham số. Chính ma trận này sẽ quyết định chức năng khác nhau của từng cổng. σ\sigma là ký hiệu của hàm sigmoid. Quan sát hình 3.2 để thấy rõ hơn vị trí các cổng:
Nếu nhìn kỹ một chút, ta có thể thấy RNN truyền thống là dạng đặc biệt của LSTM. Nếu thay giá trị đầu ra của input gate là 1 và đầu ra forget gate là 0 (không nhớ trạng thái trước), ta được RNN thuần.
Tổng kếtNhìn một lượt qua kiến trúc LSTM, ta có thể tóm tắt:
Tuy vậy, với những cải tiến so với RNN thuần, LSTM đã và đang được sử dụng phổ biến. Trên thực tế, cách cài đặt LSTM cũng rất đa dạng và linh hoạt theo bài toán, tuy nhiên vẫn dựa trên LSTM chuẩn như trên. |