import protobuf from 'protobufjs';

const generateConfigFromOnnxFile = async (onnxFile) => {
  const root = await protobuf.load('/onnx.proto');

  const onnxTypeToTriton = (onnxType) => {
    const onnxTensorProto = root.lookupEnum('onnx.TensorProto.DataType').values;
    const typeMap = {
      [onnxTensorProto.FLOAT]: 'TYPE_FP32',
      [onnxTensorProto.UINT8]: 'TYPE_UINT8',
      [onnxTensorProto.INT8]: 'TYPE_INT8',
      [onnxTensorProto.UINT16]: 'TYPE_UINT16',
      [onnxTensorProto.INT16]: 'TYPE_INT16',
      [onnxTensorProto.INT32]: 'TYPE_INT32',
      [onnxTensorProto.INT64]: 'TYPE_INT64',
      [onnxTensorProto.BOOL]: 'TYPE_BOOL',
      [onnxTensorProto.FLOAT16]: 'TYPE_FP16',
      [onnxTensorProto.DOUBLE]: 'TYPE_FP64',
      [onnxTensorProto.UINT32]: 'TYPE_UINT32',
      [onnxTensorProto.UINT64]: 'TYPE_UINT64',
    };
    return typeMap[onnxType] || 'TYPE_INVALID';
  };

  const generateConfigFromTensor = (type, tensors) => {
    const tensorConfigStrings= [];

    tensorConfigStrings.push(`${type} [`);
    tensors.forEach((tensor) => {
      const tensorName = tensor.name;
      const tensorType = onnxTypeToTriton(tensor.type.tensorType.elemType);
      const tensorShape = tensor.type.tensorType.shape.dim.map((dim) =>
        dim.dimValue > 0 ? dim.dimValue.toString() : '-1'
      );
      const dimsStr = tensorShape.join(', ');

      tensorConfigStrings.push('  {');
      tensorConfigStrings.push(`    name: "${tensorName}"`);
      tensorConfigStrings.push(`    data_type: ${tensorType}`);
      tensorConfigStrings.push(`    dims: [ ${dimsStr} ]`);
      tensorConfigStrings.push('  },');
    });
    if (graph.input.length > 0) {
      const lastIndex = tensorConfigStrings.length - 1;
      tensorConfigStrings[lastIndex] = tensorConfigStrings[lastIndex].replace(/,$/, '');
    }
    tensorConfigStrings.push(']');

    return tensorConfigStrings;
  };

  const ModelProto = root.lookupType('onnx.ModelProto');
  const arrayBuffer = await onnxFile.arrayBuffer();
  const model = ModelProto.decode(new Uint8Array(arrayBuffer));
  const graph = model.graph;

  const modelName = onnxFile.name.split(".onnx")[0];

  const configLines = [];
  configLines.push(`name: "${modelName}"`);
  configLines.push('platform: "onnxruntime_onnx"');

  configLines.push(...generateConfigFromTensor("input", graph.input));
  configLines.push(...generateConfigFromTensor("output", graph.output));

  const configContent = configLines.join('\n');
  return configContent;
}

export { generateConfigFromOnnxFile };
