#! /usr/bin/python3
# vim: set filetype=python:

# opt-jpg: losslessly optimize JPEG files

# Copyright (C) 2004-2026 by Brian Lindholm.  This file is part of the
# littleutils utility set.
#
# The opt-jpg utility is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 3, or (at your option) any later version.
#
# The opt-jpg utility is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# the littleutils.  If not, see <https://www.gnu.org/licenses/>.

import concurrent.futures, getopt, os, signal, subprocess, sys, tempfile

### PREP SIGNAL HANDLER ###
interrupted = False
def handler(signum, frame):
    global interrupted
    interrupted = True
for signal_VAL in (signal.SIGHUP, signal.SIGINT, signal.SIGPIPE, signal.SIGQUIT, signal.SIGTERM):
    signal.signal(signal_VAL, handler)

### GET INPUT ARGUMENTS ###
# print online help
def usage(rc: int) -> None:
    print('opt-jpg 1.4.0')
    print('usage: opt-jpg [-d DCT] [-f filelist] [-g(rayscale)] [-h(elp)] [-m markers]')
    print('         [-M max_quality] [-p(ipe)] [-q(uiet)] [-r rotation_angle] [-t(ouch)]')
    print('         [-T threads] filename ...')
    sys.exit(rc)
# load list of files
def load_list_from_file() -> None:
    if not os.path.isfile(opt_f):  # abort if file does not exist
        print('opt-jpg error: file list %s does not exist' % opt_f, file=sys.stderr)
        sys.exit(1)
    try:
        FILE = open(opt_f, 'r')
    except:  # abort if file cannot be opened for read
        print('opt-jpg error: file list %s cannot be opened' % opt_f, file=sys.stderr)
        sys.exit(1)
    filelist.extend(FILE.read().splitlines())
    FILE.close()
# load list of files from stdin
def load_list_from_stdin() -> None:
    filelist.extend(sys.stdin.read().splitlines())
    sys.stdin.close()
# set defaults
filelist = []
opt_a = False   # perform trials with arithmetic coding
opt_d = 'float' # type of DCT arithmetic to use
opt_f = None    # file containing list of files to process
opt_g = False   # convert to grayscale
opt_m = 'none'  # markers to copy
opt_M = None    # max quality (1 to 100)
opt_p = False   # read list of files to process from stdin
opt_q = False   # be quiet
opt_r = None    # rotation angle (in degrees)
opt_t = False   # "touch" re-written files to preserve timestamps
opt_T = None    # requested thread-count
# get command-line options
try:
    opts, filelist = getopt.getopt(sys.argv[1:], 'ad:f:ghm:M:pqr:tT:', 'help')
except getopt.error as msg:
    # print help if bad opts used, then quit
    print(msg)
    usage(1)
# parse options
for o, v in opts:
    if o in ('-h', '--help'): usage(0)
    elif o == '-a': opt_a = True
    elif o == '-d': opt_d = str(v)
    elif o == '-f': opt_f = str(v)
    elif o == '-g': opt_g = True
    elif o == '-m': opt_m = str(v)
    elif o == '-M': opt_M = int(v)
    elif o == '-p': opt_p = True
    elif o == '-q': opt_q = True
    elif o == '-r': opt_r = str(v)
    elif o == '-t': opt_t = True
    elif o == '-T': opt_T = int(v)
# ensure that we're not using -M and -r simultaneously
if (opt_M != None) and (opt_r != None):
    print('opt-jpg error: -M and -r options cannot be used simultaneously', file=sys.stderr)
    sys.exit(1)
if opt_M != None: opt_m = 'none'
# load file list from file and/or stdin if requested
if opt_f != None: load_list_from_file()
if opt_p: load_list_from_stdin()
# make sure we have at least one file to process
if len(filelist) == 0:
    if (not opt_f) and (not opt_p): usage(1)
    sys.exit(0)
# remove leading './' and trim list to unique items
filelist = [x.removeprefix('./') for x in filelist]
seen = set()
unique_filelist = [x for x in filelist if x not in seen and (seen.add(x) or True)]
# pick a reasonable default if thread-count is unspecified
if opt_T == None: opt_T = max(1, min(os.cpu_count() // 2, len(unique_filelist)))

### MAIN PROGRAM ###
# preload image types and depths
image_depth = {}
image_samp = {}
image_type = {}
def preload_image_types() -> None:
    TMPFILE = tempfile.SpooledTemporaryFile(mode='w+')
    for filename in unique_filelist: print(filename, file=TMPFILE)
    TMPFILE.seek(0)
    IMAGSIZE = subprocess.run(['imagsize', '-p'], stdin=TMPFILE, capture_output=True, text=True)
    TMPFILE.close()
    lines = IMAGSIZE.stdout.splitlines()
    for line in lines:
        field = line.split('\t')
        image_type[field[0]] = field[-1].removeprefix('type=')
        if image_type[field[0]].startswith('jpg'):
            image_depth[field[0]] = int(field[3].removeprefix('depth='))
            image_samp[field[0]] = field[5].removeprefix('samp=')
# optimize the JPEG file
def process_file(filename: str) -> None:
    # abort if file does not exist
    if interrupted: return
    if not os.path.isfile(filename):
        if not opt_q: print('opt-jpg error: %s is not a file' % filename, file=sys.stderr)
        return
    # skip zero-length files
    origsize = os.path.getsize(filename)
    if origsize == 0: return
    # skip non-JPEG images
    if filename in image_type:
        if not image_type[filename].startswith('jpg'):
            if not opt_q: print('opt-jpg error: %s is not a JPEG image' % filename, file=sys.stderr)
            return
    # grab initial timestamp if required
    if opt_t: timestamp = os.path.getmtime(filename)
    # initialize data structures
    TMPJPG = {}
    newsize = {}
    maxqsize = {}
    # run jpegtran without and with progressive scans
    idx = 'huff-base'  # trial 1
    args = ['jpegtran', '-optimize', '-copy', opt_m]
    if opt_g and (image_depth[filename] > 16): args.insert(1, '-grayscale')
    if opt_r != None: args.extend(['-rotate', opt_r, '-trim'])
    args.append(filename)
    TMPJPG[idx] = subprocess.run(args, capture_output=True, text=False)
    if interrupted: return
    newsize[idx] = len(TMPJPG[idx].stdout)
    if newsize[idx] == 0:
        if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
        return
    idx = 'huff-prog'  # trial 2
    args.insert(1, '-progressive')
    TMPJPG[idx] = subprocess.run(args, capture_output=True, text=False)
    if interrupted: return
    newsize[idx] = len(TMPJPG[idx].stdout)
    if newsize[idx] == 0:
        if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
        return
    # perform additional trials with arithmetic coding
    if opt_a:
        idx = 'arith-sequ'  # trial 3
        args = ['jpegtran', '-arithmetic', '-copy', opt_m]
        if opt_g and (image_depth[filename] > 16): args.insert(1, '-grayscale')
        if opt_r != None: args.extend(['-rotate', opt_r, '-trim'])
        args.append(filename)
        TMPJPG[idx] = subprocess.run(args, capture_output=True, text=False)
        if interrupted: return
        newsize[idx] = len(TMPJPG[idx].stdout)
        if newsize[idx] == 0:
            if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
            return
        idx = 'arith-prog'  # trial 4
        args.insert(1, '-progressive')
        TMPJPG[idx] = subprocess.run(args, capture_output=True, text=False)
        if interrupted: return
        newsize[idx] = len(TMPJPG[idx].stdout)
        if newsize[idx] == 0:
            if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
            return
    if opt_M != None:
        # run djpeg/cjpeg without and with progressive scans, max quality
        args0 = ['djpeg', '-dct', opt_d, '-pnm']
        if opt_g and (image_depth[filename] > 16): args0.insert(1, '-grayscale')
        args0.append(filename)
        TMPPNM = subprocess.run(args0, capture_output=True, text=False)
        if interrupted: return
        if len(TMPPNM.stdout) == 0:
            if not opt_q: print('opt-jpg error: failed cjpeg to PNM of %s' % filename, file=sys.stderr)
            return
        idx = 'huff-base-maxq'  # trial 5
        args = ['cjpeg',  '-optimize', '-dct', opt_d, '-quality', str(opt_M), '-sample', image_samp[filename]]
        if opt_g and (image_depth[filename] > 16): args.insert(1, '-grayscale')
        TMPJPG[idx] = subprocess.run(args, input=TMPPNM.stdout, capture_output=True, text=False)
        if interrupted: return
        maxqsize[idx] = len(TMPJPG[idx].stdout)
        if maxqsize[idx] == 0:
            if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
            return
        idx = 'huff-prog-maxq'  # trial 6
        args.insert(1, '-progressive')
        TMPJPG[idx] = subprocess.run(args, input=TMPPNM.stdout, capture_output=True, text=False)
        if interrupted: return
        maxqsize[idx] = len(TMPJPG[idx].stdout)
        if maxqsize[idx] == 0:
            if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
            return
        # perform additional trials with arithmetic coding, max quality
        if opt_a:
            idx = 'arith-sequ-maxq'  # trial 7
            args = ['cjpeg',  '-arithmetic', '-dct', opt_d, '-quality', str(opt_M), '-sample', image_samp[filename]]
            if opt_g and (image_depth[filename] > 16): args.insert(1, '-grayscale')
            TMPJPG[idx] = subprocess.run(args, input=TMPPNM.stdout, capture_output=True, text=False)
            if interrupted: return
            maxqsize[idx] = len(TMPJPG[idx].stdout)
            if maxqsize[idx] == 0:
                if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
                return
            idx = 'arith-prog-maxq'  # trial 8
            args.insert(1, '-progressive')
            TMPJPG[idx] = subprocess.run(args, input=TMPPNM.stdout, capture_output=True, text=False)
            if interrupted: return
            maxqsize[idx] = len(TMPJPG[idx].stdout)
            if maxqsize[idx] == 0:
                if not opt_q: print('opt-jpg error: failed %s processing of %s' % (idx, filename), file=sys.stderr)
                return
    # find smallest file of results
    if interrupted: return
    best_idx = None
    for idx in newsize:
        if best_idx == None: best_idx = idx
        elif newsize[idx] < newsize[best_idx]: best_idx = idx
    # use best result if smaller
    if (opt_M == None) and (opt_r == None):
        if newsize[best_idx] < origsize:
            try:
                FILE = open(filename, 'wb')
            except:
                print('opt-jpg error: cannot write results to %s' % filename, file=sys.stderr)
                return
            FILE.write(TMPJPG[best_idx].stdout)
            FILE.close()
            if opt_t: os.utime(filename, (timestamp, timestamp))
            if not opt_q: print('%s: [%s] %d vs. %d' % (filename, best_idx, origsize, newsize[best_idx]))
        elif not opt_q:
            print('%s: unchanged' % filename, file=sys.stderr)
    elif opt_M != None:
        best_idx_maxq = None
        for idx in maxqsize:
            if best_idx_maxq == None: best_idx_maxq = idx
            elif maxqsize[idx] < maxqsize[best_idx_maxq]: best_idx_maxq = idx
        if (20 * maxqsize[best_idx_maxq]) < (19 * newsize[best_idx]):
            idx = best_idx_maxq
            size = maxqsize[best_idx_maxq]
        else:
            idx = best_idx
            size = newsize[best_idx]
        if size < origsize:
            try:
                FILE = open(filename, 'wb')
            except:
                print('opt-jpg error: cannot write results to %s' % filename, file=sys.stderr)
                return
            FILE.write(TMPJPG[idx].stdout)
            FILE.close()
            if opt_t: os.utime(filename, (timestamp, timestamp))
            if not opt_q: print('%s: [%s] %d vs. %d' % (filename, idx, origsize, size))
        elif not opt_q:
            print('%s: unchanged' % filename, file=sys.stderr)
    elif opt_r != None:
        try:
            FILE = open(filename, 'wb')
        except:
            print('opt-jpg error: cannot write results to %s' % filename, file=sys.stderr)
            return
        FILE.write(TMPJPG[best_idx].stdout)
        FILE.close()
        if opt_t: os.utime(filename, (timestamp, timestamp))
        if not opt_q: print('%s: rotated by %s and optimized [%s]' % (filename, opt_r, best_idx))

# process files
preload_image_types()
if opt_T < 2:
    for filename in unique_filelist:
        if interrupted: break
        process_file(filename)
else:
    with concurrent.futures.ThreadPoolExecutor(max_workers=opt_T) as executor:
        executor.map(process_file, unique_filelist)
