Transformer

Transformer

September 17, 2024

C X

Camille X

			import torch
import torch.nn as nn
import math
class InputEmbeddings(nn.Module):
    
    def __init__(self, vocab_size: int , d_model: int):
        super.__init__()
        self.vocab_size = vocab_size
        self.d_model=d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x):
        return self.embedding(x)*math.sqrt(self.d_model)
 
class PositionEmbedding(nn.Module):
    
    def __init__(self, Seq_len:int, d_model:int, dropout:float)-> None:
        super().__init__()
        self.Seq_len=Seq_len
        self.d_model=d_model
        self.dropout=nn.Dropout(dropout)
        
        pe=torch.zeros(self.Seq_len,self.d_model)
        position = torch.arange(0,self.Seq_len,dtype=torch.float).unsqueeze(1)
        
        # apply sin to 2i
        div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer("pe",pe)
        
    def forward(self,x):
        x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
        return self.dropout(x)
    
class LayerNormalization(nn.Module):
    
    def __init__(self, epsilon:float=10**-6)->None:
        super().__init__()
        self.epsilon=epsilon
        self.alpha=nn.Parameter(torch.ones(1))
        self.bias=nn.Parameter(torch.zeros(1))
    
    def forward(self,x):
        mean = x.mean(dim = -1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha*(x-mean) / (std+self.epsilon) + self.bias
 
class FeedForwardBlock(nn.Module):
    
    def __init__(self, d_model: int, d_ff: int ,dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self,x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
    
    
class MultiHeadAttentionBlock(nn.Module):
    
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h 
        assert d_model % h ==0, "d_model isn't divisible by h"
        
        self.d_k=d_model//h
        self.w_q = nn.Linear(d_model,d_model)
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        
        self.w_o=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)
        
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        
        # (Batch, h, Seq_len, d_k) --> # (Batch, h, Seq_len, Seq_len)
        attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(value)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0 ,-1e-9)
        attention_scores = attention_scores.softmax(dim = -1) # (Batch, h, Seq_len, Seq_len)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value), attention_scores
    
    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        # (Batch, Seq_len, d_model) --> (Batch, Seq_len, h, d_k) --> (Batch, h, Seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)
        
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # (Batch, h, Seq_len, d_k) --> (Batch, Seq_len, h, d_k) --> (Batch, Seq_len, d_model)
        x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
        return self.w_o(x)
 
 
class ResidualConnection(nn.Module):
    
    def __init__(self, dropout: float) -> None:
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
 
 
class EncoderBlock(nn.Module):
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout), ResidualConnection(dropout)])
        
    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x 
 
class Encoder(nn.Module):
    
    def __init__(self, layers: nn.ModuleList) -> None:
        super.__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, mask):
        for layer in self.layers:
            x=layer(x,mask)
        return self.norm(x)
    
class DecoderBlock(nn.Module):
 
    def __init__(self, self_attention_block: MultiHeadAttentionBlock,
                       cross_attention_block: MultiHeadAttentionBlock,
                       feed_forward_block: FeedForwardBlock, 
                       dropout: float) -> None:
        super.__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout),ResidualConnection(dropout),ResidualConnection(dropout)])
        
    def forward(self, x, encoder_output, src_mask, tag_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tag_mask))
        x = self.residual_connections[1](x, lambda x: self.self_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x 
    
class Decoder(nn.Module):
    
    def __init__(self, layers: nn.ModuleList) -> None:
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)
    
class ProjectionLayer(nn.Module):
    
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super.__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)
 
class Transformer(nn.Module):
    
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, 
                 src_pos: PositionEmbedding, tgt_pos: PositionEmbedding, projection_layer: ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    
    def encoder(self, src, src_mask):
        return self.encoder(self.src_pos(self.src_embed(src)), src_mask)
    
    def decoder(self, tgt, encoder_output, src_mask, tgt_mask):
        return self.decoder(self.tgt_pos(self.tgt_embed(tgt)), encoder_output, src_mask, tgt_mask)  
    
    def project(self, x):
        return self.projection_layer(x)
    
class build_transformer(nn.Module):
    
    def __init__(self, src_vocab_size: int, 
                       tgt_vocab_size: int, 
                       src_seq_len: int, 
                       tgt_seq_len: int,
                       d_model: int = 512, 
                       N: int =6, 
                       h: int=8, 
                       d_ff: int=2048,
                       dropout: float=0.1
                ) -> Transformer:
        super().__init__()
        src_embed = InputEmbeddings(src_vocab_size, d_model)
        tgt_embed = InputEmbeddings(tgt_vocab_size, d_model)
        
        src_pos = PositionEmbedding(src_seq_len, d_model, dropout)
        tgt_pos = PositionEmbedding(tgt_seq_len, d_model, dropout)
        
        encoder_blocks = []
        for _ in range(N):
            encoder_blocks.append(EncoderBlock(MultiHeadAttentionBlock(d_model, h, dropout), 
                                               FeedForwardBlock(d_model, d_ff, dropout), 
                                               dropout))
            
        encoder = Encoder(nn.ModuleList(encoder_blocks))
        
        decoder_blocks = []
        for _ in range(N):
            decoder_blocks.append(DecoderBlock(MultiHeadAttentionBlock(d_model, h, dropout), 
                                               MultiHeadAttentionBlock(d_model, h, dropout),
                                               FeedForwardBlock(d_model, d_ff, dropout), 
                                               dropout))
        decoder = Decoder(nn.ModuleList(decoder_blocks))
        
        projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
        
        transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
        
        for p in transformer.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
        return transformer
        
        
	
			import torch
import torch.nn as nn
import math
class InputEmbeddings(nn.Module):
    
    def __init__(self, vocab_size: int , d_model: int):
        super.__init__()
        self.vocab_size = vocab_size
        self.d_model=d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x):
        return self.embedding(x)*math.sqrt(self.d_model)
 
class PositionEmbedding(nn.Module):
    
    def __init__(self, Seq_len:int, d_model:int, dropout:float)-> None:
        super().__init__()
        self.Seq_len=Seq_len
        self.d_model=d_model
        self.dropout=nn.Dropout(dropout)
        
        pe=torch.zeros(self.Seq_len,self.d_model)
        position = torch.arange(0,self.Seq_len,dtype=torch.float).unsqueeze(1)
        
        # apply sin to 2i
        div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer("pe",pe)
        
    def forward(self,x):
        x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
        return self.dropout(x)
    
class LayerNormalization(nn.Module):
    
    def __init__(self, epsilon:float=10**-6)->None:
        super().__init__()
        self.epsilon=epsilon
        self.alpha=nn.Parameter(torch.ones(1))
        self.bias=nn.Parameter(torch.zeros(1))
    
    def forward(self,x):
        mean = x.mean(dim = -1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha*(x-mean) / (std+self.epsilon) + self.bias
 
class FeedForwardBlock(nn.Module):
    
    def __init__(self, d_model: int, d_ff: int ,dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self,x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
    
    
class MultiHeadAttentionBlock(nn.Module):
    
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h 
        assert d_model % h ==0, "d_model isn't divisible by h"
        
        self.d_k=d_model//h
        self.w_q = nn.Linear(d_model,d_model)
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        
        self.w_o=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)
        
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        
        # (Batch, h, Seq_len, d_k) --> # (Batch, h, Seq_len, Seq_len)
        attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(value)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0 ,-1e-9)
        attention_scores = attention_scores.softmax(dim = -1) # (Batch, h, Seq_len, Seq_len)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value), attention_scores
    
    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        # (Batch, Seq_len, d_model) --> (Batch, Seq_len, h, d_k) --> (Batch, h, Seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)
        
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # (Batch, h, Seq_len, d_k) --> (Batch, Seq_len, h, d_k) --> (Batch, Seq_len, d_model)
        x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
        return self.w_o(x)
 
 
class ResidualConnection(nn.Module):
    
    def __init__(self, dropout: float) -> None:
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
 
 
class EncoderBlock(nn.Module):
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout), ResidualConnection(dropout)])
        
    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x 
 
class Encoder(nn.Module):
    
    def __init__(self, layers: nn.ModuleList) -> None:
        super.__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, mask):
        for layer in self.layers:
            x=layer(x,mask)
        return self.norm(x)
    
class DecoderBlock(nn.Module):
 
    def __init__(self, self_attention_block: MultiHeadAttentionBlock,
                       cross_attention_block: MultiHeadAttentionBlock,
                       feed_forward_block: FeedForwardBlock, 
                       dropout: float) -> None:
        super.__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout),ResidualConnection(dropout),ResidualConnection(dropout)])
        
    def forward(self, x, encoder_output, src_mask, tag_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tag_mask))
        x = self.residual_connections[1](x, lambda x: self.self_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x 
    
class Decoder(nn.Module):
    
    def __init__(self, layers: nn.ModuleList) -> None:
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)
    
class ProjectionLayer(nn.Module):
    
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super.__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)
 
class Transformer(nn.Module):
    
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, 
                 src_pos: PositionEmbedding, tgt_pos: PositionEmbedding, projection_layer: ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    
    def encoder(self, src, src_mask):
        return self.encoder(self.src_pos(self.src_embed(src)), src_mask)
    
    def decoder(self, tgt, encoder_output, src_mask, tgt_mask):
        return self.decoder(self.tgt_pos(self.tgt_embed(tgt)), encoder_output, src_mask, tgt_mask)  
    
    def project(self, x):
        return self.projection_layer(x)
    
class build_transformer(nn.Module):
    
    def __init__(self, src_vocab_size: int, 
                       tgt_vocab_size: int, 
                       src_seq_len: int, 
                       tgt_seq_len: int,
                       d_model: int = 512, 
                       N: int =6, 
                       h: int=8, 
                       d_ff: int=2048,
                       dropout: float=0.1
                ) -> Transformer:
        super().__init__()
        src_embed = InputEmbeddings(src_vocab_size, d_model)
        tgt_embed = InputEmbeddings(tgt_vocab_size, d_model)
        
        src_pos = PositionEmbedding(src_seq_len, d_model, dropout)
        tgt_pos = PositionEmbedding(tgt_seq_len, d_model, dropout)
        
        encoder_blocks = []
        for _ in range(N):
            encoder_blocks.append(EncoderBlock(MultiHeadAttentionBlock(d_model, h, dropout), 
                                               FeedForwardBlock(d_model, d_ff, dropout), 
                                               dropout))
            
        encoder = Encoder(nn.ModuleList(encoder_blocks))
        
        decoder_blocks = []
        for _ in range(N):
            decoder_blocks.append(DecoderBlock(MultiHeadAttentionBlock(d_model, h, dropout), 
                                               MultiHeadAttentionBlock(d_model, h, dropout),
                                               FeedForwardBlock(d_model, d_ff, dropout), 
                                               dropout))
        decoder = Decoder(nn.ModuleList(decoder_blocks))
        
        projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
        
        transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
        
        for p in transformer.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
        return transformer
        
        
	
Camille-X | © 2025

Made with

svelte-logo