airflow.providers.google.cloud.utils.mlengine_prediction_summary
¶
DataFlowPythonOperator 调用的一个模板,用于汇总 BatchPrediction。
它接受一个用户函数来计算预测结果中每个实例的指标,然后聚合以作为摘要输出。
它接受以下参数
--prediction_path
:包含 BatchPrediction 结果的 GCS 文件夹,其中包含 json 格式的prediction.results-NNNNN-of-NNNNN
文件。输出也将存储在此文件夹中,命名为 “prediction.summary.json”。--metric_fn_encoded
:一个编码函数,用于计算并返回给定实例(作为字典)的一个或多个指标的元组。它应该通过base64.b64encode(dill.dumps(fn, recurse=True))
进行编码。--metric_keys
:摘要输出中聚合指标的一个或多个逗号分隔的键。键的顺序和大小必须与 metric_fn 的输出匹配。摘要将有一个额外的键 “count” 来表示实例的总数,因此键不应包含 “count”。
使用示例
当输入文件如下所示时
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
输出文件将是
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
在 DAG 之外进行测试
subprocess.check_call(
[
"python",
"-m",
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
"--prediction_path=gs://...",
"--metric_fn_encoded=" + metric_fn_encoded,
"--metric_keys=log_loss,mse",
"--runner=DataflowRunner",
"--staging_location=gs://...",
"--temp_location=gs://...",
]
)
模块内容¶
函数¶
|
Dataflow 中使用的摘要 PTransform。 |
|
获取预测摘要。 |
- class airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder[源代码]¶
基类:
apache_beam.coders.coders.Coder
JSON 编码器/解码器。