transformer 代码详解

论文背景

谷歌 2017 年发了一篇论文 《Attention is all you need》,文中提出了一种新的架构叫做 Transformer,用以来实现机器翻译。它抛弃了传统用 CNN 或者 RNN 的定式,取得了很好的效果,激起了工业界和学术界的广泛讨论。而在谷歌的论文发出不久,就有人用 tensorflow 实现了 Transformer 模型:A TensorFlow Implementation of the Transformer: Attention Is All You Need

论文原文,论文的中文翻译

代码架构

  • hyperparams.py : 该文件包含所有需要用到的参数数据的路径
  • prepro.py : 该文件生成源语言和目标语言的词汇文件
  • data_load.py : 该文件包含所有关于加载数据以及批量化数据的函数
  • modules.py : 该文件具体实现编码器和解码器网络
  • train.py : 训练模型的代码,定义了模型,损失函数以及训练和保存模型的过程
  • eval.py : 评估模型的效果

hyperparams.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Hyperparams:
'''Hyperparameters'''
# data
source_train = 'corpora/train.tags.de-en.de'
target_train = 'corpora/train.tags.de-en.en'
source_test = 'corpora/IWSLT16.TED.tst2014.de-en.de.xml'
target_test = 'corpora/IWSLT16.TED.tst2014.de-en.en.xml'

# training
batch_size = 32 # alias = N
lr = 0.0001 # learning rate. In paper, learning rate is adjusted to the global step.
logdir = 'logdir' # log directory

# model
maxlen = 10 # Maximum number of words in a sentence. alias = T.
# Feel free to increase this if you are ambitious.
min_cnt = 20 # words whose occurred less than min_cnt are encoded as <UNK>.
hidden_units = 512 # alias = C
num_blocks = 6 # number of encoder/decoder blocks
num_epochs = 20
num_heads = 8
dropout_rate = 0.1

prepro.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import Ipynb_importer
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tensorflow as tf
import numpy as np
import codecs
import os
import regex
from collections import Counter

def make_vocab(fpath, fname):
'''Constructs vocabulary.

Args:
fpath: A string. Input file path.
fname: A string. Output file name.

Writes vocabulary line by line to `preprocessed/fname`
'''
text = codecs.open(fpath, 'r', 'utf-8').read()
text = regex.sub("[^\s\p{Latin}']", "", text)
words = text.split()
word2cnt = Counter(words)
if not os.path.exists('preprocessed'): os.mkdir('preprocessed')
with codecs.open('preprocessed/{}'.format(fname), 'w', 'utf-8') as fout:
fout.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("<PAD>", "<UNK>", "<S>", "</S>"))
for word, cnt in word2cnt.most_common(len(word2cnt)):
fout.write(u"{}\t{}\n".format(word, cnt))

if __name__ == '__main__':
make_vocab(hp.source_train, "de.vocab.tsv")
make_vocab(hp.target_train, "en.vocab.tsv")
print("Done")

data_load.py :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tensorflow as tf
import numpy as np
import codecs
import regex

def load_de_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word


def load_en_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word


def create_data(source_sents, target_sents):
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Index
x_list, y_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents):
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 1: OOV, </S>: End of Text
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()]
if max(len(x), len(y)) <=hp.maxlen:
x_list.append(np.array(x))
y_list.append(np.array(y))
Sources.append(source_sent)
Targets.append(target_sent)

# Pad
X = np.zeros([len(x_list), hp.maxlen], np.int32)
Y = np.zeros([len(y_list), hp.maxlen], np.int32)
for i, (x, y) in enumerate(zip(x_list, y_list)):
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0))
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))

return X, Y, Sources, Targets


def load_train_data():
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Y


def load_test_data():
def _refine(line):
line = regex.sub("<[^>]+>", "", line)
line = regex.sub("[^\s\p{Latin}']", "", line)
return line.strip()

de_sents = [_refine(line) for line in codecs.open(hp.source_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]
en_sents = [_refine(line) for line in codecs.open(hp.target_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Sources, Targets # (1064, 150)


def get_batch_data():
# Load data
X, Y = load_train_data()

# calc total batch count
num_batch = len(X) // hp.batch_size

# Convert to tensor
X = tf.convert_to_tensor(X, tf.int32)
Y = tf.convert_to_tensor(Y, tf.int32)

# Create Queues
input_queues = tf.train.slice_input_producer([X, Y])

# create batch queues
x, y = tf.train.shuffle_batch(input_queues,
num_threads=8,
batch_size=hp.batch_size,
capacity=hp.batch_size*64,
min_after_dequeue=hp.batch_size*32,
allow_smaller_final_batch=False)

return x, y, num_batch # (N, T), (N, T), ()

★ modules.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import print_function
import tensorflow as tf
from sklearn.metrics import accuracy_score

def normalize(inputs,
epsilon=1e-8,
scope='ln',
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
inputs_shape = inputs.get_shape()
params_shape = inputs_shape[-1:]

mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
beta = tf.Variable(tf.zeros(params_shape))
gamma = tf.Variable(tf.ones(params_shape))
normalized = (inputs - mean) / ((variance + epsilon) ** 0.5)
output = gamma * normalized + beta

return output

def embedding(inputs,
vocab_size,
num_units,
zero_pad=True,
scale=True,
scope="embedding",
reuse=None):

with tf.variable_scope(scope, reuse=reuse):
lookup_table = tf.get_variable('lookup_table',
dtype=tf.float32,
shape=[vocab_size, num_units],
initializer=tf.contrib.layers.xavier_initializer())
if zero_pad:
lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
lookup_table[1:, :]), 0)
outputs = tf.nn.embedding_lookup(lookup_table, inputs)

if scale:
outputs = outputs * (num_units ** 0.5)

return outputs

def multihead_attention(queries, # 默认大小[N, T_q, C_q]
keys, # 默认大小[N, T_k, C_k]
num_units=None,
num_heads=8,
dropout_rate=0,
is_training=True,
causality=False,
scope="multihead_attention",
reuse=None):

with tf.variable_scope(scope, reuse=reuse):
if num_units is None:
num_units = queries.get_shap().as_list[-1]

# Linear projections
Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C)
K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)
V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)

# Split and concat
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h)
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_q, C/h)
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_q, C/h)

# Multiplication
outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k)

# Scale
outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

# Key Masking
key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k)
key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)

paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)

# Causality = Future blinding
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)

paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)

# Activation
outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k)

# Query Masking
query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q)
query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)
query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)
outputs *= query_masks # broadcasting. (h*N, T_q, T_k)

# Dropouts
outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))

# Weighted sum
outputs = tf.matmul(outputs, V_) # (h*N, T_q, C/h)

# Restore shape
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C)

# Residual connection
outputs += queries

# Normalize
outputs = normalize(outputs) # (N, T_q, C)

return outputs

def feedforward(inputs,
num_units=[2048, 512],
scope="multihead_attention",
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
# Inner layer
params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1,
"activation": tf.nn.relu, "use_bias": True}
outputs = tf.layers.conv1d(**params)

# Readout layer
params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1,
"activation": None, "use_bias": True}
outputs = tf.layers.conv1d(**params)

# Residual connection
outputs += inputs

# Normalize
outputs = normalize(outputs)

return outputs

def label_smoothing(inputs, epsilon=0.1):
K = inputs.get_shape().as_list()[-1]
return ((1-epsilon) * inputs) + (epsilon / K)

train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import tensorflow as tf

from hyperparams import Hyperparams as hp
from data_load import get_batch_data, load_de_vocab, load_en_vocab
from modules import *
import os, codecs
from tqdm import tqdm

class Graph():
def __init__(self, is_training=True):
self.graph = tf.Graph()
with self.graph.as_default():
if is_training:
self.x, self.y, self.num_batch = get_batch_data() # (N, T)
else:
self.x = tf.placeholder(tf.int32, shape=(None, hp.maxlen))
self.y = tf.placeholder(tf.int32, shape=(None, hp.maxlen))

# define decoder inputs
self.decoder_inputs = tf.concat((tf.ones_like(self.y[:,:1])*2, self.y[:, :-1]), -1)

# Load vocabulary
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Encoder
with tf.variable_scope("encoder"):
## Embedding
self.enc = embedding(self.x,
vocab_size=len(de2idx),
num_units=hp.hidden_units,
scale=True,
scope="enc_embed")

## Positional Encoding
self.enc += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.x)[1]), 0), [tf.shape(self.x)[0], 1]),
vocab_size=hp.maxlen,
num_units=hp.hidden_units,
zero_pad=False,
scale=False,
scope="enc_pe")

## Dropout
self.enc = tf.layers.dropout(self.enc,
rate=hp.dropout_rate,
training=tf.convert_to_tensor(is_training))

## Blocks
for i in range(hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i)):
### Multihead Attention
self.enc = multihead_attention(queries=self.enc,
keys=self.enc,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=False)

### Feed Forward
self.enc = feedforward(self.enc, num_units=[4 * hp.hidden_units, hp.hidden_units])

# Decoder
with tf.variable_scope("decoder"):
## Embedding
self.dec = embedding(self.decoder_inputs,
vocab_size=len(en2idx),
num_units=hp.hidden_units,
scale=True,
scope="dec_embed")

## Positional Encoding
self.dec += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.decoder_inputs)[1]), 0), [tf.shape(self.decoder_inputs)[0], 1]),
vocab_size=hp.maxlen,
num_units=hp.hidden_units,
zero_pad=False,
scale=False,
scope="dec_pe")

## Dropout
self.dec = tf.layers.dropout(self.dec,
rate=hp.dropout_rate,
training=tf.convert_to_tensor(is_training))

## Blocks
for i in range(hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i)):
## Multihead Attention ( self-attention)
self.dec = multihead_attention(queries=self.dec,
keys=self.dec,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=True,
scope="self_attention")

## Multihead Attention ( vanilla attention)
self.dec = multihead_attention(queries=self.dec,
keys=self.enc,
num_units=hp.hidden_units,
num_heads=hp.num_heads,
dropout_rate=hp.dropout_rate,
is_training=is_training,
causality=False,
scope="vanilla_attention")

## Feed Forward
self.dec = feedforward(self.dec, num_units=[4*hp.hidden_units, hp.hidden_units])

# Final linear projection
self.logits = tf.layers.dense(self.dec, len(en2idx))
self.preds = tf.to_int32(tf.arg_max(self.logits, dimension=-1))
self.istarget = tf.to_float(tf.not_equal(self.y, 0))
self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y))*self.istarget)/ (tf.reduce_sum(self.istarget))
tf.summary.scalar('acc', self.acc)

if is_training:
# Loss
self.y_smoothed = label_smoothing(tf.one_hot(self.y, depth=len(en2idx)))
self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y_smoothed)
self.mean_loss = tf.reduce_sum(self.loss*self.istarget) / (tf.reduce_sum(self.istarget))

# Training Scheme
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.optimizer = tf.train.AdamOptimizer(learning_rate=hp.lr, beta1=0.9, beta2=0.98, epsilon=1e-8)
self.train_op = self.optimizer.minimize(self.mean_loss, global_step=self.global_step)

# Summary
tf.summary.scalar('mean_loss', self.mean_loss)
self.merged = tf.summary.merge_all()


if __name__ == '__main__':
# Load vocabulary
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Construct graph
g = Graph(is_training = True)
print("Graph loaded")

# Start session
sv = tf.train.Supervisor(graph=g.graph,
logdir=hp.logdir,
save_model_secs=0)
with sv.managed_session() as sess:
for epoch in range(1, hp.num_epochs+1):
if sv.should_stop():
break
for step in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'):
loss,_ = sess.run([g.mean_loss,g.train_op])
print(step , ":" ,loss)

gs = sess.run(g.global_step)
sv.saver.save(sess, hp.logdir + '/model_epoch_%02d_gs_%d' % (epoch, gs))

print("Done")

eval.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import codecs
import os

import tensorflow as tf
import numpy as np

from hyperparams import Hyperparams as hp
from data_load import load_test_data, load_de_vocab, load_en_vocab
from train import Graph
from nltk.translate.bleu_score import corpus_bleu

def eval():
# Load graph
g = Graph(is_training=False)
print("Graph loaded")

# Load data
X, Sources, Targets = load_test_data()
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Start session
with g.graph.as_default():
sv = tf.train.Supervisor()
with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
## Restore parameters
sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
print("Restored!")

## Get model name
mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name

## Inference
if not os.path.exists('results'): os.mkdir('results')
with codecs.open("results/" + mname, "w", "utf-8") as fout:
list_of_refs, hypotheses = [], []
for i in range(len(X) // hp.batch_size):

### Get mini-batches
x = X[i * hp.batch_size: (i + 1) * hp.batch_size]
sources = Sources[i * hp.batch_size: (i + 1) * hp.batch_size]
targets = Targets[i * hp.batch_size: (i + 1) * hp.batch_size]

### Autoregressive inference
preds = np.zeros((hp.batch_size, hp.maxlen), np.int32)
for j in range(hp.maxlen):
_preds = sess.run(g.preds, {g.x: x, g.y: preds})
preds[:, j] = _preds[:, j]

### Write to file
for source, target, pred in zip(sources, targets, preds): # sentence-wise
got = " ".join(idx2en[idx] for idx in pred).split("</S>")[0].strip()
fout.write("- source: " + source + "\n")
fout.write("- expected: " + target + "\n")
fout.write("- got: " + got + "\n\n")
fout.flush()

# bleu score
ref = target.split()
hypothesis = got.split()
if len(ref) > 3 and len(hypothesis) > 3:
list_of_refs.append([ref])
hypotheses.append(hypothesis)

## Calculate bleu score
score = corpus_bleu(list_of_refs, hypotheses)
fout.write("Bleu Score = " + str(100 * score))


if __name__ == '__main__':
eval()
print("Done")

参考链接

机器翻译模型Transformer代码详细解析

Transformer模型的学习总结

坚持原创技术分享,您的支持将鼓励我继续创作!