airflow.providers.google.cloud.utils.mlengine_prediction_summary

一个由 DataFlowPythonOperator 调用的模板,用于汇总批量预测 (BatchPrediction)。

它接受一个用户函数来计算预测结果中每个实例的指标,然后进行聚合并输出汇总。

它接受以下参数

  • --prediction_path: 包含批量预测结果的 GCS 文件夹,其中包含 JSON 格式的 prediction.results-NNNNN-of-NNNNN 文件。输出也将存储在此文件夹中,名称为 ‘prediction.summary.json’。

  • --metric_fn_encoded: 一个编码的函数,用于计算并返回给定实例(以字典形式)的一个或多个指标元组。它应该通过 base64.b64encode(dill.dumps(fn, recurse=True)) 进行编码。

  • --metric_keys: 汇总输出中聚合指标的逗号分隔键。键的顺序和数量必须与 metric_fn 的输出匹配。汇总将包含一个额外的键 ‘count’,表示实例总数,因此键不应包含 ‘count’。

使用示例

当输入文件如下所示时

{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}

输出文件将是

{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}

在 DAG 外部测试

subprocess.check_call(
    [
        "python",
        "-m",
        "airflow.providers.google.cloud.utils.mlengine_prediction_summary",
        "--prediction_path=gs://...",
        "--metric_fn_encoded=" + metric_fn_encoded,
        "--metric_keys=log_loss,mse",
        "--runner=DataflowRunner",
        "--staging_location=gs://...",
        "--temp_location=gs://...",
    ]
)

JsonCoder

JSON 编码器/解码器。

函数

MakeSummary(pcoll, metric_fn, metric_keys)

在 Dataflow 中使用的汇总 PTransform。

run([argv])

获取预测汇总。

模块内容

class airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder[源代码]

基类: apache_beam.coders.coders.Coder

JSON 编码器/解码器。

static encode(x)[源代码]

JSON 编码器。

static decode(x)[源代码]

JSON 解码器。

airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary(pcoll, metric_fn, metric_keys)[源代码]

在 Dataflow 中使用的汇总 PTransform。

airflow.providers.google.cloud.utils.mlengine_prediction_summary.run(argv=None)[源代码]

获取预测汇总。

此条目有帮助吗?