#!/usr/bin/env python

import nsysstats

class SyncAPI(nsysstats.Report):

    ROW_LIMIT = 50

    usage = f"""{{SCRIPT}}[:rows=<limit>] -- SyncAPI

    Options:
        rows=<limit> - Maximum number of rows returned by the query.
            Default is {ROW_LIMIT}.

    Output: All time values default to nanoseconds
        Duration : Duration
        Start : Start time
        PID : Process identifier
        TID : Thread identifier
        API Name : Name of runtime API function

    This rule identifies the following synchronization APIs that block the host
    until the issued CUDA calls are complete:
    - cudaDeviceSynchronize()
    - cudaStreamSynchronize()
"""

    query_sync_api = """
    WITH
        sync AS (
            SELECT
                id,
                value
            FROM
                StringIds
            WHERE
                value like 'cudaDeviceSynchronize%'
                OR value like 'cudaStreamSynchronize%'
        )
    SELECT
        end - start AS "Duration:dur_ns",
        start AS "Start:ts_ns",
        (globalTid >> 24) & 0x00FFFFFF AS "PID",
        globalTid & 0xFFFFFF AS "TID",
        value AS "API Name",
        globalTid AS "_Global TID"
    FROM
        CUPTI_ACTIVITY_KIND_RUNTIME AS runtime
    JOIN
        sync
        ON sync.id = runtime.nameId
    ORDER BY
        1 DESC
    LIMIT {ROW_LIMIT}
"""

    table_checks = {
        'CUPTI_ACTIVITY_KIND_RUNTIME':
            "{DBFILE} could not be analyzed because it does not contain CUDA trace data."
    }

    def setup(self):
        err = super().setup()
        if err != None:
            return err

        row_limit = self.ROW_LIMIT
        for arg in self.args:
            s = arg.split('=')
            if len(s) == 2 and s[0] == 'rows' and s[1].isdigit():
                row_limit = s[1]
            else:
                exit(self.EXIT_INVALID_ARG)

        self.query = self.query_sync_api.format(
            ROW_LIMIT = row_limit)

if __name__ == "__main__":
    SyncAPI.Main()
