Klasse MultiHeadAttention<H extends Head>
java.lang.Object
me.damoebe.architectures.transformer.mha.MultiHeadAttention<H>
-
Konstruktorübersicht
KonstruktorenKonstruktorBeschreibungMultiHeadAttention(Class<H> c, int headAmount, int inputEmbeddingAmounts, int inputEmbeddingSize, boolean masked) Main constructor for the MultiHeadAttention class -
Methodenübersicht
Modifizierer und TypMethodeBeschreibungclone()Clones a MultiHeadAttention object.static double[][]concatMatrices(List<double[][]> matrices) Merges a list of matrices.generateOutputFor(Sequence[] input) Generates an output for a list of inputs(embedding lists).intGetter for the used embedding amount of this MHA.intGetter for the used embedding size of this MHA.getHeads()Getter for the heads stored in the MultiHeadAttentionvoidupdateAllWeights(double[][] deltas) Updates All head QKV weights as well as the weights which are used to merge the head matrices
-
Konstruktordetails
-
MultiHeadAttention
public MultiHeadAttention(Class<H> c, int headAmount, int inputEmbeddingAmounts, int inputEmbeddingSize, boolean masked) Main constructor for the MultiHeadAttention class- Parameter:
c- The class that is being used for the headsheadAmount- The amount of heads this MHA object should haveinputEmbeddingAmounts- The amount of the input embeddingsinputEmbeddingSize- The size of the input embeddings
-
-
Methodendetails
-
generateOutputFor
-
concatMatrices
Merges a list of matrices.- Parameter:
matrices- the matrices that will be concat- Gibt zurück:
- the merged matrix as a 2d array
-
updateAllWeights
public void updateAllWeights(double[][] deltas) Updates All head QKV weights as well as the weights which are used to merge the head matrices- Parameter:
deltas- The first hidden-layer deltas of the block's mlp
-
clone
Clones a MultiHeadAttention object. -
getEmbeddingAmount
public int getEmbeddingAmount()Getter for the used embedding amount of this MHA.- Gibt zurück:
- The embedding Amount
-
getEmbeddingSize
public int getEmbeddingSize()Getter for the used embedding size of this MHA.- Gibt zurück:
- the embedding size.
-
getHeads
-