Burn-readme-4.rs
· 567 B · Rust
原始檔案
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: nn::Linear<B>,
linear_outer: nn::Linear<B>,
dropout: nn::Dropout,
gelu: nn::Gelu,
}
impl<B: Backend> PositionWiseFeedForward<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
1 | use burn::nn; |
2 | use burn::module::Module; |
3 | use burn::tensor::backend::Backend; |
4 | |
5 | #[derive(Module, Debug)] |
6 | pub struct PositionWiseFeedForward<B: Backend> { |
7 | linear_inner: nn::Linear<B>, |
8 | linear_outer: nn::Linear<B>, |
9 | dropout: nn::Dropout, |
10 | gelu: nn::Gelu, |
11 | } |
12 | |
13 | impl<B: Backend> PositionWiseFeedForward<B> { |
14 | pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> { |
15 | let x = self.linear_inner.forward(input); |
16 | let x = self.gelu.forward(x); |
17 | let x = self.dropout.forward(x); |
18 | |
19 | self.linear_outer.forward(x) |
20 | } |
21 | } |