/*
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#version 460
#pragma shader_stage(compute)
#extension GL_GOOGLE_include_directive : require

#define GET_BITS_SMEM 4
#include "common.comp"

layout (constant_id = 0) const bool interlaced = false;

layout (set = 0, binding = 0) readonly buffer slice_offsets_buf {
    uint32_t slice_offsets[];
};
layout (set = 0, binding = 1) writeonly buffer quant_idx_buf {
    uint8_t quant_idx[];
};
layout (set = 0, binding = 2) uniform writeonly uimage2D dst[];

layout (push_constant, scalar) uniform pushConstants {
   u8buf    slice_data;
   uint     bitstream_size;

   uint16_t width;
   uint16_t height;
   uint16_t mb_width;
   uint16_t mb_height;
   uint16_t slice_width;
   uint16_t slice_height;
   uint8_t  log2_slice_width;
   uint8_t  log2_chroma_w;
   uint8_t  depth;
   uint8_t  alpha_info;
   uint8_t  bottom_field;
};

/**
 * Table 9, encoded as (last_rice_q << 0) | (krice or kexp << 4) | ((kexp or kexp + 1) << 8)
 * According to the SMPTE document, abs(prev_dc_diff) should be used
 * to index the table, duplicating the entries removes the abs operation.
 */
const uint16_t k_dc_codebook[] = { U16(0x100),
                                   U16(0x210), U16(0x210),
                                   U16(0x321), U16(0x321),
                                   U16(0x430), U16(0x430), };

/* Table 10 */
const uint16_t k_ac_run_codebook  [] = { U16(0x102), U16(0x102), U16(0x101), U16(0x101),
                                         U16(0x100), U16(0x211), U16(0x211), U16(0x211),
                                         U16(0x211), U16(0x210), U16(0x210), U16(0x210),
                                         U16(0x210), U16(0x210), U16(0x210), U16(0x320), };
/* Table 11 */
const uint16_t k_ac_level_codebook[] = { U16(0x202), U16(0x101), U16(0x102), U16(0x100),
                                         U16(0x210), U16(0x210), U16(0x210), U16(0x210),
                                         U16(0x320) };

/* Figure 4, encoded as (x << 0) | (y << 4) */
const uint8_t k_scan_tbl[] = {
    U8(0x00), U8(0x01), U8(0x10), U8(0x11), U8(0x02), U8(0x03), U8(0x12), U8(0x13),
    U8(0x20), U8(0x21), U8(0x30), U8(0x31), U8(0x22), U8(0x23), U8(0x32), U8(0x33),
    U8(0x04), U8(0x05), U8(0x14), U8(0x24), U8(0x15), U8(0x06), U8(0x07), U8(0x16),
    U8(0x25), U8(0x34), U8(0x35), U8(0x26), U8(0x17), U8(0x27), U8(0x36), U8(0x37),
    U8(0x40), U8(0x41), U8(0x50), U8(0x60), U8(0x51), U8(0x42), U8(0x43), U8(0x52),
    U8(0x61), U8(0x70), U8(0x71), U8(0x62), U8(0x53), U8(0x44), U8(0x45), U8(0x54),
    U8(0x63), U8(0x72), U8(0x73), U8(0x64), U8(0x55), U8(0x46), U8(0x47), U8(0x56),
    U8(0x65), U8(0x74), U8(0x75), U8(0x66), U8(0x57), U8(0x67), U8(0x76), U8(0x77),
};

/* Figure 5 */
const uint8_t k_scan_tbl_interlaced[] = {
    U8(0x00), U8(0x10), U8(0x01), U8(0x11), U8(0x20), U8(0x30), U8(0x21), U8(0x31),
    U8(0x02), U8(0x12), U8(0x03), U8(0x13), U8(0x22), U8(0x32), U8(0x23), U8(0x33),
    U8(0x40), U8(0x50), U8(0x41), U8(0x42), U8(0x51), U8(0x60), U8(0x70), U8(0x61),
    U8(0x52), U8(0x43), U8(0x53), U8(0x62), U8(0x71), U8(0x72), U8(0x63), U8(0x73),
    U8(0x04), U8(0x14), U8(0x05), U8(0x06), U8(0x15), U8(0x24), U8(0x34), U8(0x25),
    U8(0x16), U8(0x07), U8(0x17), U8(0x26), U8(0x35), U8(0x44), U8(0x54), U8(0x45),
    U8(0x36), U8(0x27), U8(0x37), U8(0x46), U8(0x55), U8(0x64), U8(0x74), U8(0x65),
    U8(0x56), U8(0x47), U8(0x57), U8(0x66), U8(0x75), U8(0x76), U8(0x67), U8(0x77),
};

shared uint16_t dc_codebook      [k_dc_codebook      .length()],
                ac_run_codebook  [k_ac_run_codebook  .length()],
                ac_level_codebook[k_ac_level_codebook.length()];

shared uint8_t  scan_tbl[k_scan_tbl.length()];

void put_px(uint tex_idx, ivec2 pos, uint v)
{
    if (interlaced)
        pos = ivec2(pos.x, (pos.y << 1) + bottom_field);
    imageStore(dst[nonuniformEXT(tex_idx)], pos, uvec4(uint16_t(v)));
}

/* 7.5.3 Pixel Arrangement */
ivec2 pos_to_block(uint pos, uint luma)
{
    return ivec2((pos & -luma - 2) + luma >> 1, pos >> luma & 1) << 3;
}

/* 7.1.1.2 Signed Golomb Combination Codes */
uint to_signed(uint x)
{
    return (x >> 1) ^ -(x & 1);
}

/* 7.1.1.1 Golomb Combination Codes */
uint decode_codeword(inout GetBitContext gb, int codebook)
{
    int last_rice_q = bitfieldExtract(codebook, 0, 4),
        krice       = bitfieldExtract(codebook, 4, 4),
        kexp        = bitfieldExtract(codebook, 8, 4);

    int q = 31 - findMSB(show_bits(gb, 32));
    if (q <= last_rice_q) {
        /* Golomb-Rice encoding */
        return (get_bits(gb, krice + q + 1) & ~(1 << krice)) + (q << krice);
    } else {
        /* exp-Golomb encoding */
        return get_bits(gb, (q << 1) + kexp - last_rice_q) - (1 << kexp) + ((last_rice_q + 1) << krice);
    }
}

void decode_comp(in GetBitContext gb, uvec2 mb_pos, uint mb_count)
{
    uvec3 gid = gl_GlobalInvocationID;
    uint is_luma = uint(gid.z == 0);
    uint chroma_shift = bool(is_luma) ? 0 : log2_chroma_w;

    uint num_blocks = mb_count << (2 - chroma_shift);
    ivec2 base_pos = ivec2(mb_pos.x << (4 - chroma_shift), mb_pos.y << 4);

    /* 7.1.1.3 DC Coefficients */
    {
        /* First coeff */
        uint c = to_signed(decode_codeword(gb, 0x650));
        put_px(gid.z, base_pos, c);

        uint cw = 5, prev_dc_diff = 0;
        for (int i = 1; i < num_blocks; ++i) {
            cw = decode_codeword(gb, dc_codebook[min(cw, 6)]);

            int s = int(prev_dc_diff) >> 31;
            c += prev_dc_diff = (to_signed(cw) ^ s) - s;

            put_px(gid.z, base_pos + pos_to_block(i, is_luma), c);
        }
    }

    /* 7.1.1.4 AC Coefficients */
    {
        uint block_mask  = num_blocks - 1;
        uint block_shift = findLSB(num_blocks);

        uint pos = num_blocks - 1, run = 4, level = 1, s;
        while (pos < num_blocks << 6) {
            int left = left_bits(gb);
            if (left <= 0 || (left < 32 && show_bits(gb, left) == 0))
                break;

            run   = decode_codeword(gb, ac_run_codebook  [min(run,   15)]);
            level = decode_codeword(gb, ac_level_codebook[min(level, 8 )]);
            s     = get_bits(gb, 1);

            pos += run + 1;

            uint bidx  = pos & block_mask, scan = scan_tbl[pos >> block_shift];
            ivec2 spos = pos_to_block(bidx, is_luma);
            ivec2 bpos = ivec2(scan & 0xf, scan >> 4);

            uint c = ((level + 1) ^ -s) + s;
            put_px(gid.z, base_pos + spos + bpos, c);
        }
    }
}

/* 7.1.2 Scanned Alpha */
void decode_alpha(in GetBitContext gb, uvec2 mb_pos, uint mb_count)
{
    uvec3 gid = gl_GlobalInvocationID;

    ivec2 base_pos = ivec2(mb_pos) << 4;
    uint block_shift = findMSB(mb_count) + 4, block_mask = (1 << block_shift) - 1;

    uint mask = (1 << (4 << alpha_info)) - 1;
    uint num_values = (mb_count << 4) * min(height - (gid.y << 4), 16);

    int num_cw_bits  = alpha_info == 1 ? 5 : 8,
        num_flc_bits = alpha_info == 1 ? 9 : 17;

    uint alpha_rescale_lshift = alpha_info == 1 ? depth - 8 : 16,
         alpha_rescale_rshift = 16 - depth;

    uint alpha = -1;
    for (uint pos = 0; pos < num_values;) {
        uint diff, run;

        /* Decode run value */
        {
            uint bits = show_bits(gb, num_cw_bits), q = num_cw_bits - 1 - findMSB(bits);

            /* Tables 13/14 */
            if (q != 0) {
                uint m = (bits >> 1) + 1, s = bits & 1;
                diff = (m ^ -s) + s;
                skip_bits(gb, num_cw_bits);
            } else {
                diff = get_bits(gb, num_flc_bits);
            }

            alpha = alpha + diff & mask;
        }

        /* Decode run length */
        {
            uint bits = show_bits(gb, 5), q = 4 - findMSB(bits);

            /* Table 12 */
            if (q == 0) {
                run = 1;
                skip_bits(gb, 1);
            } else if (q <= 4) {
                run = bits + 1;
                skip_bits(gb, 5);
            } else {
                run = get_bits(gb, 16) + 1;
            }

            run = min(run, num_values - pos);
        }

        /**
         * FFmpeg doesn't support color and alpha with different precision,
         * so we need to rescale to the color range.
         */
        uint val = (alpha << alpha_rescale_lshift) | (alpha >> alpha_rescale_rshift);
        for (uint end = pos + run; pos < end; ++pos)
            put_px(3, base_pos + ivec2(pos & block_mask, pos >> block_shift), val);
    }
}

void main(void)
{
    uvec3 gid = gl_GlobalInvocationID;
    if (gid.x >= slice_width || gid.y >= slice_height)
        return;

    uint slice_idx = gid.y * slice_width + gid.x;
    uint slice_off  = slice_offsets[slice_idx],
         slice_size = slice_offsets[slice_idx + 1] - slice_off;

    u8buf bs = u8buf(slice_data + slice_off);

    /* Decode slice header */
    uint hdr_size, qidx, y_size, u_size, v_size, a_size;
    hdr_size = bs[0].v >> 3, qidx = clamp(bs[1].v, 1, 224);
    y_size = (uint(bs[2].v) << 8) | bs[3].v;
    u_size = (uint(bs[4].v) << 8) | bs[5].v;

    /**
     * The alpha_info field can be 0 even when an alpha plane is present,
     * if skip_alpha is enabled, so use the header size instead.
     */
    if (hdr_size > 6)
        v_size = (uint(bs[6].v) << 8) | bs[7].v;
    else
        v_size = slice_size - hdr_size - y_size - u_size;

    a_size = slice_size - hdr_size - y_size - u_size - v_size;

    bs += hdr_size;
    int bs_size = 0;
    switch (gid.z) {
        case 0:
            bs_size = int(y_size);
            break;
        case 1:
            bs_size = int(u_size), bs += y_size;
            break;
        case 2:
            bs_size = int(v_size), bs += y_size + u_size;
            break;
        case 3:
            bs_size = int(a_size), bs += y_size + u_size + v_size;
            break;
    }

    GetBitContext gb;
    init_get_bits(gb, bs, bs_size);

    /**
     * Support for the grayscale "extension" in the prores_aw encoder.
     * According to the spec, entropy coded data should never be empty,
     * and instead contain at least the DC coefficients.
     * This avoids undefined behavior.
     */
    if (left_bits(gb) == 0)
        return;

    /* Copy constant tables to local memory */
    dc_codebook       = k_dc_codebook;
    ac_run_codebook   = k_ac_run_codebook;
    ac_level_codebook = k_ac_level_codebook;

    if (!interlaced)
        scan_tbl = k_scan_tbl;
    else
        scan_tbl = k_scan_tbl_interlaced;

    /**
     * 4 ProRes Frame Structure
     * ProRes tiles pictures into a grid of slices, whose size is determined
     * by the log2_slice_width parameter (height is always 1 MB).
     * Each slice has a width of (1 << log2_slice_width) MBs, until the picture
     * cannot accommodate a full one. At this point, the remaining space
     * is recursively completed using the first smaller power of two that fits
     * (see Figure 1).
     * The maximum number of extra slices is 3, when log2_slice_width is 3,
     * with sizes 4, 2 and 1 MBs.
     * The mb_width parameter therefore also represents the number of full slices,
     * when interpreted as a fixed-point number with log2_slice_width fractional bits.
     */
    uint frac      = bitfieldExtract(uint(mb_width), 0, log2_slice_width),
         num_extra = bitCount(frac);

    uint diff = slice_width - gid.x - 1,
         off  = max(int(diff - num_extra + 1) << 2, 0);

    uint log2_width = min(findLSB(frac - diff >> diff) + diff + off, log2_slice_width);

    uint mb_x = (min(gid.x, slice_width - num_extra) << log2_slice_width) +
                (frac & (0xf << log2_width + 1)),
         mb_y = gid.y;
    uint mb_count = 1 << log2_width;

    if (gid.z < 3) {
        /* Color entropy decoding, inverse scanning */
        decode_comp(gb, uvec2(mb_x, mb_y), mb_count);
    } else {
        /* Alpha entropy decoding */
        decode_alpha(gb, uvec2(mb_x, mb_y), mb_count);
    }

    /* Forward the quantization index to the IDCT shader */
    if (gid.z == 0) {
        uint base = mb_y * mb_width + mb_x;
        for (uint i = 0; i < mb_count; ++i)
            quant_idx[base + i] = uint8_t(qidx);
    }
}
