diff --git a/src/athena/Compression.cpp b/src/athena/Compression.cpp index 47bd6e0..e3fc13d 100644 --- a/src/athena/Compression.cpp +++ b/src/athena/Compression.cpp @@ -10,8 +10,7 @@ namespace athena::io::Compression { -atInt32 decompressZlib(const atUint8* src, atUint32 srcLen, atUint8* dst, atUint32 dstLen) { - z_stream strm = {}; +void zlibInitZStrm(const atUint8* src, atUint32 srcLen, atUint8* dst, atUint32 dstLen, z_stream& strm) { strm.avail_in = srcLen; strm.avail_out = dstLen; strm.next_in = const_cast(src); @@ -19,53 +18,34 @@ atInt32 decompressZlib(const atUint8* src, atUint32 srcLen, atUint8* dst, atUint strm.zalloc = Z_NULL; strm.zfree = Z_NULL; strm.opaque = Z_NULL; - - atInt32 ret; - // 15 window bits, and the | 16 tells zlib to to detect if using gzip or zlib - ret = inflateInit2(&strm, MAX_WBITS | 16); - +} +atInt32 zlibInflate(z_stream& strm, atInt32 bits) { + atInt32 ret = inflateInit2(&strm, bits); if (ret == Z_OK) { ret = inflate(&strm, Z_FINISH); if (ret == Z_STREAM_END) { ret = strm.total_out; } } - inflateEnd(&strm); return ret; } - -atInt32 compressZlib(const atUint8* src, atUint32 srcLen, atUint8* dst, atUint32 dstLen) { +atInt32 decompressZlib(const atUint8* src, atUint32 srcLen, atUint8* dst, atUint32 dstLen) { z_stream strm = {}; - strm.total_in = strm.avail_in = srcLen; - strm.total_out = strm.avail_out = dstLen; - strm.next_in = (Bytef*)src; - strm.next_out = (Bytef*)dst; + zlibInitZStrm(src, srcLen, dst, dstLen, strm); - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; - - atInt32 err = -1; - atInt32 ret = -1; - - err = deflateInit(&strm, Z_BEST_COMPRESSION); - - if (err == Z_OK) { - err = deflate(&strm, Z_FINISH); - - if (err == Z_STREAM_END) + // 15 window bits, and the | 16 tells zlib to to detect if using gzip or zlib + atInt32 ret = zlibInflate(strm, MAX_WBITS | 16); + if (ret != Z_STREAM_END) { + // Try again without gzip + zlibInitZStrm(src, srcLen, dst, dstLen, strm); + ret = zlibInflate(strm, MAX_WBITS); + if (ret == Z_STREAM_END) { ret = strm.total_out; - else { - deflateEnd(&strm); - return err; } } else { - deflateEnd(&strm); - return err; + ret = strm.total_out; } - - deflateEnd(&strm); - + inflateEnd(&strm); return ret; }