Comments Table
Finalizing the gRPC Server - Creating the Comments Table and Endpoints.
Creating the Comments Table
First, open the database console tab and execute the following query to create the comments
table that will be used in the TaskChat
endpoint:
CREATE TABLE comments(
id SERIAL PRIMARY KEY,
task_id INT,
user_id INT,
comment TEXT,
created_at timestamptz,
CONSTRAINT fk_task
FOREIGN KEY (task_id)
REFERENCES tasks(id)
);
Similar to tasks
, the comments
table mirrors the structure in tasks.proto
. The task_id
column is a foreign key referencing the id
column of the tasks
table.
Finishing the gRPC Server
With the table in place, it's time to finish the rest of the server.
The ListTasks Endpoint
The first streaming endpoint that you'll write is ListTasks
. This is a unidirectional server-side streaming API, as indicated by the second parameter of the function being pb.TaskService_ListTasksServer
. This is a stream with a Send
method that can be used to stream data to the client:
func (s *taskServiceServer) ListTasks(request *pb.ListTasksRequest, stream pb.TaskService_ListTasksServer) error {
}
The next step is to get a list of all tasks with the given user ID that have deadlines before the given deadline:
func (s *taskServiceServer) ListTasks(request *pb.ListTasksRequest, stream pb.TaskService_ListTasksServer) error {
query := `
SELECT * FROM tasks WHERE user_id=$1 AND deadline < $2;
`
rows, err := db.Query(query, request.UserId, request.Deadline.AsTime())
if err != nil {
return err
}
}
Now, you'll need to iterate over all rows and construct the pb.Task
entity for each row:
for rows.Next() {
var (
id int
description string
user_id int
status string
deadline string
created_at string
)
err = rows.Scan(&id, &description, &user_id, &status, &deadline, &created_at)
if err != nil {
return err
}
deadlineTime, err := time.Parse(time.RFC3339, deadline)
if err != nil {
log.Fatalf("Error: Invalid time for deadline: %v", err)
}
createdAtTime, err := time.Parse(time.RFC3339, created_at)
if err != nil {
log.Fatalf("Error: Invalid time for created_at: %v", err)
}
task := &pb.Task{
Id: strconv.Itoa(id),
Description: description,
UserId: strconv.Itoa(user_id),
Status: pb.TaskStatus(pb.TaskStatus_value[status]),
Deadline: timestamppb.New(deadlineTime),
CreatedAt: timestamppb.New(createdAtTime),
}
}
This code is very similar to the GetTask
method, so you should already be familiar with what's happening in the code.
The last step is to call the Send
method of stream
and send the Task
:
err = stream.Send(task)
if err != nil {
return err
}
The whole function looks like this:
func (s *taskServiceServer) ListTasks(request *pb.ListTasksRequest, stream pb.TaskService_ListTasksServer) error {
query := `
SELECT * FROM tasks WHERE user_id=$1 AND deadline < $2;
`
rows, err := db.Query(query, request.UserId, request.Deadline.AsTime())
if err != nil {
return err
}
for rows.Next() {
var (
id int
description string
user_id int
status string
deadline string
created_at string
)
err = rows.Scan(&id, &description, &user_id, &status, &deadline, &created_at)
if err != nil {
return err
}
deadlineTime, err := time.Parse(time.RFC3339, deadline)
if err != nil {
log.Fatalf("Error: Invalid time for deadline: %v", err)
}
createdAtTime, err := time.Parse(time.RFC3339, created_at)
if err != nil {
log.Fatalf("Error: Invalid time for created_at: %v", err)
}
task := &pb.Task{
Id: strconv.Itoa(id),
Description: description,
UserId: strconv.Itoa(user_id),
Status: pb.TaskStatus(pb.TaskStatus_value[status]),
Deadline: timestamppb.New(deadlineTime),
CreatedAt: timestamppb.New(createdAtTime),
}
err = stream.Send(task)
if err != nil {
return err
}
}
return nil
}
The RecordTasks Endpoint
Before proceeding with the RecordTasks
endpoint, you can refactor the logic of creating Task
from CreateTaskRequest
and saving it into the database into a separate method since you'll be reusing the same code for the RecordTasks
endpoint. Create the following createTaskFromRequest
:
func createTaskFromRequest(request *pb.CreateTaskRequest) (*pb.Task, error) {
var task = &pb.Task{
Description: request.Description,
UserId: request.UserId,
Deadline: request.Deadline,
CreatedAt: timestamppb.Now(),
Status: pb.TaskStatus_TASK_STATUS_INCOMPLETE,
}
var taskId int
insertStmt := `
INSERT INTO tasks("description", "user_id", "status", "deadline", "created_at")
VALUES($1, $2, $3, $4, $5) RETURNING id;
`
err := db.QueryRow(insertStmt, task.Description, task.UserId, pb.TaskStatus_name[int32(task.Status)], task.Deadline.AsTime(), task.CreatedAt.AsTime()).Scan(&taskId)
if err != nil {
return nil, err
}
task.Id = strconv.Itoa(taskId)
return task, nil
}
The CreateTask
function now becomes a delegate for createTaskFromRequest
:
func (s *taskServiceServer) CreateTask(ctx context.Context, request *pb.CreateTaskRequest) (*pb.Task, error) {
return createTaskFromRequest(request)
}
The RecordTasks
endpoint is a unidirectional client-side streaming endpoint. This time, it receives a parameter of type pb.TaskService_RecordTasksServer
:
func (s *taskServiceServer) RecordTasks(stream pb.TaskService_RecordTasksServer) error {
}
This type has two methods that you'll use: Recv
to receive the data from the stream and SendAndClose
to send a response and close the stream.
You'll start by declaring the necessary variables, including count
to store the number of tasks recorded, which will be returned in the response:
func (s *taskServiceServer) RecordTasks(stream pb.TaskService_RecordTasksServer) error {
var createTaskRequest *pb.CreateTaskRequest
var err error
count := 0
}
Then, you'll iterate in an infinite loop until you reach the end of the stream or encounter an error:
for {
createTaskRequest, err = stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&pb.TaskSummary{
NoOfTasksCreated: strconv.Itoa(count),
})
}
if err != nil {
return err
}
}
Finally, save the task and increment count
:
_, err := createTaskFromRequest(createTaskRequest)
if err != nil {
return err
}
count++
Here's the RecordTasks
function in its entirety:
func (s *taskServiceServer) RecordTasks(stream pb.TaskService_RecordTasksServer) error {
var createTaskRequest *pb.CreateTaskRequest
var err error
count := 0
for {
createTaskRequest, err = stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&pb.TaskSummary{
NoOfTasksCreated: strconv.Itoa(count),
})
}
if err != nil {
return err
}
_, err := createTaskFromRequest(createTaskRequest)
if err != nil {
return err
}
count++
}
}
The TaskChat Endpoint
The final candidate for this part is the TaskChat
function, a bidirectional streaming endpoint. It receives an instance of pb.TaskService_TaskChatServer
, which has both Recv
and Send
methods for receiving and sending data, respectively:
func (s *taskServiceServer) TaskChat(stream pb.TaskService_TaskChatServer) error {
}
As before, you start with an infinite loop that stops when the end of the stream is reached or until an error occurs:
func (s *taskServiceServer) TaskChat(stream pb.TaskService_TaskChatServer) error {
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
}
}
Use the Recv
method to read the input from the stream and save the comment in the database:
taskId := in.TaskId
userId := in.UserId
comment := in.Comment
insertStmt := `INSERT INTO comments("task_id", "user_id", "comment") VALUES($1, $2, $3)`
_, err = db.Exec(insertStmt, taskId, userId, comment)
if err != nil {
return err
}
Finally, construct an instance of pb.TaskComment
and send it to the stream:
taskComment := &pb.TaskComment{
TaskId: taskId,
UserId: userId,
Comment: comment,
CreatedAt: timestamppb.Now(),
}
if err := stream.Send(taskComment); err != nil {
return err
}
The entire function looks like this:
func (s *taskServiceServer) TaskChat(stream pb.TaskService_TaskChatServer) error {
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
taskId := in.TaskId
userId := in.UserId
comment := in.Comment
insertStmt := `INSERT INTO comments("task_id", "user_id", "comment") VALUES($1, $2, $3)`
_, err = db.Exec(insertStmt, taskId, userId, comment)
if err != nil {
return err
}
taskComment := &pb.TaskComment{
TaskId: taskId,
UserId: userId,
Comment: comment,
CreatedAt: timestamppb.Now(),
}
if err := stream.Send(taskComment); err != nil {
return err
}
}
}
Here's the complete server.go
code for your reference:
package main
import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
pb "go-grpc-demo/src/go"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"io"
"log"
"net"
"strconv"
"time"
)
const (
host = "localhost"
port = 5432
user = "postgres"
password = ""
dbname = "go_grpc_demo"
)
var db *sql.DB
func initDB() error {
var err error
connectionString := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", user, password, host, port, dbname)
db, err = sql.Open("postgres", connectionString)
if err != nil {
return err
}
err = db.Ping()
if err != nil {
return err
}
return nil
}
type taskServiceServer struct {
pb.UnimplementedTaskServiceServer
}
func createTaskFromRequest(request *pb.CreateTaskRequest) (*pb.Task, error) {
var task = &pb.Task{
Description: request.Description,
UserId: request.UserId,
Deadline: request.Deadline,
CreatedAt: timestamppb.Now(),
Status: pb.TaskStatus_TASK_STATUS_INCOMPLETE,
}
var taskId int
insertStmt := `
INSERT INTO tasks("description", "user_id", "status", "deadline", "created_at")
VALUES($1, $2, $3, $4, $5) RETURNING id;
`
err := db.QueryRow(insertStmt, task.Description, task.UserId, pb.TaskStatus_name[int32(task.Status)], task.Deadline.AsTime(), task.CreatedAt.AsTime()).Scan(&taskId)
if err != nil {
return nil, err
}
task.Id = strconv.Itoa(taskId)
return task, nil
}
func (s *taskServiceServer) CreateTask(ctx context.Context, request *pb.CreateTaskRequest) (*pb.Task, error) {
return createTaskFromRequest(request)
}
func (s *taskServiceServer) GetTask(ctx context.Context, request *pb.GetTaskRequest) (*pb.Task, error) {
var (
id int
description string
user_id int
status string
deadline string
created_at string
)
err := db.QueryRow("SELECT * FROM tasks WHERE tasks.id = $1", request.TaskId).Scan(
&id, &description, &user_id, &status, &deadline, &created_at)
if err != nil {
return nil, err
}
deadlineTime, err := time.Parse(time.RFC3339, deadline)
if err != nil {
log.Fatalf("Error: Invalid time for deadline: %v", err)
}
createdAtTime, err := time.Parse(time.RFC3339, created_at)
if err != nil {
log.Fatalf("Error: Invalid time for created_at: %v", err)
}
task := &pb.Task{
Id: strconv.Itoa(id),
Description: description,
UserId: strconv.Itoa(user_id),
Status: pb.TaskStatus(pb.TaskStatus_value[status]),
Deadline: timestamppb.New(deadlineTime),
CreatedAt: timestamppb.New(createdAtTime),
}
return task, nil
}
func (s *taskServiceServer) ListTasks(request *pb.ListTasksRequest, stream pb.TaskService_ListTasksServer) error {
query := `
SELECT * FROM tasks WHERE user_id=$1 AND deadline < $2;
`
rows, err := db.Query(query, request.UserId, request.Deadline.AsTime())
if err != nil {
return err
}
for rows.Next() {
var (
id int
description string
user_id int
status string
deadline string
created_at string
)
err = rows.Scan(&id, &description, &user_id, &status, &deadline, &created_at)
if err != nil {
return err
}
deadlineTime, err := time.Parse(time.RFC3339, deadline)
if err != nil {
log.Fatalf("Error: Invalid time for deadline: %v", err)
}
createdAtTime, err := time.Parse(time.RFC3339, created_at)
if err != nil {
log.Fatalf("Error: Invalid time for created_at: %v", err)
}
task := &pb.Task{
Id: strconv.Itoa(id),
Description: description,
UserId: strconv.Itoa(user_id),
Status: pb.TaskStatus(pb.TaskStatus_value[status]),
Deadline: timestamppb.New(deadlineTime),
CreatedAt: timestamppb.New(createdAtTime),
}
err = stream.Send(task)
if err != nil {
return err
}
}
return nil
}
func (s *taskServiceServer) RecordTasks(stream pb.TaskService_RecordTasksServer) error {
var createTaskRequest *pb.CreateTaskRequest
var err error
count := 0
for {
createTaskRequest, err = stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&pb.TaskSummary{
NoOfTasksCreated: strconv.Itoa(count),
})
}
if err != nil {
return err
}
_, err := createTaskFromRequest(createTaskRequest)
if err != nil {
return err
}
count++
}
}
func (s *taskServiceServer) TaskChat(stream pb.TaskService_TaskChatServer) error {
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
taskId := in.TaskId
userId := in.UserId
comment := in.Comment
insertStmt := `INSERT INTO comments("task_id", "user_id", "comment") VALUES($1, $2, $3)`
_, err = db.Exec(insertStmt, taskId, userId, comment)
if err != nil {
return err
}
taskComment := &pb.TaskComment{
TaskId: taskId,
UserId: userId,
Comment: comment,
CreatedAt: timestamppb.Now(),
}
if err := stream.Send(taskComment); err != nil {
return err
}
}
}
func main() {
err := initDB()
if err != nil {
log.Fatalf("Error initiating database: %v", err)
}
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", 9090))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
grpcServer := grpc.NewServer()
pb.RegisterTaskServiceServer(grpcServer, &taskServiceServer{})
err = grpcServer.Serve(lis)
if err != nil {
log.Fatalf("Error starting gRPC server: %v", err)
}
}
The entire code so far can be found in the part3
branch of this GitHub repo.
Conclusion
With that, you've completed the server! Well done! Now, why not take a break? You might want to finish that cup of coffee before proceeding to the final part, where you'll write the client and test the whole app.