ms.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package main
  2. import (
  3. "io"
  4. "math/rand"
  5. "net/http"
  6. "os"
  7. "strconv"
  8. "time"
  9. "github.com/gin-gonic/gin"
  10. "github.com/gorilla/websocket"
  11. )
  12. func main() {
  13. port := "8080"
  14. if os.Getenv("PORT") != "" {
  15. port = os.Getenv("PORT")
  16. }
  17. rand.Seed(time.Now().Unix())
  18. router := gin.Default()
  19. router.LoadHTMLFiles("index.html")
  20. router.GET("/", func(c *gin.Context) {
  21. c.HTML(http.StatusOK, "index.html", gin.H{
  22. "websocket": "ws://localhost:" + port + "/ws",
  23. "sse": "http://localhost:" + port + "/sse",
  24. })
  25. })
  26. router.GET("/ws", websocketHandler(websocket.Upgrader{
  27. Subprotocols: []string{"protocolOne"},
  28. }))
  29. router.GET("/sse", serverSentEventHandler())
  30. router.Run(":" + port)
  31. }
  32. func websocketHandler(upgrader websocket.Upgrader) func(c *gin.Context) {
  33. return func(c *gin.Context) {
  34. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  35. if err != nil {
  36. c.AbortWithError(http.StatusInternalServerError, err)
  37. return
  38. }
  39. t := task{}
  40. t.start()
  41. for {
  42. progress, ok := <-t.progress
  43. if !ok {
  44. break
  45. }
  46. if err := conn.WriteMessage(1, []byte(strconv.Itoa(progress))); err != nil {
  47. c.AbortWithError(http.StatusInternalServerError, err)
  48. return
  49. }
  50. }
  51. }
  52. }
  53. func serverSentEventHandler() func(c *gin.Context) {
  54. return func(c *gin.Context) {
  55. t := task{}
  56. t.start()
  57. c.Stream(func(w io.Writer) bool {
  58. progress, ok := <-t.progress
  59. if !ok {
  60. return false
  61. }
  62. c.SSEvent("", strconv.Itoa(progress))
  63. return true
  64. })
  65. }
  66. }
  67. type task struct {
  68. progress chan int
  69. }
  70. func (t *task) start() {
  71. t.progress = make(chan int)
  72. go func() {
  73. p := 0
  74. t.progress <- p
  75. for p < 100 {
  76. if rand.Intn(100) > 50 {
  77. p++
  78. t.progress <- p
  79. }
  80. time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
  81. }
  82. close(t.progress)
  83. }()
  84. }