Prediction
Predictionsβ
class eole.predict.prediction.PredictionBuilder(vocabs, n_best=1, replace_unk=False, phrase_table='', tgt_eos_idx=None, id_tokenization=False)β
Bases: object
Build a word-based prediction from the batch output of predictor and the underlying dictionaries.
Replacement based on βAddressing the Rare Word Problem in Neural Machine Translationβ []
- Parameters:
- (****) (vocabs)
- (****)
- n_best (int) β number of predictions produced
- replace_unk (bool) β replace unknown words using attention
Predictor Classesβ
class eole.predict.inference.Inference(model, vocabs, gpu=-1, n_best=1, min_length=0, max_length=100, max_length_ratio=1.5, ratio=0.0, beam_size=30, top_k=0, top_p=0.0, temperature=1.0, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, ignore_when_blocking=frozenset({}), replace_unk=False, ban_unk_token=False, tgt_file_prefix=False, phrase_table='', data_type='text', verbose=False, report_time=False, global_scorer=None, report_align=False, gold_align=False, report_score=True, logger=None, seed=-1, with_score=False, estim_only=False, return_gold_log_probs=False, add_estimator=False, estimator_type='average', optional_eos=[], id_tokenization=False, image_token_id=10, fuse_kvq=False, fuse_gate=False)β
Bases: object
Predict a batch of sentences with a saved model.
- Parameters:
- model (eole.modules.BaseModel) β Model to use for prediction
- vocabs (dict *[*str , Vocab ]) β A dict mapping each sideβs Vocab.
- gpu (int) β GPU device. Set to negative for no GPU.
- n_best (int) β How many beams to wait for.
- min_length (int) β See
eole.predict.decode_strategy.DecodeStrategy. - max_length (int) β See
eole.predict.decode_strategy.DecodeStrategy. - beam_size (int) β Number of beams.
- top_p (float) β See
eole.predict.greedy_search.GreedySearch. - top_k (int) β See
eole.predict.greedy_search.GreedySearch. - temperature (float) β See
eole.predict.greedy_search.GreedySearch. - stepwise_penalty (bool) β Whether coverage penalty is applied every step or not.
- dump_beam (bool) β Debugging option.
- block_ngram_repeat (int) β See
eole.predict.decode_strategy.DecodeStrategy. - ignore_when_blocking (set or frozenset) β See
eole.predict.decode_strategy.DecodeStrategy. - replace_unk (bool) β Replace unknown token.
- tgt_file_prefix (bool) β Force the predictions begin with provided -tgt.
- data_type (str) β Source data type.
- verbose (bool) β Print/log every prediction.
- report_time (bool) β Print/log total time/frequency.
- global_scorer (eole.predict.GNMTGlobalScorer) β Prediction scoring/reranking object.
- report_score (bool) β Whether to report scores
- logger (logging.Logger or NoneType) β Logger.
classmethod from_config(model, vocabs, config, model_config, device_id=0, global_scorer=None, report_align=False, report_score=True, logger=None)β
Alternate constructor.
- Parameters:
- model (eole.modules.BaseModel) β See
__init__(). - vocabs (dict *[*str , Vocab ]) β See
__init__(). - opt (argparse.Namespace) β Command line options
- model_opt (argparse.Namespace) β Command line options saved with the model checkpoint.
- global_scorer (eole.predict.GNMTGlobalScorer) β See
__init__().. - report_align (bool) β See
__init__(). - report_score (bool) β See
__init__(). - logger (logging.Logger or NoneType) β See
__init__().
- model (eole.modules.BaseModel) β See
predict_batch(batch, attn_debug)β
Predict a batch of sentences.
Decoding Strategiesβ
eole.predict.greedy_search.sample_with_temperature(logits, temperature, top_k, top_p)β
Select next tokens randomly from the top k possible next tokens.
Samples from a categorical distribution over the top_k words using
the category probabilities logits / temperature.
- Parameters:
- logits (FloatTensor) β Shaped
(batch_size, vocab_size). These can be logits ((-inf, inf)) or log-probs ((-inf, 0]). (The distribution actually uses the log-probabilitieslogits - logits.logsumexp(-1), which equals the logits if they are log-probabilities summing to 1.) - temperature (float) β Used to scale down logits. The higher the value, the more likely it is that a non-max word will be sampled.
- top_k (int) β This many words could potentially be chosen. The other logits are set to have probability 0.
- top_p (float) β Keep most likely words until the cumulated probability is greater than p. If used with top_k: both conditions will be applied
- logits (FloatTensor) β Shaped
- Returns:
- topk_ids: Shaped
(batch_size, 1). These are the sampled word indices in the output vocab. - topk_scores: Shaped
(batch_size, 1). These are essentially(logits / temperature)[topk_ids].
- topk_ids: Shaped
- Return type: (LongTensor, FloatTensor)
Scoringβ
class eole.predict.penalties.PenaltyBuilder(cov_pen, length_pen)β
Bases: object
Returns the Length and Coverage Penalty function for Beam Search.
- Parameters:
- length_pen (str) β option name of length pen
- cov_pen (str) β option name of cov pen
- Variables:
- has_cov_pen (bool) β Whether coverage penalty is None (applying it is a no-op). Note that the converse isnβt true. Setting beta to 0 should force coverage length to be a no-op.
- has_len_pen (bool) β Whether length penalty is None (applying it is a no-op). Note that the converse isnβt true. Setting alpha to 1 should force length penalty to be a no-op.
- coverage_penalty (callable [ *[*FloatTensor , float ] , FloatTensor ]) β Calculates the coverage penalty.
- length_penalty (callable [ *[*int , float ] , float ]) β Calculates the length penalty.
coverage_none(cov, beta=0.0)β
Returns zero as penalty
coverage_summary(cov, beta=0.0)β
Our summary penalty.
coverage_wu(cov, beta=0.0)β
GNMT coverage re-ranking score.
See βGoogleβs Neural Machine Translation Systemβ [].
cov is expected to be sized (*, seq_len), where * is
probably batch_size x beam_size but could be several
dimensions like (batch_size, beam_size). If cov is attention,
then the seq_len axis probably sums to (almost) 1.
length_average(cur_len, alpha=1.0)β
Returns the current sequence length.
length_none(cur_len, alpha=0.0)β
Returns unmodified scores.
length_wu(cur_len, alpha=0.0)β
GNMT length re-ranking score.
See βGoogleβs Neural Machine Translation Systemβ [].