Dans la Procédure pour apprendre et inférer le modèle de traduction anglais-japonais du transformateur avec CloudTPU, le modèle de traduction anglais-japonais du transformateur a été appris avec CloudTPU et l'inférence a également été effectuée. C'était. Cette fois, je vais vous expliquer comment exécuter un transformateur formé par Cloud TPU dans un conteneur Docker local. Le code est ici. https://github.com/yolo-kiyoshi/transformer_python_exec
GCS, supposons que les fichiers se trouvent localement dans la structure de répertoires suivante.
Structure du répertoire
bucket
├── training/
│   └── transformer_ende/
│       ├── checkpoint
│       ├── model.ckpt-****.data-00000-of-00001
│       ├── model.ckpt-****.index
│       └── model.ckpt-****.meta
└── transformer/
    └── vocab.translate_jpen.****.subwords
Clonez le référentiel.
git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git
Structure du répertoire
.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│   └── transformer_ende/
└── transformer/
Téléchargez le fichier d'informations d'identification du compte de service (json) et placez-le dans le même répertoire que README.md.
Dupliquez et renommez .env.sample pour créer .env.
.env
#Décrivez le chemin du fichier d'identification placé au-dessus
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
#Mêmes paramètres que lors de l'apprentissage avec CloudTPU
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer
Après avoir exécuté la commande suivante, vous pouvez utiliser Jupyter lab en accédant à http: // localhost: 8080 / lab.
docker-compose up -d
Notebook
Téléchargez localement l'ensemble des fichiers point de contrôle et des fichiers vocaux créés pendant le processus d'apprentissage du transformateur à partir de GCS.
#Méthode pour télécharger des fichiers depuis GCS(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blob = bucket.blob(source_blob_name)
    blob.download_to_filename(destination_file_name)
    print('Blob {} downloaded to {}.'.format(
        source_blob_name,
        destination_file_name))
#Se référer à la méthode d'acquisition de la liste de fichiers GCS
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
    """Lists all the blobs in the bucket that begin with the prefix."""
    
    storage_client = storage.Client()
    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)
    file_list = [blob.name for blob in blobs if search_path in blob.name]
    
    return file_list
#Définir les variables d'environnement
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']
#chemin du fichier de point de contrôle
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
#Télécharger le fichier de point de contrôle depuis GCS
download_blob(BUDGET_NAME, src_file_name, dist_file_name)
#Dernière séquence de point de contrôle du fichier de point de contrôle(prefix)Obtenir
import re
with open(dist_file_name) as f:
    l = f.readlines(1)
    ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
    ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
#Obtenez la liste de fichiers associée au dernier point de contrôle de GCS
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)
# checkpoint.Téléchargez un ensemble de variables
for ckpt_file in ckpt_file_list:
    download_blob(BUDGET_NAME, ckpt_file, ckpt_file)
#Obtenez le chemin du fichier de vocabulaire à partir de GCS
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
#Télécharger le fichier de vocabulaire depuis GCS
download_blob(BUDGET_NAME, vocab_file, vocab_file)
Chargez le modèle de transformateur en fonction des résultats de la formation sur le transformateur téléchargés à partir de GCS.
#Initialisation
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys
import pickle
import numpy as np
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
#Prétraitement&Utilisez le même nom de classe que PROBLE défini dans l'apprentissage
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**13
enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)
from functools import wraps
import time
def stop_watch(func) :
    @wraps(func)
    def wrapper(*args, **kargs) :
        start = time.time()
        print(f'{func.__name__} started ...')
        result = func(*args,**kargs)
        elapsed_time =  time.time() - start
        print(f'elapsed_time:{elapsed_time}')
        print(f'{func.__name__} completed')
        return result
    return wrapper
@stop_watch
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(ckpt_path):
        model_output = translate_model.infer(features=encoded_inputs)["outputs"]
    return decode(model_output)
def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]
    batch_inputs = tf.reshape(inputs, [1, -1, 1])
    return {"inputs": batch_inputs}
def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)
Inférer avec le modèle de transformateur chargé. Lorsqu'elle est exécutée localement, une phrase prend environ 30 secondes.
inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)
résultat
>Mon chat est très mignon.