/*
 * fdtextract -- Tool to extract sub images from FIT image.
 *
 * Copyright (C) 2021 IOPSYS Software Solutions AB. All rights reserved.
 *
 * Author: jonas.hoglund@iopsys.eu
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * version 2 as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA
 */

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <stdbool.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/sendfile.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>

#include <libfdt.h>
#include <fdt.h>

#include "util.h"

#define MAX_PATH_LEN 100
#define SHA_256_LEN 32

#define FIT_HASH_NODENAME "hash"

/* Usage related data. */
static const char usage_synopsis[] = "fdtextract [options] <file>";
static const char usage_short_opts[] = "le:o:s:a:i:z:" USAGE_COMMON_SHORT_OPTS;
static struct option const usage_long_opts[] = {
	{"list",	no_argument, NULL, 'l'},
	{"extract",	a_argument, NULL, 'e'},
	{"out",		a_argument, NULL, 'o'},
	{"hash",	a_argument, NULL, 's'},
	{"attribute",	a_argument, NULL, 'a'},
	{"image",	a_argument, NULL, 'i'},
	{"size",	a_argument, NULL, 'z'},
	USAGE_COMMON_LONG_OPTS
};
static const char * const usage_opts_help[] = {
	"List images embedded in FIT",
	"Extract image from FIT",
	"Output image name",
	"SHA256 hash of image",
	"Get attribute of FIT",
	"Use image in FIT to get attribute",
	"Size of embedded image in FIT",
	USAGE_COMMON_OPTS_HELP
};

/* List images embedded in the FIT. */
static int list_images(char *buf)
{
	int ndepth = 0, count = 0, noffset;
	const char *name;

	/* Find images root node. */
	noffset = fdt_path_offset(buf, "/images");
	if (noffset < 0) {
		return -1;
	}

	/* Iterate over all images. */
	noffset = fdt_next_node(buf, noffset, &ndepth);
	while (ndepth > 0) {
		if (ndepth == 1) {
			name = fdt_get_name(buf, noffset, NULL);
			printf("[%d]: %s\n", count, name);
			count++;
		}
		noffset = fdt_next_node(buf, noffset, &ndepth);
	}

	return 0;
}

/* Print the size of an image embedded in the FIT. */
static int get_size(char *buf, char *name)
{
	char path[MAX_PATH_LEN] = {0};
	int noffset, data_size;
	const fdt32_t *val;

	snprintf(path, MAX_PATH_LEN, "/images/%s", name);
	noffset = fdt_path_offset(buf, path);
	if (noffset < 0) {
		fprintf(stderr, "Error: could not find image: %s.\n", name);
		return -1;
	}

	val = fdt_getprop(buf, noffset, "data-size", NULL);
	if (val) {
		data_size = fdt32_to_cpu(*val);
	} else {
		fdt_getprop(buf, noffset, "data", &data_size);
	}

	if (data_size < 0) {
		fprintf(stderr, "Error: Could not get image data size: %s.\n", name);
		return -1;
	}

	printf("%d\n", data_size);

	return 0;
}


static ssize_t copy_data(int out_fd, int in_fd, ssize_t size)
{
	ssize_t left = size;

	if (left < 0) {
		errno = EINVAL;
		return -1;
	}

	while (left > 0) {
		ssize_t written = sendfile(out_fd, in_fd, NULL, left);
		if (written < 0 && errno == EINVAL)
			break;
		if (written < 0)
			return -1;
		if (written == 0)
			return size - left;
		left -= written;
	}

	while (left > 0) {
		char buf[4096];
		ssize_t count = (ssize_t)sizeof(buf) < left ? (ssize_t)sizeof(buf) : left;
		ssize_t written;

		count = read(in_fd, buf, count);
		if (count < 0)
			return -1;
		if (count == 0)
			break;

		written = write(out_fd, buf, count);
		if (written < 0)
			return -1;
		left -= written;
		if (written < count)
			break;
	}

	return size - left;
}

/* Extract an image embedded in the FIT. */
static int extract_image(char *buf, char *name, int in_fd, int out_fd)
{
	char path[MAX_PATH_LEN] = {0};
	int noffset, count, data_size;
	unsigned int data_offset;
	const fdt32_t *val;
	bool is_external = false;

	snprintf(path, MAX_PATH_LEN, "/images/%s", name);
	noffset = fdt_path_offset(buf, path);
	if (noffset < 0) {
		fprintf(stderr, "Error: could not find image: %s.\n", name);
		return -1;
	}

	/* Get offset of image. Try both relative and absolute offset. */
	if ((val = fdt_getprop(buf, noffset, "data-offset", NULL))) {
		/* Relative offset */
		data_offset = fdt32_to_cpu(*val);
		data_offset += ((fdt_totalsize(buf) + 3) & ~3);
		is_external = true;
	} else if ((val = fdt_getprop(buf, noffset, "data-position", NULL))) {
		/* Absolute offset */
		data_offset = fdt32_to_cpu(*val);
		is_external = true;
	}

	if (is_external) {
		/* Size */
		val = fdt_getprop(buf, noffset, "data-size", NULL);
		if (!val) {
			fprintf(stderr, "Error: Could not get image size: %s.\n", name);
			return -1;
		}
		data_size = fdt32_to_cpu(*val);
		if (lseek(in_fd, data_offset, SEEK_SET) < 0) {
			fprintf(stderr, "Error: Could not lseek to: %u\n", data_offset);
			return errno;
		}
		count = copy_data(out_fd, in_fd, data_size);
	} else {
		const char *data = fdt_getprop(buf, noffset, "data", &data_size);
		if (data_size < 0) {
			fprintf(stderr, "Error: Could not get image data: %s.\n", name);
			return -1;
		}
		count = write(out_fd, data, data_size);
	}

	if (count < 0) {
		fprintf(stderr, "Error: I/O error while copying image data.\n");
		return errno;
	}

	if (count < data_size) {
		fprintf(stderr, "Error: Image data was truncated.\n");
		return -1;
	}

	return 0;
}

/* Print the sha256 hash on an image embedded in the FIT. */
static int get_hash(char *buf, char *name)
{
	char path[MAX_PATH_LEN] = {0};
	int noffset, i;
	uint8_t *val = NULL;
	const char *algo;

	/* Get path of image hash node. */
	snprintf(path, MAX_PATH_LEN, "/images/%s", name);
	noffset = fdt_path_offset(buf, path);
	if (noffset < 0) {
		fprintf(stderr, "Error: could not find image: %s.\n", name);
		return -1;
	}

	for (noffset = fdt_first_subnode(buf, noffset);
	    noffset >= 0;
	    noffset = fdt_next_subnode(buf, noffset)) {
		/* Check subnode name, must start with "hash" */
		const char *node_name = fdt_get_name(buf, noffset, NULL);
		if (!strncmp(node_name, FIT_HASH_NODENAME,
				strlen(FIT_HASH_NODENAME))) {
			/* Verify that we know the hash algo. */
			algo = fdt_getprop(buf, noffset, "algo", NULL);
			if (algo && !strcmp(algo, "sha256")) {
				val = (uint8_t *)fdt_getprop(buf, noffset, "value", NULL);
				break;
			}
		}
	}
	if (!val) {
		fprintf(stderr, "Error: No suitable hash found for image %s.\n", name);
		return -1;
	}
	/* Print the hash. */
	for (i=0; i<SHA_256_LEN; i++)
		printf("%02x", val[i]);
	printf("\n");

	return 0;
}

/* Print out an attribute of the root node. */
static int get_attribute(char *buf, char *name, char *imagename)
{
	int noffset;
	const char *val = NULL;
	char path[MAX_PATH_LEN] = "/";

	if (imagename)
		snprintf(path, MAX_PATH_LEN, "/images/%s", imagename);

	noffset = fdt_path_offset(buf, path);
	if (noffset < 0) {
		fprintf(stderr, "Error: invalid FDT path: %s\n", path);
		return -1;
	}

	/* Print the property value. */
	val = fdt_getprop(buf, noffset, name, NULL);
	if (!val) {
		fprintf(stderr, "Error: could not find property %s.\n", name);
		return -1;
	}
	printf("%s\n", val);

	return 0;
}

char *read_header(int fd)
{
	char *buf, *tmp;
	ssize_t len, total_size;

	/* Read minimal static struct */
	buf = malloc(FDT_V1_SIZE);
	if (!buf)
		return NULL;

	len = read(fd, buf, FDT_V1_SIZE);
	if (len < (ssize_t)FDT_V1_SIZE) {
		free(buf);
		return NULL;
	}

	/* Read rest of header */
	total_size = fdt_totalsize(buf);
	if (total_size < (ssize_t)FDT_V1_SIZE) {
		free(buf);
		return NULL;
	}

	tmp = realloc(buf, total_size);
	if (total_size && !tmp) {
		free(buf);
		return NULL;
	}
	buf = tmp;

	len = read(fd, buf + FDT_V1_SIZE, total_size - FDT_V1_SIZE);
	if (len < total_size - (ssize_t)FDT_V1_SIZE) {
		free(buf);
		return NULL;
	}

	return buf;
}

int main(int argc, char *argv[])
{
	const char *file;
	int opt, ret = 0;
	char *buf, *name = NULL, *out = NULL, *imagename = NULL;
	bool list = false, extract = false,
		hash = false, attribute = false, size = false;
	int in_fd, out_fd = STDOUT_FILENO;

	while ((opt = util_getopt_long()) != EOF) {
		if ((opt == 'l' || opt == 'e' || opt == 's' || opt == 'a') &&
		    (list || extract || hash || attribute))
			usage("only one of --list/--extract/--hash/--attribute allowed");

		switch (opt) {
		case_USAGE_COMMON_FLAGS

		case 'l':
			list = true;
			break;
		case 'z':
			size = true;
			name = optarg;
			break;
		case 'e':
			extract = true;
			name = optarg;
			break;
		case 'o':
			out = optarg;
			break;
		case 's':
			hash = true;
			name = optarg;
			break;
		case 'a':
			attribute = true;
			name = optarg;
			break;
		case 'i':
			imagename = optarg;
			break;
		}
	}
	if (optind != argc - 1)
		usage("missing input filename");
	file = argv[optind];

	in_fd = open(file, O_RDONLY);
	if (in_fd < 0) {
		die("could not open: %s\n", file);
	}

	buf = read_header(in_fd);
	if (!buf) {
		die("could not read header from: %s\n", file);
	}

	if (fdt_check_header(buf)) {
		die("Bad header in %s\n", file);
	}

	if (imagename != NULL && attribute == false) {
		die("--image should be used with --attribute\n");
	}

	if (out != NULL && extract == false) {
		die("--out should be used with --extract\n");
	}

	if (out && (out_fd = open(out, O_WRONLY | O_CREAT | O_TRUNC, 0666)) < 0) {
		die("could not open output file: %s\n", out);
	}

	/* Pass the pointer to the header to read attributes. */
	if (list)
		ret = list_images(buf);

	if (size)
		ret = get_size(buf, name);

	if (hash)
		ret = get_hash(buf, name);

	if (attribute)
		ret = get_attribute(buf, name, imagename);

	/* Let extract image read the entire file. */
	if (extract)
		ret = extract_image(buf, name, in_fd, out_fd);

	if (out_fd != STDOUT_FILENO)
		close(out_fd);

	free(buf);
	close(in_fd);

	return ret ? 1 : 0;
}