Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/ctranslate2/converters/marian.py
import argparse
import re
from typing import List
import numpy as np
import yaml
from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec, transformer_spec
_SUPPORTED_ACTIVATIONS = {
"gelu": common_spec.Activation.GELUSigmoid,
"relu": common_spec.Activation.RELU,
"swish": common_spec.Activation.SWISH,
}
_SUPPORTED_POSTPROCESS_EMB = {"", "d", "n", "nd"}
class MarianConverter(Converter):
"""Converts models trained with Marian."""
def __init__(self, model_path: str, vocab_paths: List[str]):
"""Initializes the Marian converter.
Arguments:
model_path: Path to the Marian model (.npz file).
vocab_paths: Paths to the vocabularies (.yml files).
"""
self._model_path = model_path
self._vocab_paths = vocab_paths
def _load(self):
model = np.load(self._model_path)
config = _get_model_config(model)
vocabs = list(map(load_vocab, self._vocab_paths))
activation = config["transformer-ffn-activation"]
pre_norm = "n" in config["transformer-preprocess"]
postprocess_emb = config["transformer-postprocess-emb"]
check = utils.ConfigurationChecker()
check(config["type"] == "transformer", "Option --type must be 'transformer'")
check(
config["transformer-decoder-autoreg"] == "self-attention",
"Option --transformer-decoder-autoreg must be 'self-attention'",
)
check(
not config["transformer-no-projection"],
"Option --transformer-no-projection is not supported",
)
check(
activation in _SUPPORTED_ACTIVATIONS,
"Option --transformer-ffn-activation %s is not supported "
"(supported activations are: %s)"
% (activation, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
)
check(
postprocess_emb in _SUPPORTED_POSTPROCESS_EMB,
"Option --transformer-postprocess-emb %s is not supported (supported values are: %s)"
% (postprocess_emb, ", ".join(_SUPPORTED_POSTPROCESS_EMB)),
)
if pre_norm:
check(
config["transformer-preprocess"] == "n"
and config["transformer-postprocess"] == "da"
and config.get("transformer-postprocess-top", "") == "n",
"Unsupported pre-norm Transformer architecture, expected the following "
"combination of options: "
"--transformer-preprocess n "
"--transformer-postprocess da "
"--transformer-postprocess-top n",
)
else:
check(
config["transformer-preprocess"] == ""
and config["transformer-postprocess"] == "dan"
and config.get("transformer-postprocess-top", "") == "",
"Unsupported post-norm Transformer architecture, excepted the following "
"combination of options: "
"--transformer-preprocess '' "
"--transformer-postprocess dan "
"--transformer-postprocess-top ''",
)
check.validate()
alignment_layer = config["transformer-guided-alignment-layer"]
alignment_layer = -1 if alignment_layer == "last" else int(alignment_layer) - 1
layernorm_embedding = "n" in postprocess_emb
model_spec = transformer_spec.TransformerSpec.from_config(
(config["enc-depth"], config["dec-depth"]),
config["transformer-heads"],
pre_norm=pre_norm,
activation=_SUPPORTED_ACTIVATIONS[activation],
alignment_layer=alignment_layer,
alignment_heads=1,
layernorm_embedding=layernorm_embedding,
)
set_transformer_spec(model_spec, model)
model_spec.register_source_vocabulary(vocabs[0])
model_spec.register_target_vocabulary(vocabs[-1])
model_spec.config.add_source_eos = True
return model_spec
def _get_model_config(model):
config = model["special:model.yml"]
config = config[:-1].tobytes()
config = yaml.safe_load(config)
return config
def load_vocab(path):
# pyyaml skips some entries so we manually parse the vocabulary file.
with open(path, encoding="utf-8") as vocab:
tokens = []
token = None
idx = None
for i, line in enumerate(vocab):
line = line.rstrip("\n\r")
if not line:
continue
if line.startswith("? "): # Complex key mapping (key)
token = line[2:]
elif token is not None: # Complex key mapping (value)
idx = line[2:]
else:
token, idx = line.rsplit(":", 1)
if token is not None:
if token.startswith('"') and token.endswith('"'):
# Unescape characters and remove quotes.
token = re.sub(r"\\([^x])", r"\1", token)
token = token[1:-1]
if token.startswith("\\x"):
# Convert the digraph \x to the actual escaped sequence.
token = chr(int(token[2:], base=16))
elif token.startswith("'") and token.endswith("'"):
token = token[1:-1]
token = token.replace("''", "'")
if idx is not None:
try:
idx = int(idx.strip())
except ValueError as e:
raise ValueError(
"Unexpected format at line %d: '%s'" % (i + 1, line)
) from e
tokens.append((idx, token))
token = None
idx = None
return [token for _, token in sorted(tokens, key=lambda item: item[0])]
def set_transformer_spec(spec, weights):
set_transformer_encoder(spec.encoder, weights, "encoder")
set_transformer_decoder(spec.decoder, weights, "decoder")
def set_transformer_encoder(spec, weights, scope):
set_common_layers(spec, weights, scope)
for i, layer_spec in enumerate(spec.layer):
set_transformer_encoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
def set_transformer_decoder(spec, weights, scope):
spec.start_from_zero_embedding = True
set_common_layers(spec, weights, scope)
for i, layer_spec in enumerate(spec.layer):
set_transformer_decoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
set_linear(
spec.projection,
weights,
"%s_ff_logit_out" % scope,
reuse_weight=spec.embeddings.weight,
)
def set_common_layers(spec, weights, scope):
embeddings_specs = spec.embeddings
if not isinstance(embeddings_specs, list):
embeddings_specs = [embeddings_specs]
set_embeddings(embeddings_specs[0], weights, scope)
set_position_encodings(
spec.position_encodings, weights, dim=embeddings_specs[0].weight.shape[1]
)
if hasattr(spec, "layernorm_embedding"):
set_layer_norm(
spec.layernorm_embedding,
weights,
"%s_emb" % scope,
pre_norm=True,
)
if hasattr(spec, "layer_norm"):
set_layer_norm(spec.layer_norm, weights, "%s_top" % scope)
def set_transformer_encoder_layer(spec, weights, scope):
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
set_multi_head_attention(
spec.self_attention, weights, "%s_self" % scope, self_attention=True
)
def set_transformer_decoder_layer(spec, weights, scope):
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
set_multi_head_attention(
spec.self_attention, weights, "%s_self" % scope, self_attention=True
)
set_multi_head_attention(spec.attention, weights, "%s_context" % scope)
def set_multi_head_attention(spec, weights, scope, self_attention=False):
split_layers = [common_spec.LinearSpec() for _ in range(3)]
set_linear(split_layers[0], weights, scope, "q")
set_linear(split_layers[1], weights, scope, "k")
set_linear(split_layers[2], weights, scope, "v")
if self_attention:
utils.fuse_linear(spec.linear[0], split_layers)
else:
spec.linear[0].weight = split_layers[0].weight
spec.linear[0].bias = split_layers[0].bias
utils.fuse_linear(spec.linear[1], split_layers[1:])
set_linear(spec.linear[-1], weights, scope, "o")
set_layer_norm_auto(spec.layer_norm, weights, "%s_Wo" % scope)
def set_ffn(spec, weights, scope):
set_layer_norm_auto(spec.layer_norm, weights, "%s_ffn" % scope)
set_linear(spec.linear_0, weights, scope, "1")
set_linear(spec.linear_1, weights, scope, "2")
def set_layer_norm_auto(spec, weights, scope):
try:
set_layer_norm(spec, weights, scope, pre_norm=True)
except KeyError:
set_layer_norm(spec, weights, scope)
def set_layer_norm(spec, weights, scope, pre_norm=False):
suffix = "_pre" if pre_norm else ""
spec.gamma = weights["%s_ln_scale%s" % (scope, suffix)].squeeze()
spec.beta = weights["%s_ln_bias%s" % (scope, suffix)].squeeze()
def set_linear(spec, weights, scope, suffix="", reuse_weight=None):
weight = weights.get("%s_W%s" % (scope, suffix))
if weight is None:
weight = weights.get("%s_Wt%s" % (scope, suffix), reuse_weight)
else:
weight = weight.transpose()
spec.weight = weight
bias = weights.get("%s_b%s" % (scope, suffix))
if bias is not None:
spec.bias = bias.squeeze()
def set_embeddings(spec, weights, scope):
spec.weight = weights.get("%s_Wemb" % scope)
if spec.weight is None:
spec.weight = weights.get("Wemb")
def set_position_encodings(spec, weights, dim=None):
spec.encodings = weights.get("Wpos", _make_sinusoidal_position_encodings(dim))
def _make_sinusoidal_position_encodings(dim, num_positions=2048):
positions = np.arange(num_positions)
timescales = np.power(10000, 2 * (np.arange(dim) // 2) / dim)
position_enc = np.expand_dims(positions, 1) / np.expand_dims(timescales, 0)
table = np.zeros_like(position_enc)
table[:, : dim // 2] = np.sin(position_enc[:, 0::2])
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
return table
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model_path", required=True, help="Path to the model .npz file."
)
parser.add_argument(
"--vocab_paths",
required=True,
nargs="+",
help="List of paths to the YAML vocabularies.",
)
Converter.declare_arguments(parser)
args = parser.parse_args()
converter = MarianConverter(args.model_path, args.vocab_paths)
converter.convert_from_args(args)
if __name__ == "__main__":
main()
Back to Directory
File Manager