Options
All
  • Public
  • Public/Protected
  • All
Menu

Class GraphRunner

A class that drives the training of a graph model given a dataset. It allows the user to provide a set of callbacks for measurements like cost, accuracy, and speed of training.

Hierarchy

  • GraphRunner

Index

Constructors

constructor

Methods

computeMetric

getLastComputedMetric

  • getLastComputedMetric(): Scalar

getTotalBatchesTrained

  • getTotalBatchesTrained(): number

infer

  • infer(inferenceTensor: Tensor, inferenceFeedEntries: FeedEntry[], inferenceExampleIntervalMs?: number, inferenceExampleCount?: number, numPasses?: number): void
  • Parameters

    • inferenceTensor: Tensor
    • inferenceFeedEntries: FeedEntry[]
    • Default value inferenceExampleIntervalMs: number = DEFAULT_INFERENCE_EXAMPLE_INTERVAL_MS
    • Default value inferenceExampleCount: number = 5
    • Optional numPasses: number

    Returns void

isInferenceRunning

  • isInferenceRunning(): boolean

resetStatistics

  • resetStatistics(): void

resumeTraining

  • resumeTraining(): void

setInferenceExampleCount

  • setInferenceExampleCount(inferenceExampleCount: number): void

setInferenceTensor

  • setInferenceTensor(inferenceTensor: Tensor): void

setMath

setSession

  • setSession(session: Session): void

stopInferring

  • stopInferring(): void

stopTraining

  • stopTraining(): void

train

  • train(costTensor: Tensor, trainFeedEntries: FeedEntry[], batchSize: number, optimizer: Optimizer, numBatches?: number, metricTensor?: Tensor, metricFeedEntries?: FeedEntry[], metricBatchSize?: number, metricReduction?: MetricReduction.SUM | MetricReduction.MEAN, evalIntervalMs?: number, costIntervalMs?: number): void
  • Start the training loop with an optional number of batches to train for. Optionally takes a metric tensor and feed entries to compute periodically. This can be used for computing accuracy, or a similar metric.

    Parameters

    • costTensor: Tensor
    • trainFeedEntries: FeedEntry[]
    • batchSize: number
    • optimizer: Optimizer
    • Optional numBatches: number
    • Optional metricTensor: Tensor
    • Optional metricFeedEntries: FeedEntry[]
    • Optional metricBatchSize: number
    • Default value metricReduction: MetricReduction.SUM | MetricReduction.MEAN = MetricReduction.MEAN
    • Default value evalIntervalMs: number = DEFAULT_EVAL_INTERVAL_MS
    • Default value costIntervalMs: number = DEFAULT_COST_INTERVAL_MS

    Returns void

Generated using TypeDoc