/**
 * Main program code for the fwflash utility.
 *
 * Copyright 2006 David Anderson <david.anderson@calixo.net>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "error.h"
#include "lowlevel.h"
#include "samba.h"

int flag_norun;

typedef uint16_t Elf32_Half;
typedef uint32_t Elf32_Word;
typedef uint32_t Elf32_Addr;
typedef uint32_t Elf32_Off;

#define EI_NIDENT 16
typedef struct
{
  unsigned char e_ident[EI_NIDENT];	/* Magic number and other info */
  Elf32_Half    e_type;                 /* Object file type */
  Elf32_Half    e_machine;              /* Architecture */
  Elf32_Word    e_version;              /* Object file version */
  Elf32_Addr    e_entry;                /* Entry point virtual address */
  Elf32_Off     e_phoff;                /* Program header table file offset */
  Elf32_Off     e_shoff;                /* Section header table file offset */
  Elf32_Word    e_flags;                /* Processor-specific flags */
  Elf32_Half    e_ehsize;               /* ELF header size in bytes */
  Elf32_Half    e_phentsize;            /* Program header table entry size */
  Elf32_Half    e_phnum;                /* Program header table entry count */
  Elf32_Half    e_shentsize;            /* Section header table entry size */
  Elf32_Half    e_shnum;                /* Section header table entry count */
  Elf32_Half    e_shstrndx;             /* Section header string table index */
} Elf32_Ehdr;

typedef struct
{
  Elf32_Word    p_type;                 /* Segment type */
  Elf32_Off     p_offset;               /* Segment file offset */
  Elf32_Addr    p_vaddr;                /* Segment virtual address */
  Elf32_Addr    p_paddr;                /* Segment physical address */
  Elf32_Word    p_filesz;               /* Segment size in file */
  Elf32_Word    p_memsz;                /* Segment size in memory */
  Elf32_Word    p_flags;                /* Segment flags */
  Elf32_Word    p_align;                /* Segment alignment */
} Elf32_Phdr;

#define EI_CLASS 4
#define ELFCLASS32 1

#define EI_DATA 5
#define ELFDATA2LSB 1

#define EI_VERSION 5
#define EV_CURRENT 1

#define ET_EXEC 2
#define EM_ARM 40

#define PT_LOAD 1

static Elf32_Half
get_elf32_half (Elf32_Half *v)
{
  unsigned char *p = (unsigned char *)v;
  return p[0] | (p[1] << 8);
}

static Elf32_Word
get_elf32_word (Elf32_Word *v)
{
     unsigned char *p = (unsigned char *)v;
     return p[0] | (p[1] << 8) | (p[2] << 16) | (p[3] << 24);
}

#define NXT_HANDLE_ERR(expr, nxt, msg)     \
  do {                                     \
    nxt_error_t nxt__err_temp = (expr);    \
    if (nxt__err_temp)                     \
      return handle_error(nxt, msg, nxt__err_temp);  \
  } while(0)

static int handle_error(nxt_t *nxt, char *msg, nxt_error_t err)
{
  printf("%s: %s\n", msg, nxt_str_error(err));
  if (nxt != NULL)
    nxt_close(nxt);
  exit(err);
}

static unsigned int elf_entry;

static int
load_elf (nxt_t *nxt, FILE *f)
{
  unsigned char buf[256];
  Elf32_Ehdr ehdr;
  Elf32_Phdr phdr[16];
  int nbr_phdr;
  int i;
  size_t res;

  rewind(f);
  if (fread(&ehdr, sizeof(ehdr), 1, f) != 1)
    {
      fprintf (stderr, "cannot read ELF header\n");
      return -1;
    }
  if (ehdr.e_ident[EI_CLASS] != ELFCLASS32
      || ehdr.e_ident[EI_DATA] != ELFDATA2LSB
      || ehdr.e_ident[EI_VERSION] != EV_CURRENT)
    {
      fprintf (stderr, "Invalid ELF header\n");
      return -1;
    }
  if (get_elf32_half (&ehdr.e_type) != ET_EXEC
      || get_elf32_half(&ehdr.e_machine) != EM_ARM
      || get_elf32_word(&ehdr.e_version) != EV_CURRENT
      || get_elf32_half(&ehdr.e_phentsize) != sizeof (Elf32_Phdr))
    {
      fprintf (stderr, "Wrong ELF type\n");
      return -1;
    }
  elf_entry = get_elf32_word (&ehdr.e_entry);

  nbr_phdr = get_elf32_half(&ehdr.e_phnum);
  if ((unsigned)nbr_phdr > sizeof(phdr)/sizeof(*phdr))
    return -1;

  if (fseek (f, get_elf32_word(&ehdr.e_phoff), SEEK_SET) != 0)
    return -1;
  res = fread(phdr, sizeof(Elf32_Phdr), nbr_phdr, f);
  if (res != (size_t)nbr_phdr)
    //if (fread(phdr, sizeof(phdr), nbr_phdr, f) != (size_t)nbr_phdr)
    return -1;
  for (i = 0; i < nbr_phdr; i++) {
    Elf32_Addr addr = get_elf32_word (&phdr[i].p_paddr);
    Elf32_Word filesz = get_elf32_word (&phdr[i].p_filesz);
    Elf32_Word memsz= get_elf32_word (&phdr[i].p_memsz);

    if (get_elf32_word (&phdr[i].p_type) != PT_LOAD)
      continue;
    if (filesz == 0)
      continue;
    if (verbose)
      printf ("Load 0x%06x - 0x%06x\n",
              (unsigned)addr, (unsigned)(addr + memsz));
    if (fseek (f, get_elf32_word(&phdr[i].p_offset), SEEK_SET) != 0)
      return -1;
    while (filesz > 0)
      {
        size_t l;

        if (filesz > sizeof (buf))
          l = sizeof (buf);
        else
          l = filesz;

        if (fread (buf, l, 1, f) != 1)
          return -1;
        NXT_HANDLE_ERR (nxt_send_file (nxt, addr, buf, l), nxt, "send buffer");

        filesz -= l;
        addr += l;
        memsz -= l;
      }

    /* No need to clear memory.  */
    if (memsz > 0)
      {
        memset (buf, 0, sizeof (buf));

        while (memsz > 0)
          {
            size_t l;

            if (memsz > sizeof (buf))
              l = sizeof (buf);
            else
              l = memsz;

            NXT_HANDLE_ERR (nxt_send_file (nxt, addr, buf, l), nxt,
                            "clear mem");

            addr += l;
            memsz -= l;
          }
      }
  }
  return 0;
}

const char *progname;

void
help (void)
{
  printf ("Usage: %s [-d device] [-v] elf-file\n", progname);
}

int main(int argc, char *argv[])
{
  nxt_t *nxt;
  nxt_error_t err;
  char *file;
  FILE *f;
  char *device = NULL;
  int c;
  int i;

  progname = argv[0];

  for (i = 1; i < argc; i++)
    {
      if (argv[i][0] != '-')
        break;
      if (strcmp (argv[i], "-d") == 0)
        {
          i++;
          if (i >= argc)
            {
              help ();
              return 1;
            }
          device = argv[i];
        }
      else if (strcmp (argv[i], "-v") == 0)
        verbose++;
      else if (strcmp (argv[i], "-n") == 0)
        flag_norun = 1;
      else
        {
          help ();
          return 1;
        }
    }

  if (i + 1 != argc)
    {
      help ();
      return 1;
    }

  file = argv[i];

  f = fopen (file, "rb");
  if (f == NULL)
    {
      fprintf (stderr, "cannot open elf file %s\n", file);
      return 1;
    }

  NXT_HANDLE_ERR(nxt_init(&nxt, device), NULL,
                 "Error during library initialization");

  err = nxt_find(nxt);
  if (err)
    {
      if (err == NXT_NOT_PRESENT)
        printf("NXT not found. Is it properly plugged in via USB?\n");
      else
        NXT_HANDLE_ERR(0, NULL, "Error while scanning for NXT");
      exit(1);
    }

  if (!nxt_in_reset_mode(nxt))
    {
      printf("NXT found, but not running in reset mode.\n");
      printf("Please reset your NXT manually and restart this program.\n");
      exit(2);
    }

  NXT_HANDLE_ERR(nxt_open(nxt), NULL, "Error while connecting to NXT");


  NXT_HANDLE_ERR(nxt_samba_ping(nxt), NULL, "Error while pinging NXT");

  if (load_elf (nxt, f) != 0)
    {
      fprintf (stderr, "failed to download image\n");
      return 1;
    }

  printf("Image download complete.\n");

  if (!flag_norun)
    {
      NXT_HANDLE_ERR(nxt_jump(nxt, elf_entry), nxt,
                     "Error booting new firmware");
      printf("Image started at 0x%08x\n", elf_entry);
    }

  NXT_HANDLE_ERR(nxt_close(nxt), NULL,
                 "Error while closing connection to NXT");
  return 0;
}
