core/core.go
2024-11-22 16:25:06 -05:00

587 lines
14 KiB
Go

// Copyright 2021 Harran Ali <harran.m@gmail.com>. All rights reserved.
// Copyright (c) 2024 Zeni Kim <zenik@smarteching.com>
// Use of this source code is governed by MIT-style
// license that can be found in the LICENSE file.
package core
import (
"embed"
"fmt"
"log"
"net/http"
"os"
"path"
"path/filepath"
"runtime/debug"
"strconv"
"syscall"
"git.smarteching.com/goffee/core/env"
"git.smarteching.com/goffee/core/logger"
"github.com/julienschmidt/httprouter"
"golang.org/x/crypto/acme/autocert"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var loggr *logger.Logger
var logsDriver *logger.LogsDriver
var requestC RequestConfig
var jwtC JWTConfig
var gormC GormConfig
var cacheC CacheConfig
var db *gorm.DB
var mailer *Mailer
var basePath string
var runMode string
var disableEvents bool = false
//go:embed all:template
var components_resources embed.FS
type configContainer struct {
Request RequestConfig
}
type App struct {
t int // for tracking hooks
chain *chain
hooks *Hooks
Config *configContainer
}
var app *App
func New() *App {
app = &App{
chain: &chain{},
hooks: NewHooks(),
Config: &configContainer{
Request: requestC,
},
}
return app
}
func ResolveApp() *App {
return app
}
func (app *App) SetLogsDriver(d logger.LogsDriver) {
logsDriver = &d
}
func (app *App) Bootstrap() {
loggr = logger.NewLogger(*logsDriver)
NewRouter()
NewEventsManager()
}
func (app *App) RegisterTemplates(templates_resources embed.FS) {
NewTemplates(components_resources, templates_resources)
}
func (app *App) Run(router *httprouter.Router) {
portNumber := os.Getenv("App_HTTP_PORT")
if portNumber == "" {
portNumber = "80"
}
router = app.RegisterRoutes(ResolveRouter().GetRoutes(), router)
// check if template engine is enable
TemplateEnableStr := os.Getenv("TEMPLATE_ENABLE")
if TemplateEnableStr == "" {
TemplateEnableStr = "false"
}
TemplateEnable, _ := strconv.ParseBool(TemplateEnableStr)
// if enabled,
if TemplateEnable {
router.ServeFiles("/public/*filepath", http.Dir("storage/public"))
}
useHttpsStr := os.Getenv("App_USE_HTTPS")
if useHttpsStr == "" {
useHttpsStr = "false"
}
useHttps, _ := strconv.ParseBool(useHttpsStr)
if runMode == "dev" {
fmt.Printf("Welcome to Goffee\n")
if useHttps {
fmt.Printf("Listening on https \nWaiting for requests...\n")
} else {
fmt.Printf("Listening on port %s\nWaiting for requests...\n", portNumber)
}
}
// check if use letsencrypt
UseLetsEncryptStr := os.Getenv("App_USE_LETSENCRYPT")
if UseLetsEncryptStr == "" {
UseLetsEncryptStr = "false"
}
UseLetsEncrypt, _ := strconv.ParseBool(UseLetsEncryptStr)
if useHttps && UseLetsEncrypt {
m := &autocert.Manager{
Cache: autocert.DirCache("letsencrypt-certs-dir"),
Prompt: autocert.AcceptTOS,
}
LetsEncryptEmail := os.Getenv("APP_LETSENCRYPT_EMAIL")
if LetsEncryptEmail != "" {
m.Email = LetsEncryptEmail
}
HttpsHosts := os.Getenv("App_HTTPS_HOSTS")
if HttpsHosts != "" {
m.HostPolicy = autocert.HostWhitelist(HttpsHosts)
}
log.Fatal(http.Serve(m.Listener(), router))
return
}
if useHttps && !UseLetsEncrypt {
CertFile := os.Getenv("App_CERT_FILE_PATH")
if CertFile == "" {
CertFile = "tls/server.crt"
}
KeyFile := os.Getenv("App_KEY_FILE_PATH")
if KeyFile == "" {
KeyFile = "tls/server.key"
}
certFilePath := filepath.Join(basePath, CertFile)
KeyFilePath := filepath.Join(basePath, KeyFile)
log.Fatal(http.ListenAndServeTLS(":443", certFilePath, KeyFilePath, router))
return
}
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", portNumber), router))
}
func (app *App) RegisterRoutes(routes []Route, router *httprouter.Router) *httprouter.Router {
router.PanicHandler = panicHandler
router.NotFound = notFoundHandler{}
router.MethodNotAllowed = methodNotAllowed{}
for _, route := range routes {
switch route.Method {
case GET:
router.GET(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case POST:
router.POST(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case DELETE:
router.DELETE(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case PATCH:
router.PATCH(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case PUT:
router.PUT(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case OPTIONS:
router.OPTIONS(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
case HEAD:
router.HEAD(route.Path, app.makeHTTPRouterHandlerFunc(route.Controller, route.Hooks))
}
}
// check if enable core services
UseCoreServicesStr := os.Getenv("App_USE_CORESERVICES")
if UseCoreServicesStr == "" {
UseCoreServicesStr = "false"
}
UseCoreServices, _ := strconv.ParseBool(UseCoreServicesStr)
if UseCoreServices {
// Register router for graphs
router.GET("/coregraph/*graph", Graph)
}
return router
}
func (app *App) makeHTTPRouterHandlerFunc(h Controller, ms []Hook) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
ctx := &Context{
Request: &Request{
httpRequest: r,
httpPathParams: ps,
},
Response: &Response{
headers: []header{},
body: nil,
contentType: "",
overrideContentType: "",
HttpResponseWriter: w,
isTerminated: false,
redirectTo: "",
},
GetValidator: getValidator(),
GetJWT: getJWT(),
GetGorm: getGormFunc(),
GetCache: resolveCache(),
GetHashing: resloveHashing(),
GetMailer: resolveMailer(),
GetEventsManager: resolveEventsManager(),
GetLogger: resolveLogger(),
}
ctx.prepare(ctx)
rhs := app.combHandlers(h, ms)
app.prepareChain(rhs)
app.t = 0
app.chain.execute(ctx)
for _, header := range ctx.Response.headers {
w.Header().Add(header.key, header.val)
}
logger.CloseLogsFile()
var ct string
if ctx.Response.overrideContentType != "" {
ct = ctx.Response.overrideContentType
} else if ctx.Response.contentType != "" {
ct = ctx.Response.contentType
} else {
ct = CONTENT_TYPE_HTML
}
w.Header().Add(CONTENT_TYPE, ct)
if ctx.Response.statusCode != 0 {
w.WriteHeader(ctx.Response.statusCode)
}
if ctx.Response.redirectTo != "" {
http.Redirect(w, r, ctx.Response.redirectTo, http.StatusTemporaryRedirect)
} else {
w.Write(ctx.Response.body)
}
e := ResolveEventsManager()
if e != nil {
e.setContext(ctx).processFiredEvents()
}
app.t = 0
ctx.Response.reset()
app.chain.reset()
}
}
type notFoundHandler struct{}
type methodNotAllowed struct{}
func (n notFoundHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
res := "{\"message\": \"Not Found\"}"
loggr.Error("Not Found")
loggr.Error(debug.Stack())
w.Header().Add(CONTENT_TYPE, CONTENT_TYPE_JSON)
w.Write([]byte(res))
}
func (n methodNotAllowed) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
res := "{\"message\": \"Method not allowed\"}"
loggr.Error("Method not allowed")
loggr.Error(debug.Stack())
w.Header().Add(CONTENT_TYPE, CONTENT_TYPE_JSON)
w.Write([]byte(res))
}
var panicHandler = func(w http.ResponseWriter, r *http.Request, e interface{}) {
isDebugModeStr := os.Getenv("APP_DEBUG_MODE")
isDebugMode, err := strconv.ParseBool(isDebugModeStr)
if err != nil {
errStr := "error parsing env var APP_DEBUG_MODE"
loggr.Error(errStr)
fmt.Sprintln(errStr)
w.Write([]byte(errStr))
return
}
if !isDebugMode {
errStr := "internal error"
loggr.Error(errStr)
fmt.Sprintln(errStr)
w.WriteHeader(http.StatusInternalServerError)
w.Header().Add(CONTENT_TYPE, CONTENT_TYPE_JSON)
w.Write([]byte(fmt.Sprintf("{\"message\": \"%v\"}", errStr)))
return
}
shrtMsg := fmt.Sprintf("%v", e)
loggr.Error(shrtMsg)
fmt.Println(shrtMsg)
loggr.Error(string(debug.Stack()))
var res string
if env.GetVarOtherwiseDefault("APP_ENV", "local") == PRODUCTION {
res = "{\"message\": \"internal error\"}"
} else {
res = fmt.Sprintf("{\"message\": \"%v\", \"stack trace\": \"%v\"}", e, string(debug.Stack()))
}
w.WriteHeader(http.StatusInternalServerError)
w.Header().Add(CONTENT_TYPE, CONTENT_TYPE_JSON)
w.Write([]byte(res))
}
func UseHook(mw Hook) {
ResolveHooks().Attach(mw)
}
func (app *App) Next(c *Context) {
app.t = app.t + 1
n := app.chain.getByIndex(app.t)
if n != nil {
f, ok := n.(Hook)
if ok {
f(c)
} else {
ff, ok := n.(Controller)
if ok {
ff(c)
}
}
}
}
type chain struct {
nodes []interface{}
}
func (cn *chain) reset() {
cn.nodes = []interface{}{}
}
func (c *chain) getByIndex(i int) interface{} {
for k := range c.nodes {
if k == i {
return c.nodes[i]
}
}
return nil
}
func (app *App) prepareChain(hs []interface{}) {
mw := app.hooks.GetHooks()
for _, v := range mw {
app.chain.nodes = append(app.chain.nodes, v)
}
for _, v := range hs {
app.chain.nodes = append(app.chain.nodes, v)
}
}
func (cn *chain) execute(ctx *Context) {
i := cn.getByIndex(0)
if i != nil {
f, ok := i.(Hook)
if ok {
f(ctx)
} else {
ff, ok := i.(Controller)
if ok {
ff(ctx)
}
}
}
}
func (app *App) combHandlers(h Controller, mw []Hook) []interface{} {
var rev []interface{}
for _, k := range mw {
rev = append(rev, k)
}
rev = append(rev, h)
return rev
}
func getGormFunc() func() *gorm.DB {
f := func() *gorm.DB {
if !gormC.EnableGorm {
panic("you are trying to use gorm but it's not enabled, you can enable it in the file config/gorm.go")
}
return ResolveGorm()
}
return f
}
func NewGorm() *gorm.DB {
var err error
switch os.Getenv("DB_DRIVER") {
case "mysql":
db, err = mysqlConnect()
case "postgres":
db, err = postgresConnect()
case "sqlite":
sqlitePath := os.Getenv("SQLITE_DB_PATH")
fullSqlitePath := path.Join(basePath, sqlitePath)
_, err := os.Stat(fullSqlitePath)
if err != nil {
panic(fmt.Sprintf("error locating sqlite file: %v", err.Error()))
}
db, err = gorm.Open(sqlite.Open(fullSqlitePath), &gorm.Config{})
default:
panic("database driver not selected")
}
if gormC.EnableGorm && err != nil {
panic(fmt.Sprintf("gorm has problem connecting to %v, (if it's not needed you can disable it in config/gorm.go): %v", os.Getenv("DB_DRIVER"), err))
}
return db
}
func ResolveGorm() *gorm.DB {
if db != nil {
return db
}
db = NewGorm()
return db
}
func resolveCache() func() *Cache {
f := func() *Cache {
if !cacheC.EnableCache {
panic("you are trying to use cache but it's not enabled, you can enable it in the file config/cache.go")
}
return NewCache(cacheC)
}
return f
}
func postgresConnect() (*gorm.DB, error) {
dsn := fmt.Sprintf("host=%v user=%v password=%v dbname=%v port=%v sslmode=%v TimeZone=%v",
os.Getenv("POSTGRES_HOST"),
os.Getenv("POSTGRES_USER"),
os.Getenv("POSTGRES_PASSWORD"),
os.Getenv("POSTGRES_DB_NAME"),
os.Getenv("POSTGRES_PORT"),
os.Getenv("POSTGRES_SSL_MODE"),
os.Getenv("POSTGRES_TIMEZONE"),
)
return gorm.Open(postgres.Open(dsn), &gorm.Config{})
}
func mysqlConnect() (*gorm.DB, error) {
dsn := fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?charset=%v&parseTime=True&loc=Local",
os.Getenv("MYSQL_USERNAME"),
os.Getenv("MYSQL_PASSWORD"),
os.Getenv("MYSQL_HOST"),
os.Getenv("MYSQL_PORT"),
os.Getenv("MYSQL_DB_NAME"),
os.Getenv("MYSQL_CHARSET"),
)
return gorm.Open(mysql.New(mysql.Config{
DSN: dsn, // data source name
DefaultStringSize: 256, // default size for string fields
DisableDatetimePrecision: true, // disable datetime precision, which not supported before MySQL 5.6
DontSupportRenameIndex: true, // drop & create when rename index, rename index not supported before MySQL 5.7, MariaDB
DontSupportRenameColumn: true, // `change` when rename column, rename column not supported before MySQL 8, MariaDB
SkipInitializeWithVersion: false, // auto configure based on currently MySQL version
}), &gorm.Config{})
}
func getJWT() func() *JWT {
f := func() *JWT {
secret := os.Getenv("JWT_SECRET")
if secret == "" {
panic("jwt secret key is not set")
}
lifetimeStr := os.Getenv("JWT_LIFESPAN_MINUTES")
if lifetimeStr == "" {
lifetimeStr = "10080" // 7 days
}
lifetime64, err := strconv.ParseInt(lifetimeStr, 10, 32)
if err != nil {
panic(err)
}
lifetime := int(lifetime64)
return newJWT(JWTOptions{
SigningKey: secret,
LifetimeMinutes: lifetime,
})
}
return f
}
func getValidator() func() *Validator {
f := func() *Validator {
return &Validator{}
}
return f
}
func resloveHashing() func() *Hashing {
f := func() *Hashing {
return &Hashing{}
}
return f
}
func resolveMailer() func() *Mailer {
f := func() *Mailer {
if mailer != nil {
return mailer
}
var m *Mailer
var emailsDriver string
if os.Getenv("EMAILS_DRIVER") == "" {
emailsDriver = "SMTP"
}
switch emailsDriver {
case "SMTP":
m = initiateMailerWithSMTP()
case "sparkpost":
m = initiateMailerWithSparkPost()
case "sendgrid":
m = initiateMailerWithSendGrid()
case "mailgun":
return initiateMailerWithMailGun()
default:
m = initiateMailerWithSMTP()
}
mailer = m
return mailer
}
return f
}
func resolveEventsManager() func() *EventsManager {
f := func() *EventsManager {
return ResolveEventsManager()
}
return f
}
func resolveLogger() func() *logger.Logger {
f := func() *logger.Logger {
return loggr
}
return f
}
func (app *App) MakeDirs(dirs ...string) {
o := syscall.Umask(0)
defer syscall.Umask(o)
for _, dir := range dirs {
os.MkdirAll(path.Join(basePath, dir), 0766)
}
}
func (app *App) SetRequestConfig(r RequestConfig) {
requestC = r
}
func (app *App) SetGormConfig(g GormConfig) {
gormC = g
}
func (app *App) SetCacheConfig(c CacheConfig) {
cacheC = c
}
func (app *App) SetBasePath(path string) {
basePath = path
}
func (app *App) SetRunMode(runmode string) {
runMode = runmode
}
func DisableEvents() {
disableEvents = true
}
func EnableEvents() {
disableEvents = false
}