/* University of Strasbourg - Master ILC-ISI-RISE
 * Compilation Lab - Compiler for the arith language
 * Written by Cedric Bastoul cedric.bastoul@unistra.fr
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "symbol.h"
#include "ast.h"
#include "quad.h"
#include "codegen.h"

// Generate code & symbol containing the result for number nodes
struct codegen* codegen_number(struct ast* ast,
                               struct symbol** symbol_table) {
  struct codegen* cg = malloc(sizeof(struct codegen));
  cg->result = symbol_new_constant(symbol_table, ast->u.number);
  cg->code = NULL;
  return cg;
}

// Generate code & symbol containing the result for symbol nodes
struct codegen* codegen_symbol(struct ast* ast) {
  struct codegen* cg = malloc(sizeof(struct codegen));
  cg->result = ast->u.symbol;
  cg->code = NULL;
  return cg;
}

// Generate code & symbol containing the result for statement nodes
struct codegen* codegen_statement(struct codegen* previous_statements,
                                  struct ast* ast,
                                  struct symbol** symbol_table) {
  struct codegen* cg = malloc(sizeof(struct codegen));
  struct codegen* expr = codegen_ast(ast->u.statement.expression,symbol_table);
  struct quad* new_code;
  cg->result = NULL;
  cg->code = NULL;
  
  if (ast->u.statement.symbol == NULL)
    new_code = quad_gen('P', NULL, NULL, expr->result);
  else
    new_code = quad_gen('=', expr->result, NULL, ast->u.statement.symbol);
  
  if (previous_statements != NULL) {
    quad_add(&cg->code, previous_statements->code);
    free(previous_statements);
  }
  quad_add(&cg->code, expr->code);
  quad_add(&cg->code, new_code);
  free(expr);
  return cg;
}

// Generate code & symbol containing the result for operation nodes
struct codegen* codegen_operation(enum ast_type type,
                                  struct ast* ast,
                                  struct symbol** symbol_table) {
  struct codegen* cg = malloc(sizeof(struct codegen));
  struct codegen* left  = codegen_ast(ast->u.operation.left,  symbol_table);
  struct codegen* right = codegen_ast(ast->u.operation.right, symbol_table);
  struct quad* new_code;
  cg->result = symbol_new_temp(symbol_table);
  
  if (type == ast_type_add)
    new_code = quad_gen('+', left->result, right->result, cg->result);
  else
    new_code = quad_gen('*', left->result, right->result, cg->result);

  cg->code = left->code;
  quad_add(&cg->code, right->code);
  quad_add(&cg->code, new_code);
  free(left);
  free(right);
  return cg;
}

// Generate code & symbol containing the result for ast nodes
struct codegen* codegen_ast(struct ast* ast, struct symbol** symbol_table) {
  struct codegen* cg = NULL;
  
  do {
    switch (ast->type) {
      case ast_type_number:
        cg = codegen_number(ast, symbol_table);
        break;
      case ast_type_identifier:
        break;
      case ast_type_symbol:
        cg = codegen_symbol(ast);
        break;
      case ast_type_statement:
        cg = codegen_statement(cg, ast, symbol_table);
        break;
      case ast_type_add:
        cg = codegen_operation(ast_type_add, ast, symbol_table);
        break;
      case ast_type_mul:
        cg = codegen_operation(ast_type_mul, ast, symbol_table);
        break;
      default:
        printf("Unknown AST node type\n");
        exit(1);
    }
    if (ast->type == ast_type_statement)
      ast = ast->u.statement.next;
  } while ((ast != NULL) && (ast->type == ast_type_statement));

  return cg;
}

// Translate an AST to a list of quads
struct quad* codegen(struct ast* ast,  struct symbol** symbol_table) {
  struct codegen* cg = codegen_ast(ast, symbol_table);
  struct quad* quad_list = cg->code;
  free(cg);
  return quad_list;
}
