/* syscall logger module code.
 * 
 * Authors: guy keren <choo@actcom.co.il>
 *
 */

#include <linux/config.h>
#include <linux/version.h>
#include <linux/module.h>
#include <linux/tasks.h>

#include <linux/errno.h>
#include <linux/kernel.h>
#include <linux/malloc.h>
#include <linux/init.h>
#include <asm/uaccess.h>

#ifdef MODULE

/* which system calls we should hijack on module init. */
int hijack_sys_open = 0;
int hijack_sys_unlink = 0;

/* define module input parameters, to be integers (i) in the range of 0-1. */
#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,1,0)
MODULE_PARM(hijack_sys_open,"0-1i");
MODULE_PARM(hijack_sys_unlink,"0-1i");
#endif

/* typedef used for function pointers. */
/* this is architecture-dependant. */
#ifdef __i386__
typedef unsigned long func_ptr;
#else
#error This module works on i386 architectures only, until i get a sparc machine to try it with..
#endif

/* forward function declarations. */
extern int logger_module_init(void);
extern void logger_module_cleanup(void);
extern asmlinkage int logged_sys_unlink(const char* path);
extern asmlinkage int logged_sys_open(const char* filename,
				      int flags,
				      int mode);

/* addresses of supported system calls, as found in the sys_call_table. */
func_ptr sys_unlink_addr = 0;
func_ptr sys_open_addr = 0;

/* typedefs for function pointers of supported system calls. */
typedef asmlinkage int (*sys_unlink_type)(const char*);
typedef asmlinkage int (*sys_open_type)(const char*, int, int);

/* DON'T mess with these - they are required to allow proper expansion of */
/* (possibly) versioned symbols into strings.                             */
#define do_turn_into_versioned_string(x) #x
#define versioned_string(x) do_turn_into_versioned_string(x)

/* location of supported system calls in sys_call_table */
#define SYS_OPEN_CODE 5
#define SYS_UNLINK_CODE 10


/* log the call to sys_open, and then call the original system call. */
asmlinkage int logged_sys_open(const char * filename, int flags, int mode)
{
        unsigned long page;
        long retval;
	sys_open_type sys_open_ptr = (sys_open_type)sys_open_addr;

	/* copy the 'filename' parameter into kernel space, so we can */
	/* print it out. use __get_free_page to avoid zeroing out */
	/* which would be a waste of time. 			  */
        page = __get_free_page(GFP_KERNEL);
        if (!page)
                return -ENOMEM;

        retval = strncpy_from_user((char *)page, filename, PAGE_SIZE);
        if (retval > 0) {
                if (retval < PAGE_SIZE) {
                        const char* buf = (char *)page;
			printk(KERN_INFO
			       "'open' invoked for file '%s', "
			       "flags '0x%x', mode '0%o' by UID %d\n",
			       buf, flags, mode, current->euid);
        		free_page(page);
			/* invoke original system call */
			return (*sys_open_ptr)(filename, flags, mode);
                }
                retval = -ENAMETOOLONG;
        } else if (!retval)
                retval = -EINVAL;

        free_page(page);

	return retval;
}

/* log the call to sys_unlink, and then call the original systm call. */
asmlinkage int logged_sys_unlink(const char* path)
{
        unsigned long page;
        long retval;
	sys_unlink_type sys_unlink_ptr = (sys_unlink_type)sys_unlink_addr;

	/* copy the 'path' parameter into kernel space, so we can */
	/* print it out. use __get_free_page to avoid zeroing out */
	/* which would be a waste of time. 			  */
        page = __get_free_page(GFP_KERNEL);
        if (!page)
                return -ENOMEM;

        retval = strncpy_from_user((char *)page, path, PAGE_SIZE);
        if (retval > 0) {
                if (retval < PAGE_SIZE) {
                        const char* buf = (char *)page;
			printk(KERN_INFO
			       "'unlink' invoked for '%s' by UID %d\n",
			       buf, current->euid);
        		free_page(page);
			/* invoke original system call */
			return (*sys_unlink_ptr)(path);
                }
                retval = -ENAMETOOLONG;
        } else if (!retval)
                retval = -EINVAL;

        free_page(page);

	return retval;
}

int logger_module_init()
{
	// look for the sys_call_table symbol inside the kernel.
	func_ptr* sys_call_table_addr = (func_ptr*)
		get_module_symbol(NULL, versioned_string(sys_call_table));
	if (sys_call_table_addr) {
		printk("logger_module_init: "
		       "initializing syscall logger module\n");
		/* hijack sys_unlink */
		if (hijack_sys_unlink) {
			sys_unlink_addr =
				(func_ptr)(sys_call_table_addr[SYS_UNLINK_CODE]);
			sys_call_table_addr[SYS_UNLINK_CODE] =
					(func_ptr)logged_sys_unlink;
			printk(KERN_INFO
			       "replaced sys_unlink with our own version\n");
		}

		/* hijack sys_open */
		if (hijack_sys_open) {
			sys_open_addr =
		    		(func_ptr)(sys_call_table_addr[SYS_OPEN_CODE]);
			sys_call_table_addr[SYS_OPEN_CODE] =
					(func_ptr)logged_sys_open;
			printk(KERN_INFO
			       "replaced sys_open with our own version\n");
		}

		return 0;
	} else {
		printk("logger_module_init: failed locating sys_call_table\n");
		return -ENOSYS;
	}
}

void logger_module_cleanup()
{
	// look for the sys_call_table symbol inside the kernel.
	func_ptr* sys_call_table_addr = (func_ptr*)
		get_module_symbol(NULL, versioned_string(sys_call_table));
	printk("logger_module_cleanup: cleaning up syscall logger module\n");

	if (sys_call_table_addr) {
		if (sys_unlink_addr) {
			func_ptr* sys_unlink_ptr =
					&sys_call_table_addr[SYS_UNLINK_CODE];
			if (*sys_unlink_ptr == (func_ptr)logged_sys_unlink) {
				*sys_unlink_ptr = sys_unlink_addr;
				printk(KERN_INFO
				       "released 'unlink' system call\n");
			} else {
				printk("oops... someone else hijacked"
				       " sys_unlink... cannot restore"
				       " function.\n");
			}
		}
		if (sys_open_addr) {
			func_ptr* sys_open_ptr =
					&sys_call_table_addr[SYS_OPEN_CODE];
			if (*sys_open_ptr == (func_ptr)logged_sys_open) {
				*sys_open_ptr = sys_open_addr;
				printk(KERN_INFO
				       "released 'open' system call\n");
			} else {
				printk("oops... someone else hijacked sys_open"
				       "... cannot restore function.\n");
			}
		}
	} else {
		printk("logger_module_cleanup:"
		       " failed locating sys_call_table\n");
	}
}

int init_module(void)
{
	return logger_module_init();
}

void cleanup_module(void)
{
	logger_module_cleanup();
}

#endif

